package org.opensearch.neuralsearch.ml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.util.RetryUtil;

/* loaded from: input_file:org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.class */
public class MLCommonsClientAccessor {

    @Generated
    private static final Logger log = LogManager.getLogger(MLCommonsClientAccessor.class);
    private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
    private final MachineLearningNodeClient mlClient;

    public void inferenceSentence(@NonNull String str, @NonNull String str2, @NonNull ActionListener<List<Float>> actionListener) {
        Objects.requireNonNull(str, "modelId is marked non-null but is null");
        Objects.requireNonNull(str2, "inputText is marked non-null but is null");
        Objects.requireNonNull(actionListener, "listener is marked non-null but is null");
        List<String> list = TARGET_RESPONSE_FILTERS;
        List<String> of = List.of(str2);
        CheckedConsumer checkedConsumer = list2 -> {
            if (list2.size() != 1) {
                actionListener.onFailure(new IllegalStateException("Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + list2.size() + "]"));
            } else {
                actionListener.onResponse((List) list2.get(0));
            }
        };
        Objects.requireNonNull(actionListener);
        inferenceSentences(list, str, of, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public void inferenceSentences(@NonNull String str, @NonNull List<String> list, @NonNull ActionListener<List<List<Float>>> actionListener) {
        Objects.requireNonNull(str, "modelId is marked non-null but is null");
        Objects.requireNonNull(list, "inputText is marked non-null but is null");
        Objects.requireNonNull(actionListener, "listener is marked non-null but is null");
        inferenceSentences(TARGET_RESPONSE_FILTERS, str, list, actionListener);
    }

    public void inferenceSentences(@NonNull List<String> list, @NonNull String str, @NonNull List<String> list2, @NonNull ActionListener<List<List<Float>>> actionListener) {
        Objects.requireNonNull(list, "targetResponseFilters is marked non-null but is null");
        Objects.requireNonNull(str, "modelId is marked non-null but is null");
        Objects.requireNonNull(list2, "inputText is marked non-null but is null");
        Objects.requireNonNull(actionListener, "listener is marked non-null but is null");
        retryableInferenceSentencesWithVectorResult(list, str, list2, 0, actionListener);
    }

    public void inferenceSentencesWithMapResult(@NonNull String str, @NonNull List<String> list, @NonNull ActionListener<List<Map<String, ?>>> actionListener) {
        Objects.requireNonNull(str, "modelId is marked non-null but is null");
        Objects.requireNonNull(list, "inputText is marked non-null but is null");
        Objects.requireNonNull(actionListener, "listener is marked non-null but is null");
        retryableInferenceSentencesWithMapResult(str, list, 0, actionListener);
    }

    public void inferenceSentences(@NonNull String str, @NonNull Map<String, String> map, @NonNull ActionListener<List<Float>> actionListener) {
        Objects.requireNonNull(str, "modelId is marked non-null but is null");
        Objects.requireNonNull(map, "inputObjects is marked non-null but is null");
        Objects.requireNonNull(actionListener, "listener is marked non-null but is null");
        retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, str, map, 0, actionListener);
    }

    public void inferenceSimilarity(@NonNull String str, @NonNull String str2, @NonNull List<String> list, @NonNull ActionListener<List<Float>> actionListener) {
        Objects.requireNonNull(str, "modelId is marked non-null but is null");
        Objects.requireNonNull(str2, "queryText is marked non-null but is null");
        Objects.requireNonNull(list, "inputText is marked non-null but is null");
        Objects.requireNonNull(actionListener, "listener is marked non-null but is null");
        retryableInferenceSimilarityWithVectorResult(str, str2, list, 0, actionListener);
    }

    private void retryableInferenceSentencesWithMapResult(String str, List<String> list, int i, ActionListener<List<Map<String, ?>>> actionListener) {
        this.mlClient.predict(str, createMLTextInput(null, list), ActionListener.wrap(mLOutput -> {
            actionListener.onResponse(buildMapResultFromResponse(mLOutput));
        }, exc -> {
            if (RetryUtil.shouldRetry(exc, i)) {
                retryableInferenceSentencesWithMapResult(str, list, i + 1, actionListener);
            } else {
                actionListener.onFailure(exc);
            }
        }));
    }

    private void retryableInferenceSentencesWithVectorResult(List<String> list, String str, List<String> list2, int i, ActionListener<List<List<Float>>> actionListener) {
        this.mlClient.predict(str, createMLTextInput(list, list2), ActionListener.wrap(mLOutput -> {
            actionListener.onResponse(buildVectorFromResponse(mLOutput));
        }, exc -> {
            if (RetryUtil.shouldRetry(exc, i)) {
                retryableInferenceSentencesWithVectorResult(list, str, list2, i + 1, actionListener);
            } else {
                actionListener.onFailure(exc);
            }
        }));
    }

    private void retryableInferenceSimilarityWithVectorResult(String str, String str2, List<String> list, int i, ActionListener<List<Float>> actionListener) {
        this.mlClient.predict(str, createMLTextPairsInput(str2, list), ActionListener.wrap(mLOutput -> {
            actionListener.onResponse((List) buildVectorFromResponse(mLOutput).stream().map(list2 -> {
                return (Float) list2.get(0);
            }).collect(Collectors.toList()));
        }, exc -> {
            if (RetryUtil.shouldRetry(exc, i)) {
                retryableInferenceSimilarityWithVectorResult(str, str2, list, i + 1, actionListener);
            } else {
                actionListener.onFailure(exc);
            }
        }));
    }

    private MLInput createMLTextInput(List<String> list, List<String> list2) {
        return new MLInput(FunctionName.TEXT_EMBEDDING, (MLAlgoParams) null, new TextDocsInputDataSet(list2, new ModelResultFilter(false, true, list, (List) null)));
    }

    private MLInput createMLTextPairsInput(String str, List<String> list) {
        return new MLInput(FunctionName.TEXT_SIMILARITY, (MLAlgoParams) null, new TextSimilarityInputDataSet(str, list));
    }

    private List<List<Float>> buildVectorFromResponse(MLOutput mLOutput) {
        ArrayList arrayList = new ArrayList();
        Iterator it = ((ModelTensorOutput) mLOutput).getMlModelOutputs().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((ModelTensors) it.next()).getMlModelTensors().iterator();
            while (it2.hasNext()) {
                arrayList.add((List) Arrays.stream(((ModelTensor) it2.next()).getData()).map(number -> {
                    return (Float) number;
                }).collect(Collectors.toList()));
            }
        }
        return arrayList;
    }

    private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mLOutput) {
        List mlModelOutputs = ((ModelTensorOutput) mLOutput).getMlModelOutputs();
        if (CollectionUtils.isEmpty(mlModelOutputs) || CollectionUtils.isEmpty(((ModelTensors) mlModelOutputs.get(0)).getMlModelTensors())) {
            throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]");
        }
        ArrayList arrayList = new ArrayList();
        Iterator it = mlModelOutputs.iterator();
        while (it.hasNext()) {
            Iterator it2 = ((ModelTensors) it.next()).getMlModelTensors().iterator();
            while (it2.hasNext()) {
                arrayList.add(((ModelTensor) it2.next()).getDataAsMap());
            }
        }
        return arrayList;
    }

    private List<Float> buildSingleVectorFromResponse(MLOutput mLOutput) {
        List<List<Float>> buildVectorFromResponse = buildVectorFromResponse(mLOutput);
        return buildVectorFromResponse.isEmpty() ? new ArrayList() : buildVectorFromResponse.get(0);
    }

    private void retryableInferenceSentencesWithSingleVectorResult(List<String> list, String str, Map<String, String> map, int i, ActionListener<List<Float>> actionListener) {
        this.mlClient.predict(str, createMLMultimodalInput(list, map), ActionListener.wrap(mLOutput -> {
            List<Float> buildSingleVectorFromResponse = buildSingleVectorFromResponse(mLOutput);
            log.debug("Inference Response for input sentence is : {} ", buildSingleVectorFromResponse);
            actionListener.onResponse(buildSingleVectorFromResponse);
        }, exc -> {
            if (RetryUtil.shouldRetry(exc, i)) {
                retryableInferenceSentencesWithSingleVectorResult(list, str, map, i + 1, actionListener);
            } else {
                actionListener.onFailure(exc);
            }
        }));
    }

    private MLInput createMLMultimodalInput(List<String> list, Map<String, String> map) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(map.get(TextImageEmbeddingProcessor.INPUT_TEXT));
        if (map.containsKey(TextImageEmbeddingProcessor.INPUT_IMAGE)) {
            arrayList.add(map.get(TextImageEmbeddingProcessor.INPUT_IMAGE));
        }
        return new MLInput(FunctionName.TEXT_EMBEDDING, (MLAlgoParams) null, new TextDocsInputDataSet(arrayList, new ModelResultFilter(false, true, list, (List) null)));
    }

    @Generated
    public MLCommonsClientAccessor(MachineLearningNodeClient machineLearningNodeClient) {
        this.mlClient = machineLearningNodeClient;
    }
}
