/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_similarity;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;

public class TextSimilarityTranslator
extends SentenceTransformerTranslator {
    public final String SIMILARITY_NAME = "similarity";

    @Override
    public NDList processInput(TranslatorContext ctx, Input input) {
        String sentence = input.getAsString(0);
        String context = input.getAsString(1);
        NDManager manager = ctx.getNDManager();
        NDList ndList = new NDList();
        Encoding encodings = this.tokenizer.encode(sentence, context);
        long[] indices = encodings.getIds();
        long[] attentionMask = encodings.getAttentionMask();
        long[] tokenTypes = encodings.getTypeIds();
        NDArray indicesArray = manager.create(indices);
        indicesArray.setName("input_ids");
        NDArray attentionMaskArray = manager.create(attentionMask);
        attentionMaskArray.setName("attention_mask");
        NDArray tokenTypeArray = manager.create(tokenTypes);
        tokenTypeArray.setName("token_type_ids");
        ndList.add((Object)indicesArray);
        ndList.add((Object)attentionMaskArray);
        ndList.add((Object)tokenTypeArray);
        return ndList;
    }

    public Output processOutput(TranslatorContext ctx, NDList list) {
        Output output = new Output(200, "OK");
        ArrayList<ModelTensor> outputs = new ArrayList<ModelTensor>();
        for (NDArray ndArray : list) {
            String name = "similarity";
            Number[] data = ndArray.toArray();
            long[] shape = ndArray.getShape().getShape();
            DataType dataType = ndArray.getDataType();
            MLResultDataType mlResultDataType = MLResultDataType.valueOf((String)dataType.name());
            ByteBuffer buffer = ndArray.toByteBuffer();
            ModelTensor tensor = ModelTensor.builder().name(name).data(data).shape(shape).dataType(mlResultDataType).byteBuffer(buffer).build();
            outputs.add(tensor);
        }
        ModelTensors modelTensorOutput = new ModelTensors(outputs);
        output.add(modelTensorOutput.toBytes());
        return output;
    }
}

