/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchpipelines.questionanswering.generative.prompt;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import lombok.Generated;
import org.apache.commons.text.StringEscapeUtils;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

public class PromptUtil {
    public static final String DEFAULT_SYSTEM_PROMPT = "Generate a concise and informative answer in less than 100 words for the given question, taking into context: - An enumerated list of search results- A rephrase of the question that was used to generate the search results- The conversation historyCite search results using [${number}] notation.Do not repeat yourself, and NEVER repeat anything in the chat history.If there are any necessary steps or procedures in your answer, enumerate them.";
    private static final String roleUser = "user";
    private static final String NEWLINE = "\\n";
    private static final String CONTENT_FIELD_TEXT = "text";
    private static final String CONTENT_FIELD_TYPE = "type";

    public static String getQuestionRephrasingPrompt(String originalQuestion, List<Interaction> chatHistory) {
        return null;
    }

    public static String getChatCompletionPrompt(Llm.ModelProvider provider, String question, List<Interaction> chatHistory, List<String> contexts) {
        return PromptUtil.getChatCompletionPrompt(provider, DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts, null);
    }

    public static String getChatCompletionPrompt(Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List<Interaction> chatHistory, List<String> contexts, List<MessageBlock> llmMessages) {
        return PromptUtil.buildMessageParameter(provider, systemPrompt, userInstructions, question, chatHistory, contexts, llmMessages);
    }

