package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.DocumentContext;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

/* loaded from: input_file:org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.class */
public class MLInferenceSearchRequestProcessor extends AbstractProcessor implements SearchRequestProcessor, ModelExecutor {
    private final NamedXContentRegistry xContentRegistry;
    private static final Logger logger = LogManager.getLogger(MLInferenceSearchRequestProcessor.class);
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final String functionName;
    private String queryTemplate;
    private final boolean fullResponsePath;
    private final boolean ignoreFailure;
    private final String modelInput;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final String QUERY_TEMPLATE = "query_template";
    public static final String FUNCTION_NAME = "function_name";
    public static final String FULL_RESPONSE_PATH = "full_response_path";
    public static final String MODEL_INPUT = "model_input";
    public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;

    /* loaded from: input_file:org/opensearch/ml/processor/MLInferenceSearchRequestProcessor$Factory.class */
    public static class Factory implements Processor.Factory<SearchRequestProcessor> {
        private final Client client;
        private final NamedXContentRegistry xContentRegistry;

        public Factory(Client client, NamedXContentRegistry namedXContentRegistry) {
            this.client = client;
            this.xContentRegistry = namedXContentRegistry;
        }

        public MLInferenceSearchRequestProcessor create(Map<String, Processor.Factory<SearchRequestProcessor>> map, String str, String str2, boolean z, Map<String, Object> map2, Processor.PipelineContext pipelineContext) {
            String readStringProperty = ConfigurationUtils.readStringProperty("ml_inference", str, map2, "model_id");
            String readOptionalStringProperty = ConfigurationUtils.readOptionalStringProperty("ml_inference", str, map2, MLInferenceSearchRequestProcessor.QUERY_TEMPLATE);
            Map readOptionalMap = ConfigurationUtils.readOptionalMap("ml_inference", str, map2, InferenceProcessorAttributes.MODEL_CONFIG);
            List readList = ConfigurationUtils.readList("ml_inference", str, map2, InferenceProcessorAttributes.INPUT_MAP);
            List readList2 = ConfigurationUtils.readList("ml_inference", str, map2, InferenceProcessorAttributes.OUTPUT_MAP);
            int intValue = ConfigurationUtils.readIntProperty("ml_inference", str, map2, InferenceProcessorAttributes.MAX_PREDICTION_TASKS, 10).intValue();
            boolean readBooleanProperty = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, "ignore_missing", false);
            String readStringProperty2 = ConfigurationUtils.readStringProperty("ml_inference", str, map2, "function_name", FunctionName.REMOTE.name());
            String readOptionalStringProperty2 = ConfigurationUtils.readOptionalStringProperty("ml_inference", str, map2, "model_input");
            if (readStringProperty2.equalsIgnoreCase("remote")) {
                readOptionalStringProperty2 = readOptionalStringProperty2 != null ? readOptionalStringProperty2 : "{ \"parameters\": ${ml_inference.parameters} }";
            } else if (readOptionalStringProperty2 == null) {
                throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
            }
            boolean readBooleanProperty2 = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, "full_response_path", !readStringProperty2.equalsIgnoreCase(FunctionName.REMOTE.name()));
            boolean readBooleanProperty3 = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, "ignore_failure", false);
            Map map3 = null;
            if (readOptionalMap != null) {
                map3 = StringUtils.getParameterMap(readOptionalMap);
            }
            if (readList == null || readList.size() <= intValue) {
                return new MLInferenceSearchRequestProcessor(readStringProperty, readOptionalStringProperty, readList, readList2, map3, intValue, str, str2, readBooleanProperty, readStringProperty2, readBooleanProperty2, readBooleanProperty3, readOptionalStringProperty2, this.client, this.xContentRegistry);
            }
            throw new IllegalArgumentException("The number of prediction task setting in this process is " + readList.size() + ". It exceeds the max_prediction_tasks of " + intValue + ". Please reduce the size of input_map or increase max_prediction_tasks.");
        }

        /* renamed from: create, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Processor m79create(Map map, String str, String str2, boolean z, Map map2, Processor.PipelineContext pipelineContext) throws Exception {
            return create((Map<String, Processor.Factory<SearchRequestProcessor>>) map, str, str2, z, (Map<String, Object>) map2, pipelineContext);
        }
    }

    protected MLInferenceSearchRequestProcessor(String str, String str2, List<Map<String, String>> list, List<Map<String, String>> list2, Map<String, String> map, int i, String str3, String str4, boolean z, String str5, boolean z2, boolean z3, String str6, Client client2, NamedXContentRegistry namedXContentRegistry) {
        super(str3, str4, z3);
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(str, list, list2, map, i);
        this.ignoreMissing = z;
        this.functionName = str5;
        this.fullResponsePath = z2;
        this.queryTemplate = str2;
        this.ignoreFailure = z3;
        this.modelInput = str6;
        client = client2;
        this.xContentRegistry = namedXContentRegistry;
    }

    public SearchRequest processRequest(SearchRequest searchRequest) throws Exception {
        throw new RuntimeException("ML inference search request processor make asynchronous calls and does not call processRequest");
    }

    public void processRequestAsync(SearchRequest searchRequest, PipelineProcessingContext pipelineProcessingContext, ActionListener<SearchRequest> actionListener) {
        try {
            if (searchRequest.source() == null) {
                throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request.");
            }
            rewriteQueryString(searchRequest, searchRequest.source().toString(), actionListener);
        } catch (Exception e) {
            if (this.ignoreFailure) {
                actionListener.onResponse(searchRequest);
            } else {
                actionListener.onFailure(e);
            }
        }
    }

    private void rewriteQueryString(SearchRequest searchRequest, String str, ActionListener<SearchRequest> actionListener) throws IOException {
        List<Map<String, String>> inputMaps = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> outputMaps = this.inferenceProcessorAttributes.getOutputMaps();
        int size = inputMaps != null ? inputMaps.size() : 0;
        if (size == 0) {
            actionListener.onResponse(searchRequest);
            return;
        }
        try {
            if (!validateQueryFieldInQueryString(inputMaps, outputMaps, str)) {
                actionListener.onResponse(searchRequest);
            }
            GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener = createBatchPredictionListener(createRewriteRequestListener(searchRequest, str, actionListener, outputMaps), size);
            for (int i = 0; i < size; i++) {
                processPredictions(str, inputMaps, i, createBatchPredictionListener);
            }
        } catch (Exception e) {
            if (this.ignoreMissing) {
                actionListener.onResponse(searchRequest);
            } else {
                actionListener.onFailure(e);
            }
        }
    }

    private ActionListener<Map<Integer, MLOutput>> createRewriteRequestListener(final SearchRequest searchRequest, final String str, final ActionListener<SearchRequest> actionListener, final List<Map<String, String>> list) {
        return new ActionListener<Map<Integer, MLOutput>>() { // from class: org.opensearch.ml.processor.MLInferenceSearchRequestProcessor.1
            public void onResponse(Map<Integer, MLOutput> map) {
                for (Map.Entry<Integer, MLOutput> entry : map.entrySet()) {
                    Integer key = entry.getKey();
                    MLOutput value = entry.getValue();
                    Map<String, String> map2 = (Map) list.get(key.intValue());
                    try {
                        if (MLInferenceSearchRequestProcessor.this.queryTemplate == null) {
                            Object read = JsonPath.parse(str).read("$", new Predicate[0]);
                            updateIncomeQueryObject(read, map2, value);
                            searchRequest.source(MLInferenceSearchRequestProcessor.getSearchSourceBuilder(MLInferenceSearchRequestProcessor.this.xContentRegistry, StringUtils.toJson(read)));
                            actionListener.onResponse(searchRequest);
                        } else {
                            searchRequest.source(MLInferenceSearchRequestProcessor.getSearchSourceBuilder(MLInferenceSearchRequestProcessor.this.xContentRegistry, updateQueryTemplate(MLInferenceSearchRequestProcessor.this.queryTemplate, map2, value)));
                            actionListener.onResponse(searchRequest);
                        }
                    } catch (Exception e) {
                        if (MLInferenceSearchRequestProcessor.this.ignoreFailure) {
                            MLInferenceSearchRequestProcessor.logger.error("Failed in writing prediction outcomes to new query", e);
                            actionListener.onResponse(searchRequest);
                        } else {
                            actionListener.onFailure(e);
                        }
                    }
                }
            }

            public void onFailure(Exception exc) {
                if (!MLInferenceSearchRequestProcessor.this.ignoreFailure) {
                    actionListener.onFailure(exc);
                } else {
                    MLInferenceSearchRequestProcessor.logger.error("Failed in writing prediction outcomes to new query", exc);
                    actionListener.onResponse(searchRequest);
                }
            }

            private void updateIncomeQueryObject(Object obj, Map<String, String> map, MLOutput mLOutput) {
                for (Map.Entry<String, String> entry : map.entrySet()) {
                    String key = entry.getKey();
                    Object modelOutputValue = MLInferenceSearchRequestProcessor.this.getModelOutputValue(mLOutput, entry.getValue(), MLInferenceSearchRequestProcessor.this.ignoreMissing, MLInferenceSearchRequestProcessor.this.fullResponsePath);
                    JsonPath.parse(obj).set("$." + key, modelOutputValue, new Predicate[0]);
                }
            }

            private String updateQueryTemplate(String str2, Map<String, String> map, MLOutput mLOutput) {
                HashMap hashMap = new HashMap();
                for (Map.Entry<String, String> entry : map.entrySet()) {
                    hashMap.put(entry.getKey(), MLInferenceSearchRequestProcessor.this.getModelOutputValue(mLOutput, entry.getValue(), MLInferenceSearchRequestProcessor.this.ignoreMissing, MLInferenceSearchRequestProcessor.this.fullResponsePath));
                }
                return new StringSubstitutor(hashMap).replace(str2);
            }
        };
    }

    private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener(final ActionListener<Map<Integer, MLOutput>> actionListener, int i) {
        return new GroupedActionListener<>(new ActionListener<Collection<Map<Integer, MLOutput>>>(this) { // from class: org.opensearch.ml.processor.MLInferenceSearchRequestProcessor.2
            public void onResponse(Collection<Map<Integer, MLOutput>> collection) {
                HashMap hashMap = new HashMap();
                Iterator<Map<Integer, MLOutput>> it = collection.iterator();
                while (it.hasNext()) {
                    hashMap.putAll(it.next());
                }
                actionListener.onResponse(hashMap);
            }

            public void onFailure(Exception exc) {
                MLInferenceSearchRequestProcessor.logger.error("Prediction Failed:", exc);
                actionListener.onFailure(exc);
            }
        }, Math.max(i, 1));
    }

    private boolean validateQueryFieldInQueryString(List<Map<String, String>> list, List<Map<String, String>> list2, String str) {
        DocumentContext parse = JsonPath.using(Configuration.defaultConfiguration().addOptions(new Option[]{Option.SUPPRESS_EXCEPTIONS})).parse(str);
        Iterator<Map<String, String>> it = list.iterator();
        while (it.hasNext()) {
            Iterator<Map.Entry<String, String>> it2 = it.next().entrySet().iterator();
            while (it2.hasNext()) {
                String value = it2.next().getValue();
                if (parse.read(value, new Predicate[0]) == null) {
                    throw new IllegalArgumentException("cannot find field: " + value + " in query string: " + parse.jsonString());
                }
            }
        }
        if (this.queryTemplate != null) {
            return true;
        }
        Iterator<Map<String, String>> it3 = list2.iterator();
        while (it3.hasNext()) {
            Iterator<Map.Entry<String, String>> it4 = it3.next().entrySet().iterator();
            while (it4.hasNext()) {
                String key = it4.next().getKey();
                if (parse.read(key, new Predicate[0]) == null) {
                    throw new IllegalArgumentException("cannot find field: " + key + " in query string: " + parse.jsonString());
                }
            }
        }
        return true;
    }

    private void processPredictions(String str, List<Map<String, String>> list, final int i, final GroupedActionListener groupedActionListener) throws IOException {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            hashMap.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
            hashMap2.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        new HashMap();
        if (list != null) {
            Map<String, String> map = list.get(i);
            Object read = JsonPath.parse(str).read("$", new Predicate[0]);
            for (Map.Entry<String, String> entry : map.entrySet()) {
                hashMap.put(entry.getKey(), StringUtils.toJson(JsonPath.parse(read).read(entry.getValue(), new Predicate[0])));
            }
        }
        HashSet<String> hashSet = new HashSet(hashMap.keySet());
        hashSet.removeAll(hashMap2.keySet());
        HashMap hashMap3 = new HashMap();
        for (String str2 : hashSet) {
            hashMap3.put(str2, hashMap.get(str2));
        }
        client.execute(MLPredictionTaskAction.INSTANCE, getMLModelInferenceRequest(this.xContentRegistry, hashMap, hashMap2, hashMap3, this.inferenceProcessorAttributes.getModelId(), this.functionName, this.modelInput), new ActionListener<MLTaskResponse>(this) { // from class: org.opensearch.ml.processor.MLInferenceSearchRequestProcessor.3
            public void onResponse(MLTaskResponse mLTaskResponse) {
                MLOutput output = mLTaskResponse.getOutput();
                HashMap hashMap4 = new HashMap();
                hashMap4.put(Integer.valueOf(i), output);
                groupedActionListener.onResponse(hashMap4);
            }

            public void onFailure(Exception exc) {
                groupedActionListener.onFailure(exc);
            }
        });
    }

    private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry namedXContentRegistry, String str) throws IOException {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        XContentParser createParser = XContentType.JSON.xContent().createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, str);
        XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, createParser.nextToken(), createParser);
        searchSourceBuilder.parseXContent(createParser);
        return searchSourceBuilder;
    }

    public String getType() {
        return "ml_inference";
    }
}
