/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote.streaming;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.security.AccessController;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Generated;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.internal.http2.StreamResetException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.agui.BaseEvent;
import org.opensearch.ml.common.agui.RunFinishedEvent;
import org.opensearch.ml.common.agui.TextMessageContentEvent;
import org.opensearch.ml.common.agui.TextMessageEndEvent;
import org.opensearch.ml.common.agui.TextMessageStartEvent;
import org.opensearch.ml.common.agui.ToolCallArgsEvent;
import org.opensearch.ml.common.agui.ToolCallEndEvent;
import org.opensearch.ml.common.agui.ToolCallStartEvent;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.streaming.BaseStreamingHandler;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;

public class HttpStreamingHandler
extends BaseStreamingHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(HttpStreamingHandler.class);
    private final Connector connector;
    private OkHttpClient okHttpClient;
    private String llmInterface;
    private Map<String, String> parameters;

    public HttpStreamingHandler(String llmInterface, Connector connector, ConnectorClientConfig connectorClientConfig) {
        this(llmInterface, connector, connectorClientConfig, null);
    }

    public HttpStreamingHandler(String llmInterface, Connector connector, ConnectorClientConfig connectorClientConfig, Map<String, String> parameters) {
        this.connector = connector;
        this.llmInterface = llmInterface;
        this.parameters = parameters;
        Duration connectionTimeout = Duration.ofSeconds(connectorClientConfig.getConnectionTimeout().intValue());
        Duration readTimeout = Duration.ofSeconds(connectorClientConfig.getReadTimeout().intValue());
        try {
            AccessController.doPrivileged(() -> {
                this.okHttpClient = new OkHttpClient.Builder().connectTimeout(connectionTimeout).readTimeout(readTimeout).retryOnConnectionFailure(true).build();
                return null;
            });
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to build OkHttpClient", e);
        }
    }

    @Override
    public void startStream(String action, Map<String, String> parameters, String payload, StreamPredictActionListener<MLTaskResponse, ?> actionListener) {
        try {
            log.info("Creating SSE connection for streaming request");
            HTTPEventSourceListener listener = new HTTPEventSourceListener(actionListener, this.llmInterface, parameters);
            Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, this.connector, parameters, payload);
            AccessController.doPrivileged(() -> {
                EventSources.createFactory((OkHttpClient)this.okHttpClient).newEventSource(request, listener);
                return null;
            });
        }
        catch (Exception e) {
            log.error("Failed to start HTTP streaming", (Throwable)e);
            this.handleError(e, actionListener);
        }
    }

    @Override
    public void handleError(Throwable error, StreamPredictActionListener<MLTaskResponse, ?> listener) {
        log.error("HTTP streaming error", error);
        listener.onFailure((Exception)new MLException("Fail to execute streaming", error));
    }

    public final class HTTPEventSourceListener
    extends EventSourceListener {
        private StreamPredictActionListener<MLTaskResponse, ?> streamActionListener;
        private final String llmInterface;
        private final boolean isAGUIAgent;
        private final Map<String, String> parameters;
        private AtomicBoolean isStreamClosed;
        private boolean functionCallInProgress = false;
        private boolean agentExecutionInProgress = false;
        private String accumulatedToolCallId = null;
        private String accumulatedToolName = null;
        private String accumulatedArguments = "";

        public HTTPEventSourceListener(StreamPredictActionListener<MLTaskResponse, ?> streamActionListener, String llmInterface, Map<String, String> parameters) {
            this.streamActionListener = streamActionListener;
            this.llmInterface = llmInterface;
            this.parameters = parameters;
            this.isStreamClosed = new AtomicBoolean(false);
            this.isAGUIAgent = AgentUtils.isAGUIAgent(parameters);
            if (this.isAGUIAgent) {
                log.debug("HttpStreamingHandler: Detected AG-UI agent");
            }
        }

        public void onOpen(EventSource eventSource, Response response) {
            log.debug("Connected to SSE Endpoint.");
        }

        public void onEvent(EventSource eventSource, String id, String type, String data) {
            log.debug("The data is: {}", (Object)data);
            switch (this.llmInterface) {
                case "openai/v1/chat/completions": {
                    this.onOpenAIEvent(data);
                    break;
                }
                default: {
                    throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", this.llmInterface));
                }
            }
        }

        public void onClosed(EventSource eventSource) {
            log.debug("SSE CLOSED.");
        }

        public void onFailure(EventSource eventSource, Throwable t, Response response) {
            if (t != null) {
                log.error("Error: " + t.getMessage(), t);
                if (!(t instanceof StreamResetException) || !t.getMessage().contains("NO_ERROR")) {
                    this.streamActionListener.onFailure((Exception)new MLException("SSE failure with network error", t));
                }
            } else if (response != null) {
                try {
                    String errorBody = response.body() != null ? response.body().string() : "";
                    this.streamActionListener.onFailure((Exception)new MLException("Error from remote service: " + errorBody));
                }
                catch (IOException e) {
                    this.streamActionListener.onFailure((Exception)new MLException("SSE failure - unable to read error details"));
                }
            } else {
                this.streamActionListener.onFailure((Exception)new MLException("SSE failure"));
            }
        }

        private void onOpenAIEvent(String data) {
            if ("[DONE]".equals(data)) {
                this.handleDoneEvent();
                return;
            }
            try {
                Map dataMap = (Map)StringUtils.gson.fromJson(data, Map.class);
                this.processStreamChunk(dataMap);
            }
            catch (Exception e) {
                log.debug("Skipping malformed chunk: {}", (Object)data);
            }
        }

        private void handleDoneEvent() {
            if (!this.agentExecutionInProgress) {
                boolean textMessageStarted;
                String messageId = this.isAGUIAgent && this.parameters != null ? this.parameters.get("agui_message_id") : null;
                boolean bl = textMessageStarted = this.isAGUIAgent && this.parameters != null && "true".equalsIgnoreCase(this.parameters.get("agui_text_message_started"));
                if (this.isAGUIAgent && textMessageStarted) {
                    this.parameters.put("agui_text_message_started", "false");
                    TextMessageEndEvent textMessageEndEvent = new TextMessageEndEvent(messageId);
                    HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)textMessageEndEvent, false, this.streamActionListener);
                    log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {} at stream end", (Object)messageId);
                    String threadId = this.parameters.get("agui_thread_id");
                    String runId = this.parameters.get("agui_run_id");
                    RunFinishedEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null);
                    HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)runFinishedEvent, true, this.streamActionListener);
                    log.debug("AG-UI: Sent RUN_FINISHED event at [DONE] - threadId={}, runId={}", (Object)threadId, (Object)runId);
                }
                HttpStreamingHandler.this.sendCompletionResponse(this.isStreamClosed, this.streamActionListener);
            }
        }

        private void processStreamChunk(Map<String, Object> dataMap) {
            List toolCalls;
            Object messageId = this.isAGUIAgent && this.parameters != null ? this.parameters.get("agui_message_id") : null;
            boolean textMessageStarted = this.isAGUIAgent && this.parameters != null && "true".equalsIgnoreCase(this.parameters.get("agui_text_message_started"));
            String finishReason = (String)this.extractPath(dataMap, "$.choices[0].finish_reason");
            if ("stop".equals(finishReason)) {
                this.agentExecutionInProgress = false;
                if (this.isAGUIAgent && textMessageStarted) {
                    this.parameters.put("agui_text_message_started", "false");
                    TextMessageEndEvent textMessageEndEvent = new TextMessageEndEvent((String)messageId);
                    HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)textMessageEndEvent, false, this.streamActionListener);
                    log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {}", messageId);
                    String threadId = this.parameters.get("agui_thread_id");
                    String runId = this.parameters.get("agui_run_id");
                    RunFinishedEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null);
                    HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)runFinishedEvent, true, this.streamActionListener);
                    log.debug("AG-UI: Sent RUN_FINISHED event - threadId={}, runId={}", (Object)threadId, (Object)runId);
                }
                HttpStreamingHandler.this.sendCompletionResponse(this.isStreamClosed, this.streamActionListener);
                return;
            }
            String content = (String)this.extractPath(dataMap, "$.choices[0].delta.content");
            if (content != null && !content.isEmpty()) {
                if (this.isAGUIAgent) {
                    if (!textMessageStarted) {
                        messageId = "msg_" + System.nanoTime();
                        this.parameters.put("agui_message_id", (String)messageId);
                        this.parameters.put("agui_text_message_started", "true");
                        TextMessageStartEvent textMessageStartEvent = new TextMessageStartEvent((String)messageId, "assistant");
                        HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)textMessageStartEvent, false, this.streamActionListener);
                        log.debug("AG-UI: Sent TEXT_MESSAGE_START for messageId: {}", messageId);
                    }
                    TextMessageContentEvent textMessageContentEvent = new TextMessageContentEvent((String)messageId, content);
                    HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)textMessageContentEvent, false, this.streamActionListener);
                    log.debug("AG-UI: Sent TEXT_MESSAGE_CONTENT for messageId: {}", messageId);
                } else {
                    HttpStreamingHandler.this.sendContentResponse(content, false, this.streamActionListener);
                }
            }
            if ((toolCalls = (List)this.extractPath(dataMap, "$.choices[0].delta.tool_calls")) != null) {
                if (this.isAGUIAgent) {
                    if (textMessageStarted) {
                        this.parameters.put("agui_text_message_started", "false");
                        TextMessageEndEvent textMessageEndEvent = new TextMessageEndEvent((String)messageId);
                        HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)textMessageEndEvent, false, this.streamActionListener);
                        log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {} before tool call", messageId);
                    }
                    this.processAGUIToolCalls(toolCalls);
                } else {
                    this.accumulateFunctionCall(toolCalls);
                    HttpStreamingHandler.this.sendContentResponse(StringUtils.toJson((Object)toolCalls), false, this.streamActionListener);
                }
            }
            if ("tool_calls".equals(finishReason) && this.functionCallInProgress) {
                this.completeToolCall();
            }
        }

        private <T> T extractPath(Map<String, Object> dataMap, String path) {
            try {
                return (T)JsonPath.read(dataMap, (String)path, (Predicate[])new Predicate[0]);
            }
            catch (Exception e) {
                return null;
            }
        }

        private void completeToolCall() {
            this.agentExecutionInProgress = true;
            if (this.isAGUIAgent) {
                ToolCallEndEvent toolCallEndEvent = new ToolCallEndEvent(this.accumulatedToolCallId);
                HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)toolCallEndEvent, false, this.streamActionListener);
                log.debug("AG-UI: Sent TOOL_CALL_END for toolCallId: {}", (Object)this.accumulatedToolCallId);
                String completeFunctionCall = this.buildCompleteFunctionCallResponse();
                Map response = (Map)StringUtils.gson.fromJson(completeFunctionCall, Map.class);
                ModelTensorOutput output = this.createModelTensorOutput(response);
                this.streamActionListener.onResponse(new MLTaskResponse((MLOutput)output));
                log.debug("AG-UI: Sent tool execution response to agent");
            } else {
                String completeFunctionCall = this.buildCompleteFunctionCallResponse();
                HttpStreamingHandler.this.sendContentResponse(completeFunctionCall, false, this.streamActionListener);
                Map response = (Map)StringUtils.gson.fromJson(completeFunctionCall, Map.class);
                ModelTensorOutput output = this.createModelTensorOutput(response);
                this.streamActionListener.onResponse(new MLTaskResponse((MLOutput)output));
            }
            this.accumulatedToolCallId = null;
            this.accumulatedToolName = null;
            this.accumulatedArguments = "";
            this.functionCallInProgress = false;
        }

        private String buildCompleteFunctionCallResponse() {
            String arguments = this.accumulatedArguments == null || this.accumulatedArguments.isEmpty() ? "{}" : this.accumulatedArguments;
            Map<String, String> function = Map.of("name", this.accumulatedToolName, "arguments", arguments);
            Map<String, Map<String, String>> toolCall = Map.of("id", this.accumulatedToolCallId, "type", "function", "function", function);
            Map<String, List<Map<String, Map<String, String>>>> message = Map.of("tool_calls", List.of(toolCall));
            Map<String, String> choice = Map.of("message", message, "finish_reason", "tool_calls");
            Map<String, List<Map<String, String>>> response = Map.of("choices", List.of(choice));
            return StringUtils.toJson(response);
        }

        private ModelTensorOutput createModelTensorOutput(Map<String, Object> responseData) {
            ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(responseData).build();
            ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build();
            return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
        }

        private void processAGUIToolCalls(List<?> toolCalls) {
            this.functionCallInProgress = true;
            String messageId = this.isAGUIAgent && this.parameters != null ? this.parameters.get("agui_message_id") : null;
            for (Object toolCall : toolCalls) {
                Map tcMap = (Map)toolCall;
                if (tcMap.containsKey("id")) {
                    String toolCallId = (String)tcMap.get("id");
                    if (this.accumulatedToolCallId == null) {
                        this.accumulatedToolCallId = toolCallId;
                    }
                }
                if (!tcMap.containsKey("function")) continue;
                Map func = (Map)tcMap.get("function");
                if (func.containsKey("name")) {
                    String toolName = (String)func.get("name");
                    if (this.accumulatedToolName == null) {
                        this.accumulatedToolName = toolName;
                        ToolCallStartEvent startEvent = new ToolCallStartEvent(this.accumulatedToolCallId, toolName, messageId);
                        HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)startEvent, false, this.streamActionListener);
                    }
                }
                if (!func.containsKey("arguments")) continue;
                String argsDelta = (String)func.get("arguments");
                this.accumulatedArguments = this.accumulatedArguments + argsDelta;
                ToolCallArgsEvent argsEvent = new ToolCallArgsEvent(this.accumulatedToolCallId, argsDelta);
                HttpStreamingHandler.this.sendAGUIEvent((BaseEvent)argsEvent, false, this.streamActionListener);
            }
        }

        private void accumulateFunctionCall(List<?> toolCalls) {
            this.functionCallInProgress = true;
            for (Object toolCall : toolCalls) {
                Map tcMap = (Map)toolCall;
                if (tcMap.containsKey("id")) {
                    this.accumulatedToolCallId = (String)tcMap.get("id");
                }
                if (!tcMap.containsKey("function")) continue;
                Map func = (Map)tcMap.get("function");
                if (func.containsKey("name")) {
                    this.accumulatedToolName = (String)func.get("name");
                }
                if (!func.containsKey("arguments")) continue;
                this.accumulatedArguments = this.accumulatedArguments + (String)func.get("arguments");
            }
        }
    }
}

