package org.opensearch.ml.processor;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.ValueSource;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.ScriptService;

/* loaded from: input_file:org/opensearch/ml/processor/MLInferenceIngestProcessor.class */
public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor {
    private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class);
    public static final String DOT_SYMBOL = ".";
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final String functionName;
    private final boolean fullResponsePath;
    private final boolean ignoreFailure;
    private final boolean override;
    private final String modelInput;
    private final ScriptService scriptService;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final String OVERRIDE = "override";
    public static final String FUNCTION_NAME = "function_name";
    public static final String FULL_RESPONSE_PATH = "full_response_path";
    public static final String MODEL_INPUT = "model_input";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
    public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
    private final NamedXContentRegistry xContentRegistry;

    /* loaded from: input_file:org/opensearch/ml/processor/MLInferenceIngestProcessor$Factory.class */
    public static class Factory implements Processor.Factory {
        private final ScriptService scriptService;
        private final Client client;
        private final NamedXContentRegistry xContentRegistry;

        public Factory(ScriptService scriptService, Client client, NamedXContentRegistry namedXContentRegistry) {
            this.scriptService = scriptService;
            this.client = client;
            this.xContentRegistry = namedXContentRegistry;
        }

        public MLInferenceIngestProcessor create(Map<String, Processor.Factory> map, String str, String str2, Map<String, Object> map2) throws Exception {
            String readStringProperty = ConfigurationUtils.readStringProperty("ml_inference", str, map2, "model_id");
            Map readOptionalMap = ConfigurationUtils.readOptionalMap("ml_inference", str, map2, InferenceProcessorAttributes.MODEL_CONFIG);
            List readOptionalList = ConfigurationUtils.readOptionalList("ml_inference", str, map2, InferenceProcessorAttributes.INPUT_MAP);
            List readOptionalList2 = ConfigurationUtils.readOptionalList("ml_inference", str, map2, InferenceProcessorAttributes.OUTPUT_MAP);
            int intValue = ConfigurationUtils.readIntProperty("ml_inference", str, map2, InferenceProcessorAttributes.MAX_PREDICTION_TASKS, 10).intValue();
            boolean readBooleanProperty = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, "ignore_missing", false);
            boolean readBooleanProperty2 = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, MLInferenceIngestProcessor.OVERRIDE, false);
            String readStringProperty2 = ConfigurationUtils.readStringProperty("ml_inference", str, map2, "function_name", FunctionName.REMOTE.name());
            String readOptionalStringProperty = ConfigurationUtils.readOptionalStringProperty("ml_inference", str, map2, "model_input");
            if (readStringProperty2.equalsIgnoreCase("remote")) {
                readOptionalStringProperty = readOptionalStringProperty != null ? readOptionalStringProperty : "{ \"parameters\": ${ml_inference.parameters} }";
            } else if (readOptionalStringProperty == null) {
                throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
            }
            boolean readBooleanProperty3 = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, "full_response_path", !readStringProperty2.equalsIgnoreCase(FunctionName.REMOTE.name()));
            boolean readBooleanProperty4 = ConfigurationUtils.readBooleanProperty("ml_inference", str, map2, "ignore_failure", false);
            Map map3 = null;
            if (readOptionalMap != null) {
                map3 = StringUtils.getParameterMap(readOptionalMap);
            }
            if (readOptionalList != null && readOptionalList.size() > intValue) {
                throw new IllegalArgumentException("The number of prediction task setting in this process is " + readOptionalList.size() + ". It exceeds the max_prediction_tasks of " + intValue + ". Please reduce the size of input_map or increase max_prediction_tasks.");
            }
            if (readOptionalList == null || readOptionalList2 == null || readOptionalList2.size() == readOptionalList.size()) {
                return new MLInferenceIngestProcessor(readStringProperty, readOptionalList, readOptionalList2, map3, intValue, str, str2, readBooleanProperty, readStringProperty2, readBooleanProperty3, readBooleanProperty4, readBooleanProperty2, readOptionalStringProperty, this.scriptService, this.client, this.xContentRegistry);
            }
            throw new IllegalArgumentException("The length of output_map and the length of input_map do no match.");
        }

        /* renamed from: create, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Processor m77create(Map map, String str, String str2, Map map2) throws Exception {
            return create((Map<String, Processor.Factory>) map, str, str2, (Map<String, Object>) map2);
        }
    }

    protected MLInferenceIngestProcessor(String str, List<Map<String, String>> list, List<Map<String, String>> list2, Map<String, String> map, int i, String str2, String str3, boolean z, String str4, boolean z2, boolean z3, boolean z4, String str5, ScriptService scriptService, Client client2, NamedXContentRegistry namedXContentRegistry) {
        super(str2, str3);
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(str, list, list2, map, i);
        this.ignoreMissing = z;
        this.functionName = str4;
        this.fullResponsePath = z2;
        this.ignoreFailure = z3;
        this.override = z4;
        this.modelInput = str5;
        this.scriptService = scriptService;
        client = client2;
        this.xContentRegistry = namedXContentRegistry;
    }

    public void execute(final IngestDocument ingestDocument, final BiConsumer<IngestDocument, Exception> biConsumer) {
        List<Map<String, String>> inputMaps = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> outputMaps = this.inferenceProcessorAttributes.getOutputMaps();
        int size = inputMaps != null ? inputMaps.size() : 0;
        GroupedActionListener<Void> groupedActionListener = new GroupedActionListener<>(new ActionListener<Collection<Void>>() { // from class: org.opensearch.ml.processor.MLInferenceIngestProcessor.1
            public void onResponse(Collection<Void> collection) {
                biConsumer.accept(ingestDocument, null);
            }

            public void onFailure(Exception exc) {
                if (MLInferenceIngestProcessor.this.ignoreFailure) {
                    biConsumer.accept(ingestDocument, null);
                } else {
                    biConsumer.accept(null, exc);
                }
            }
        }, Math.max(size, 1));
        for (int i = 0; i < Math.max(size, 1); i++) {
            try {
                processPredictions(ingestDocument, groupedActionListener, inputMaps, outputMaps, i, size);
            } catch (Exception e) {
                groupedActionListener.onFailure(e);
            }
        }
    }

    public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
        throw new UnsupportedOperationException("this method should not get executed.");
    }

    private void processPredictions(final IngestDocument ingestDocument, final GroupedActionListener<Void> groupedActionListener, List<Map<String, String>> list, final List<Map<String, String>> list2, final int i, int i2) throws IOException {
        Map<String, String> hashMap = new HashMap<>();
        Map<String, String> hashMap2 = new HashMap<>();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            hashMap.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
            hashMap2.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        HashMap hashMap3 = new HashMap();
        hashMap3.putAll(ingestDocument.getSourceAndMetadata());
        hashMap3.put("_ingest", ingestDocument.getIngestMetadata());
        final HashMap hashMap4 = new HashMap();
        if (list2 != null) {
            Map<String, String> map = list2.get(i);
            Iterator<Map.Entry<String, String>> it = map.entrySet().iterator();
            while (it.hasNext()) {
                String key = it.next().getKey();
                hashMap4.put(key, writeNewDotPathForNestedObject(hashMap3, key));
            }
            Iterator<Map.Entry<String, String>> it2 = map.entrySet().iterator();
            while (it2.hasNext()) {
                String key2 = it2.next().getKey();
                List list3 = (List) hashMap4.get(key2);
                int i3 = 0;
                Iterator it3 = list3.iterator();
                while (it3.hasNext()) {
                    if (ingestDocument.hasField((String) it3.next())) {
                        i3++;
                    }
                }
                if (!this.override && i3 == list3.size()) {
                    logger.debug("{} already exists in the ingest document. Removing it from output mapping", key2);
                    hashMap4.remove(key2);
                }
            }
            if (hashMap4.size() == 0) {
                groupedActionListener.onResponse((Object) null);
                return;
            }
        }
        if (i2 == 0) {
            for (String str : ingestDocument.getSourceAndMetadata().keySet()) {
                getMappedModelInputFromDocuments(ingestDocument, hashMap, str, str);
            }
        } else {
            for (Map.Entry<String, String> entry : list.get(i).entrySet()) {
                getMappedModelInputFromDocuments(ingestDocument, hashMap, entry.getValue(), entry.getKey());
            }
        }
        HashSet<String> hashSet = new HashSet(hashMap.keySet());
        hashSet.removeAll(hashMap2.keySet());
        Map<String, String> hashMap5 = new HashMap<>();
        for (String str2 : hashSet) {
            hashMap5.put(str2, hashMap.get(str2));
        }
        client.execute(MLPredictionTaskAction.INSTANCE, getMLModelInferenceRequest(this.xContentRegistry, hashMap, hashMap2, hashMap5, this.inferenceProcessorAttributes.getModelId(), this.functionName, this.modelInput), new ActionListener<MLTaskResponse>() { // from class: org.opensearch.ml.processor.MLInferenceIngestProcessor.2
            public void onResponse(MLTaskResponse mLTaskResponse) {
                MLOutput output = mLTaskResponse.getOutput();
                if (list2 == null || list2.isEmpty()) {
                    MLInferenceIngestProcessor.this.appendFieldValue(output, null, "inference_results", ingestDocument);
                } else {
                    for (Map.Entry entry2 : ((Map) list2.get(i)).entrySet()) {
                        String str3 = (String) entry2.getKey();
                        String str4 = (String) entry2.getValue();
                        if (hashMap4.containsKey(str3)) {
                            MLInferenceIngestProcessor.this.appendFieldValue(output, str4, str3, ingestDocument);
                        }
                    }
                }
                groupedActionListener.onResponse((Object) null);
            }

            public void onFailure(Exception exc) {
                groupedActionListener.onFailure(exc);
            }
        });
    }

    private void getMappedModelInputFromDocuments(IngestDocument ingestDocument, Map<String, String> map, String str, String str2) {
        String fieldPath = getFieldPath(ingestDocument, str);
        if (fieldPath != null) {
            updateModelParameters(str2, toString(ingestDocument.getFieldValue(fieldPath, Object.class)), map);
            return;
        }
        if (!StringUtils.isValidJSONPath(str)) {
            throw new IllegalArgumentException("Cannot find field name defined from input map: " + str);
        }
        Object read = JsonPath.using(suppressExceptionConfiguration).parse(ingestDocument.getSourceAndMetadata()).read(str, new Predicate[0]);
        if (read == null) {
            if (!this.ignoreMissing) {
                throw new IllegalArgumentException("Cannot find field name defined from input map: " + str);
            }
        } else {
            if (!(read instanceof List)) {
                updateModelParameters(str2, toString(read), map);
                return;
            }
            List list = (List) read;
            if (!list.isEmpty()) {
                updateModelParameters(str2, toString(list), map);
            } else if (!this.ignoreMissing) {
                throw new IllegalArgumentException("Cannot find field name defined from input map: " + str);
            }
        }
    }

    private void updateModelParameters(String str, String str2, Map<String, String> map) {
        if (!map.containsKey(str)) {
            map.put(str, str2);
            return;
        }
        List list = (List) map.get(str);
        list.add(str2);
        map.put(str, toString(list));
    }

    private String getFieldPath(IngestDocument ingestDocument, String str) {
        if (Strings.isNullOrEmpty(str) || !ingestDocument.hasField(str, true)) {
            return null;
        }
        return str;
    }

    private void appendFieldValue(MLOutput mLOutput, String str, String str2, IngestDocument ingestDocument) {
        if (mLOutput == null) {
            throw new RuntimeException("model inference output is null");
        }
        Object modelOutputValue = getModelOutputValue(mLOutput, str, this.ignoreMissing, this.fullResponsePath);
        HashMap hashMap = new HashMap();
        hashMap.putAll(ingestDocument.getSourceAndMetadata());
        hashMap.put("_ingest", ingestDocument.getIngestMetadata());
        List<String> writeNewDotPathForNestedObject = writeNewDotPathForNestedObject(hashMap, str2);
        if (writeNewDotPathForNestedObject.size() == 1) {
            ingestDocument.setFieldValue(ConfigurationUtils.compileTemplate("ml_inference", this.tag, writeNewDotPathForNestedObject.get(0), writeNewDotPathForNestedObject.get(0), this.scriptService), ValueSource.wrap(modelOutputValue, this.scriptService), this.ignoreMissing);
            return;
        }
        if (!(modelOutputValue instanceof List)) {
            throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
        }
        List list = (List) modelOutputValue;
        if (writeNewDotPathForNestedObject.size() != list.size()) {
            throw new RuntimeException("the prediction field: " + str + " is an array in size of " + list.size() + " but the document field array from field " + str2 + " is in size of " + writeNewDotPathForNestedObject.size());
        }
        for (int i = 0; i < writeNewDotPathForNestedObject.size(); i++) {
            String str3 = writeNewDotPathForNestedObject.get(i);
            ingestDocument.setFieldValue(ConfigurationUtils.compileTemplate("ml_inference", this.tag, str3, str3, this.scriptService), ValueSource.wrap(list.get(i), this.scriptService), this.ignoreMissing);
        }
    }

    public String getType() {
        return "ml_inference";
    }
}