    public static String buildSingleStringPrompt(String systemPrompt, String userInstructions, String question, List<Interaction> chatHistory, List<String> contexts) {
        if (Strings.isNullOrEmpty((String)systemPrompt) && Strings.isNullOrEmpty((String)userInstructions)) {
            systemPrompt = DEFAULT_SYSTEM_PROMPT;
        }
        StringBuilder bldr = new StringBuilder();
        if (!Strings.isNullOrEmpty((String)systemPrompt)) {
            bldr.append(systemPrompt);
            bldr.append(NEWLINE);
        }
        if (!Strings.isNullOrEmpty((String)userInstructions)) {
            bldr.append(userInstructions);
            bldr.append(NEWLINE);
        }
        for (int i = 0; i < contexts.size(); ++i) {
            bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i));
            bldr.append(NEWLINE);
        }
        if (!chatHistory.isEmpty()) {
            List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
            Collections.reverse(messages);
            messages.forEach(m -> {
                bldr.append(m.toString());
                bldr.append(NEWLINE);
            });
        }
        bldr.append("QUESTION: " + question);
        bldr.append(NEWLINE);
        return bldr.toString();
    }

    @VisibleForTesting
    static String buildMessageParameter(Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List<Interaction> chatHistory, List<String> contexts) {
        return PromptUtil.buildMessageParameter(provider, systemPrompt, userInstructions, question, chatHistory, contexts, null);
    }

    static String buildMessageParameter(Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List<Interaction> chatHistory, List<String> contexts, List<MessageBlock> llmMessages) {
        if (Strings.isNullOrEmpty((String)systemPrompt) && Strings.isNullOrEmpty((String)userInstructions)) {
            userInstructions = DEFAULT_SYSTEM_PROMPT;
        }
        MessageArrayBuilder messageArrayBuilder = new MessageArrayBuilder(provider);
        if (!Strings.isNullOrEmpty((String)systemPrompt)) {
            messageArrayBuilder.startMessage(ChatRole.SYSTEM);
            messageArrayBuilder.addTextContent(systemPrompt);
            messageArrayBuilder.endMessage();
        }
        messageArrayBuilder.startMessage(ChatRole.USER);
        boolean lastRoleIsAssistant = false;
        if (!Strings.isNullOrEmpty((String)userInstructions)) {
            messageArrayBuilder.addTextContent(userInstructions);
        }
        for (int i = 0; i < contexts.size(); ++i) {
            messageArrayBuilder.addTextContent("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i));
        }
        if (!chatHistory.isEmpty()) {
            int idx = chatHistory.size() - 1;
            Interaction firstInteraction = chatHistory.get(idx);
            messageArrayBuilder.addTextContent(firstInteraction.getInput());
            messageArrayBuilder.endMessage();
            messageArrayBuilder.startMessage(ChatRole.ASSISTANT, firstInteraction.getResponse());
            messageArrayBuilder.endMessage();
            if (chatHistory.size() > 1) {
                for (int i = --idx; i >= 0; --i) {
                    Interaction interaction = chatHistory.get(i);
                    messageArrayBuilder.startMessage(ChatRole.USER, interaction.getInput());
                    messageArrayBuilder.endMessage();
                    messageArrayBuilder.startMessage(ChatRole.ASSISTANT, interaction.getResponse());
                    messageArrayBuilder.endMessage();
                }
            }
            lastRoleIsAssistant = true;
        }
        if (llmMessages != null && !llmMessages.isEmpty()) {
            if (lastRoleIsAssistant) {
                messageArrayBuilder.startMessage(ChatRole.USER);
            }
            for (MessageBlock message : llmMessages) {
                List<MessageBlock.AbstractBlock> blockList = message.getBlockList();
                for (MessageBlock.Block block : blockList) {
                    switch (block.getType()) {
                        case "text": {
                            messageArrayBuilder.addTextContent(((MessageBlock.TextBlock)block).getText());
                            break;
                        }
                        case "image": {
                            MessageBlock.ImageBlock ib = (MessageBlock.ImageBlock)block;
                            if (ib.getData() != null) {
                                messageArrayBuilder.addImageData(ib.getFormat(), ib.getData());
                                break;
                            }
                            if (ib.getUrl() == null) break;
                            messageArrayBuilder.addImageUrl(ib.getFormat(), ib.getUrl());
                            break;
                        }
                        case "document": {
                            MessageBlock.DocumentBlock db = (MessageBlock.DocumentBlock)block;
                            messageArrayBuilder.addDocumentContent(db.getFormat(), db.getName(), db.getData());
                            break;
                        }
                    }
                }
            }
        } else {
            if (lastRoleIsAssistant) {
                messageArrayBuilder.startMessage(ChatRole.USER, "QUESTION: " + question + "\n");
            } else {
                messageArrayBuilder.addTextContent("QUESTION: " + question + "\n");
            }
            messageArrayBuilder.addTextContent("ANSWER:");
        }
        messageArrayBuilder.endMessage();
        return messageArrayBuilder.toJsonArray().toString();
    }

    public static String getPromptTemplate(String systemPrompt, String userInstructions) {
        return PromptUtil.getPromptTemplateAsJsonArray(systemPrompt, userInstructions).toString();
    }

    static JsonArray getPromptTemplateAsJsonArray(String systemPrompt, String userInstructions) {
        JsonArray messageArray = new JsonArray();
        if (!Strings.isNullOrEmpty((String)systemPrompt)) {
            messageArray.add((JsonElement)new Message(ChatRole.SYSTEM, systemPrompt).toJson());
        }
        if (!Strings.isNullOrEmpty((String)userInstructions)) {
            messageArray.add((JsonElement)new Message(ChatRole.USER, userInstructions).toJson());
        }
        return messageArray;
    }

    @Generated
    private PromptUtil() {
    }

    static class Messages {
        private List<Message> messages = new ArrayList<Message>();

        public Messages(List<Message> messages) {
            this.addMessages(messages);
        }

        public void addMessages(List<Message> messages) {
            this.messages.addAll(messages);
        }

        public static Messages fromInteractions(List<Interaction> interactions) {
            ArrayList<Message> messages = new ArrayList<Message>();
            for (Interaction interaction : interactions) {
                messages.add(new Message(ChatRole.USER, interaction.getInput()));
                messages.add(new Message(ChatRole.ASSISTANT, interaction.getResponse()));
            }
            return new Messages(messages);
        }

        @Generated
        public List<Message> getMessages() {
            return this.messages;
        }
    }

    static class MessageArrayBuilder {
        private final Llm.ModelProvider provider;
        private List<Message> messages = new ArrayList<Message>();
        private Message message = null;
        private Content content = null;

        public MessageArrayBuilder(Llm.ModelProvider provider) {
            if (!EnumSet.of(Llm.ModelProvider.OPENAI, Llm.ModelProvider.BEDROCK_CONVERSE).contains((Object)provider)) {
                throw new IllegalArgumentException("Unsupported provider: " + String.valueOf((Object)provider));
            }
            this.provider = provider;
        }

        public void startMessage(ChatRole role) {
            this.message = new Message();
            this.message.setChatRole(role);
            if (this.provider == Llm.ModelProvider.OPENAI) {
                this.content = new OpenAIContent();
            } else if (this.provider == Llm.ModelProvider.BEDROCK_CONVERSE) {
                this.content = new BedrockContent();
            }
        }

        public void startMessage(ChatRole role, String text) {
            this.startMessage(role);
            this.addTextContent(text);
        }

        public void endMessage() {
            this.message.setContent(this.content);
            this.messages.add(this.message);
            this.message = null;
            this.content = null;
        }

        public void addTextContent(String content) {
            if (this.message == null || this.content == null) {
                throw new RuntimeException("You must call startMessage before calling addTextContent !!");
            }
            this.content.addText(content);
        }

        public void addImageData(String format, String data) {
            if (this.content != null && this.content instanceof ImageContent) {
                ((ImageContent)this.content).addImageData(format, data);
            }
        }

        public void addImageUrl(String format, String url) {
            if (this.content != null && this.content instanceof ImageContent) {
                ((ImageContent)this.content).addImageUrl(format, url);
            }
        }

        public void addDocumentContent(String format, String name, String data) {
            if (this.content != null && this.content instanceof DocumentContent) {
                ((DocumentContent)this.content).addDocument(format, name, data);
            }
        }

        public JsonArray toJsonArray() {
            Preconditions.checkState((this.message == null && this.content == null ? 1 : 0) != 0, (Object)"You must call endMessage before calling toJsonArray !!");
            JsonArray ja = new JsonArray();
            for (Message message : this.messages) {
                ja.add((JsonElement)message.toJson());
            }
            return ja;
        }
    }

    static enum ChatRole {
        USER("user"),
        ASSISTANT("assistant"),
        SYSTEM("system");

        private String name;

        private ChatRole(String name) {
            this.name = name;
        }

        @Generated
        public String getName() {
            return this.name;
        }
    }

    static class Message {
        private static final String MESSAGE_FIELD_ROLE = "role";
        private static final String MESSAGE_FIELD_CONTENT = "content";
        private ChatRole chatRole;
        private String content;
        private JsonObject json = new JsonObject();

        public Message() {
        }

        public Message(ChatRole chatRole, String content) {
            this();
            this.setChatRole(chatRole);
            this.setContent(content);
        }

        public Message(ChatRole chatRole, Content content) {
            this();
            this.setChatRole(chatRole);
            this.setContent(content);
        }

        public void setChatRole(ChatRole chatRole) {
            this.chatRole = chatRole;
            this.json.remove(MESSAGE_FIELD_ROLE);
            this.json.add(MESSAGE_FIELD_ROLE, (JsonElement)new JsonPrimitive(chatRole.getName()));
        }

        public void setContent(String content) {
            this.content = StringEscapeUtils.escapeJson((String)content);
            this.json.remove(MESSAGE_FIELD_CONTENT);
            this.json.add(MESSAGE_FIELD_CONTENT, (JsonElement)new JsonPrimitive(this.content));
        }

        public void setContent(Content content) {
            this.json.remove(MESSAGE_FIELD_CONTENT);
            this.json.add(MESSAGE_FIELD_CONTENT, content.toJson());
        }

        public JsonObject toJson() {
            return this.json;
        }

        public String toString() {
            return String.format(Locale.ROOT, "%s: %s", this.chatRole.getName(), this.content);
        }

        @Generated
        public ChatRole getChatRole() {
            return this.chatRole;
        }

        @Generated
        public String getContent() {
            return this.content;
        }
    }

    static class BedrockContent
    implements MultimodalContent {
        private JsonArray json = new JsonArray();

        public BedrockContent() {
        }

        public BedrockContent(String type, String value) {
            if (type.equals(PromptUtil.CONTENT_FIELD_TEXT)) {
                this.addText(value);
            }
        }

        @Override
        public void addText(String text) {
            JsonObject content = new JsonObject();
            content.add(PromptUtil.CONTENT_FIELD_TEXT, (JsonElement)new JsonPrimitive(text));
            this.json.add((JsonElement)content);
        }

        @Override
        public JsonElement toJson() {
            return this.json;
        }

        @Override
        public void addImageData(String format, String data) {
            JsonObject imageData = new JsonObject();
            imageData.add("bytes", (JsonElement)new JsonPrimitive(data));
            JsonObject image = new JsonObject();
            image.add("format", (JsonElement)new JsonPrimitive(format));
            image.add("source", (JsonElement)imageData);
            JsonObject content = new JsonObject();
            content.add("image", (JsonElement)image);
            this.json.add((JsonElement)content);
        }

        @Override
        public void addImageUrl(String format, String url) {
        }

        @Override
        public void addDocument(String format, String name, String data) {
            JsonObject documentData = new JsonObject();
            documentData.add("bytes", (JsonElement)new JsonPrimitive(data));
            JsonObject document = new JsonObject();
            document.add("format", (JsonElement)new JsonPrimitive(format));
            document.add("name", (JsonElement)new JsonPrimitive(name));
            document.add("source", (JsonElement)documentData);
            JsonObject content = new JsonObject();
            content.add("document", (JsonElement)document);
            this.json.add((JsonElement)content);
        }
    }

    static class OpenAIContent
    implements ImageContent {
        private JsonArray json = new JsonArray();

        @Override
        public void addText(String text) {
            JsonObject content = new JsonObject();
            content.add(PromptUtil.CONTENT_FIELD_TYPE, (JsonElement)new JsonPrimitive(PromptUtil.CONTENT_FIELD_TEXT));
            content.add(PromptUtil.CONTENT_FIELD_TEXT, (JsonElement)new JsonPrimitive(text));
            this.json.add((JsonElement)content);
        }

        @Override
        public void addImageData(String format, String data) {
            JsonObject content = new JsonObject();
            content.add(PromptUtil.CONTENT_FIELD_TYPE, (JsonElement)new JsonPrimitive("image_url"));
            JsonObject urlContent = new JsonObject();
            String imageData = String.format(Locale.ROOT, "data:image/%s;base64,%s", format, data);
            urlContent.add("url", (JsonElement)new JsonPrimitive(imageData));
            content.add("image_url", (JsonElement)urlContent);
            this.json.add((JsonElement)content);
        }

        @Override
        public void addImageUrl(String format, String url) {
            JsonObject content = new JsonObject();
            content.add(PromptUtil.CONTENT_FIELD_TYPE, (JsonElement)new JsonPrimitive("image_url"));
            JsonObject urlContent = new JsonObject();
            urlContent.add("url", (JsonElement)new JsonPrimitive(url));
            content.add("image_url", (JsonElement)urlContent);
            this.json.add((JsonElement)content);
        }

        @Override
        public JsonElement toJson() {
            return this.json;
        }
    }

    static interface MultimodalContent
    extends ImageContent,
    DocumentContent {
    }

    static interface DocumentContent
    extends Content {
        public void addDocument(String var1, String var2, String var3);
    }

    static interface ImageContent
    extends Content {
        public void addImageData(String var1, String var2);

        public void addImageUrl(String var1, String var2);
    }

    static interface Content {
        public void addText(String var1);

        public JsonElement toJson();
    }
}

