package org.opensearch.neuralsearch.processor;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/* loaded from: input_file:org/opensearch/neuralsearch/processor/InferenceProcessor.class */
public abstract class InferenceProcessor extends AbstractBatchingProcessor {
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String FIELD_MAP_FIELD = "field_map";
    private final String type;
    private final String listTypeNestedMapKey;
    protected final String modelId;
    private final Map<String, Object> fieldMap;
    protected final MLCommonsClientAccessor mlCommonsClientAccessor;
    private final Environment environment;
    private final ClusterService clusterService;

    @Generated
    private static final Logger log = LogManager.getLogger(InferenceProcessor.class);
    private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (obj, obj2) -> {
        if ((obj instanceof Collection) && (obj2 instanceof Collection)) {
            ((Collection) obj).addAll((Collection) obj2);
            return obj;
        }
        if (!(obj instanceof Map) || !(obj2 instanceof Map)) {
            return obj2;
        }
        ((Map) obj).putAll((Map) obj2);
        return obj;
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/opensearch/neuralsearch/processor/InferenceProcessor$DataForInference.class */
    public static class DataForInference {
        private final IngestDocumentWrapper ingestDocumentWrapper;
        private final Map<String, Object> processMap;
        private final List<String> inferenceList;

        @Generated
        public IngestDocumentWrapper getIngestDocumentWrapper() {
            return this.ingestDocumentWrapper;
        }

        @Generated
        public Map<String, Object> getProcessMap() {
            return this.processMap;
        }

        @Generated
        public List<String> getInferenceList() {
            return this.inferenceList;
        }

        @Generated
        public DataForInference(IngestDocumentWrapper ingestDocumentWrapper, Map<String, Object> map, List<String> list) {
            this.ingestDocumentWrapper = ingestDocumentWrapper;
            this.processMap = map;
            this.inferenceList = list;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/opensearch/neuralsearch/processor/InferenceProcessor$IndexWrapper.class */
    public static class IndexWrapper {
        private int index;

        protected IndexWrapper(int i) {
            this.index = i;
        }
    }

    public InferenceProcessor(String str, String str2, int i, String str3, String str4, String str5, Map<String, Object> map, MLCommonsClientAccessor mLCommonsClientAccessor, Environment environment, ClusterService clusterService) {
        super(str, str2, i);
        this.type = str3;
        if (StringUtils.isBlank(str5)) {
            throw new IllegalArgumentException("model_id is null or empty, cannot process it");
        }
        validateEmbeddingConfiguration(map);
        this.listTypeNestedMapKey = str4;
        this.modelId = str5;
        this.fieldMap = map;
        this.mlCommonsClientAccessor = mLCommonsClientAccessor;
        this.environment = environment;
        this.clusterService = clusterService;
    }

    private void validateEmbeddingConfiguration(Map<String, Object> map) {
        if (map == null || map.size() == 0 || map.entrySet().stream().anyMatch(entry -> {
            return StringUtils.isBlank((CharSequence) entry.getKey()) || Objects.isNull(entry.getValue()) || StringUtils.isBlank(entry.getValue().toString());
        })) {
            throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value");
        }
    }

    public abstract void doExecute(IngestDocument ingestDocument, Map<String, Object> map, List<String> list, BiConsumer<IngestDocument, Exception> biConsumer);

    public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
        return ingestDocument;
    }

    public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> biConsumer) {
        try {
            validateEmbeddingFieldsValue(ingestDocument);
            Map<String, Object> buildMapWithTargetKeys = buildMapWithTargetKeys(ingestDocument);
            List<String> createInferenceList = createInferenceList(buildMapWithTargetKeys);
            if (createInferenceList.size() == 0) {
                biConsumer.accept(ingestDocument, null);
            } else {
                doExecute(ingestDocument, buildMapWithTargetKeys, createInferenceList, biConsumer);
            }
        } catch (Exception e) {
            biConsumer.accept(null, e);
        }
    }

    abstract void doBatchExecute(List<String> list, Consumer<List<?>> consumer, Consumer<Exception> consumer2);

    public void subBatchExecute(List<IngestDocumentWrapper> list, Consumer<List<IngestDocumentWrapper>> consumer) {
        if (CollectionUtils.isEmpty(list)) {
            consumer.accept(Collections.emptyList());
            return;
        }
        List<DataForInference> dataForInference = getDataForInference(list);
        List<String> constructInferenceTexts = constructInferenceTexts(dataForInference);
        if (constructInferenceTexts.isEmpty()) {
            consumer.accept(list);
            return;
        }
        Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder = sortByLengthAndReturnOriginalOrder(constructInferenceTexts);
        List<String> list2 = (List) sortByLengthAndReturnOriginalOrder.v1();
        Map map = (Map) sortByLengthAndReturnOriginalOrder.v2();
        doBatchExecute(list2, list3 -> {
            int i = 0;
            List<?> restoreToOriginalOrder = restoreToOriginalOrder(list3, map);
            Iterator it = dataForInference.iterator();
            while (it.hasNext()) {
                DataForInference dataForInference2 = (DataForInference) it.next();
                if (dataForInference2.getIngestDocumentWrapper().getException() == null && !CollectionUtils.isEmpty(dataForInference2.getInferenceList())) {
                    List<?> subList = restoreToOriginalOrder.subList(i, i + dataForInference2.getInferenceList().size());
                    i += dataForInference2.getInferenceList().size();
                    setVectorFieldsToDocument(dataForInference2.getIngestDocumentWrapper().getIngestDocument(), dataForInference2.getProcessMap(), subList);
                }
            }
            consumer.accept(list);
        }, exc -> {
            Iterator it = list.iterator();
            while (it.hasNext()) {
                IngestDocumentWrapper ingestDocumentWrapper = (IngestDocumentWrapper) it.next();
                if (ingestDocumentWrapper.getException() == null) {
                    ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exc);
                }
            }
            consumer.accept(list);
        });
    }

    private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(Tuple.tuple(Integer.valueOf(i), list.get(i)));
        }
        arrayList.sort(Comparator.comparingInt(tuple -> {
            return ((String) tuple.v2()).length();
        }));
        List list2 = (List) arrayList.stream().map((v0) -> {
            return v0.v2();
        }).collect(Collectors.toList());
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            hashMap.put(Integer.valueOf(i2), (Integer) ((Tuple) arrayList.get(i2)).v1());
        }
        return Tuple.tuple(list2, hashMap);
    }

    private List<?> restoreToOriginalOrder(List<?> list, Map<Integer, Integer> map) {
        List<?> asList = Arrays.asList(list.toArray());
        for (int i = 0; i < list.size(); i++) {
            if (map.containsKey(Integer.valueOf(i))) {
                asList.set(map.get(Integer.valueOf(i)).intValue(), list.get(i));
            }
        }
        return asList;
    }

    private List<String> constructInferenceTexts(List<DataForInference> list) {
        ArrayList arrayList = new ArrayList();
        for (DataForInference dataForInference : list) {
            if (dataForInference.getIngestDocumentWrapper().getException() == null && !CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
                arrayList.addAll(dataForInference.getInferenceList());
            }
        }
        return arrayList;
    }

    private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> list) {
        ArrayList arrayList = new ArrayList();
        for (IngestDocumentWrapper ingestDocumentWrapper : list) {
            Map<String, Object> map = null;
            List<String> list2 = null;
            try {
                try {
                    validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
                    map = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
                    list2 = createInferenceList(map);
                    arrayList.add(new DataForInference(ingestDocumentWrapper, map, list2));
                } catch (Exception e) {
                    ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
                    arrayList.add(new DataForInference(ingestDocumentWrapper, map, list2));
                }
            } catch (Throwable th) {
                arrayList.add(new DataForInference(ingestDocumentWrapper, map, list2));
                throw th;
            }
        }
        return arrayList;
    }

    private List<String> createInferenceList(Map<String, Object> map) {
        ArrayList arrayList = new ArrayList();
        map.entrySet().stream().filter(entry -> {
            return entry.getValue() != null;
        }).forEach(entry2 -> {
            Object value = entry2.getValue();
            if (value instanceof List) {
                arrayList.addAll((List) value);
            } else if (value instanceof Map) {
                createInferenceListForMapTypeInput(value, arrayList);
            } else {
                arrayList.add(value.toString());
            }
        });
        return arrayList;
    }

    private void createInferenceListForMapTypeInput(Object obj, List<String> list) {
        if (obj instanceof Map) {
            ((Map) obj).forEach((str, obj2) -> {
                createInferenceListForMapTypeInput(obj2, list);
            });
            return;
        }
        if (obj instanceof List) {
            Stream filter = ((List) obj).stream().filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(list);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
        } else {
            if (obj == null) {
                return;
            }
            list.add(obj.toString());
        }
    }

    @VisibleForTesting
    Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
        Map<String, Object> sourceAndMetadata = ingestDocument.getSourceAndMetadata();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Map.Entry<String, Object>> it = this.fieldMap.entrySet().iterator();
        while (it.hasNext()) {
            Pair<String, Object> processNestedKey = processNestedKey(it.next());
            String str = (String) processNestedKey.getKey();
            Object value = processNestedKey.getValue();
            if (value instanceof Map) {
                LinkedHashMap linkedHashMap2 = new LinkedHashMap();
                buildNestedMap(str, value, sourceAndMetadata, linkedHashMap2);
                linkedHashMap.put(str, linkedHashMap2.get(str));
            } else {
                linkedHashMap.put(String.valueOf(value), sourceAndMetadata.get(str));
            }
        }
        return linkedHashMap;
    }

    @VisibleForTesting
    void buildNestedMap(String str, Object obj, Map<String, Object> map, Map<String, Object> map2) {
        if (Objects.isNull(obj) || Objects.isNull(map)) {
            return;
        }
        if (!(obj instanceof Map)) {
            map2.put(String.valueOf(obj), map.get(str));
            return;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (map.get(str) instanceof Map) {
            Iterator it = ((Map) obj).entrySet().iterator();
            while (it.hasNext()) {
                Pair<String, Object> processNestedKey = processNestedKey((Map.Entry) it.next());
                buildNestedMap((String) processNestedKey.getKey(), processNestedKey.getValue(), (Map) map.get(str), linkedHashMap);
            }
        } else if (map.get(str) instanceof List) {
            for (Map.Entry entry : ((Map) obj).entrySet()) {
                List list = (List) ((List) map.get(str)).stream().map(map3 -> {
                    return map3.get(entry.getKey());
                }).collect(Collectors.toList());
                LinkedHashMap linkedHashMap2 = new LinkedHashMap();
                linkedHashMap2.put((String) entry.getKey(), list);
                buildNestedMap((String) entry.getKey(), entry.getValue(), linkedHashMap2, linkedHashMap);
            }
        }
        map2.merge(str, linkedHashMap, REMAPPING_FUNCTION);
    }

    @VisibleForTesting
    protected Pair<String, Object> processNestedKey(Map.Entry<String, Object> entry) {
        String key = entry.getKey();
        Object value = entry.getValue();
        int indexOf = key.indexOf(46);
        if (indexOf != -1) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            linkedHashMap.put(key.substring(indexOf + 1), value);
            value = linkedHashMap;
            key = key.substring(0, indexOf);
        }
        return new ImmutablePair(key, value);
    }

    private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
        Map sourceAndMetadata = ingestDocument.getSourceAndMetadata();
        ProcessorDocumentUtils.validateMapTypeValue("field_map", sourceAndMetadata, this.fieldMap, sourceAndMetadata.get("_index").toString(), this.clusterService, this.environment, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> map, List<?> list) {
        Objects.requireNonNull(list, "embedding failed, inference returns null result!");
        log.debug("Model inference result fetched, starting build vector output!");
        Map<String, Object> buildNLPResult = buildNLPResult(map, list, ingestDocument.getSourceAndMetadata());
        Objects.requireNonNull(ingestDocument);
        buildNLPResult.forEach(ingestDocument::setFieldValue);
    }

    @VisibleForTesting
    Map<String, Object> buildNLPResult(Map<String, Object> map, List<?> list, Map<String, Object> map2) {
        IndexWrapper indexWrapper = new IndexWrapper(0);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Map.Entry<String, Object>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            Pair<String, Object> processNestedKey = processNestedKey(it.next());
            String str = (String) processNestedKey.getKey();
            Object value = processNestedKey.getValue();
            if (value instanceof String) {
                int i = indexWrapper.index;
                indexWrapper.index = i + 1;
                linkedHashMap.put(str, list.get(i));
            } else if (value instanceof List) {
                linkedHashMap.put(str, buildNLPResultForListType((List) value, list, indexWrapper));
            } else if (value instanceof Map) {
                putNLPResultToSourceMapForMapType(str, value, list, indexWrapper, map2);
            }
        }
        return linkedHashMap;
    }

    private void putNLPResultToSourceMapForMapType(String str, Object obj, List<?> list, IndexWrapper indexWrapper, Map<String, Object> map) {
        Map<String, Object> map2;
        if (str == null || map == null || obj == null) {
            return;
        }
        if (!(obj instanceof Map)) {
            if (obj instanceof String) {
                int i = indexWrapper.index;
                indexWrapper.index = i + 1;
                map.merge(str, list.get(i), REMAPPING_FUNCTION);
                return;
            } else {
                if (obj instanceof List) {
                    map.merge(str, buildNLPResultForListType((List) obj, list, indexWrapper), REMAPPING_FUNCTION);
                    return;
                }
                return;
            }
        }
        for (Map.Entry<String, Object> entry : ((Map) obj).entrySet()) {
            if (map.get(str) instanceof List) {
                Iterator it = ((List) entry.getValue()).iterator();
                for (Map map3 : (List) map.get(str)) {
                    if (it.hasNext() && it.next() != null) {
                        String key = entry.getKey();
                        int i2 = indexWrapper.index;
                        indexWrapper.index = i2 + 1;
                        map3.put(key, list.get(i2));
                    }
                }
            } else {
                Pair<String, Object> processNestedKey = processNestedKey(entry);
                if (map.get(str) == null) {
                    map2 = new HashMap();
                    map.put(str, map2);
                } else {
                    map2 = (Map) map.get(str);
                }
                putNLPResultToSourceMapForMapType((String) processNestedKey.getKey(), processNestedKey.getValue(), list, indexWrapper, map2);
            }
        }
    }

    private List<Map<String, Object>> buildNLPResultForListType(List<String> list, List<?> list2, IndexWrapper indexWrapper) {
        ArrayList arrayList = new ArrayList();
        IntStream.range(0, list.size()).forEachOrdered(i -> {
            String str = this.listTypeNestedMapKey;
            int i = indexWrapper.index;
            indexWrapper.index = i + 1;
            arrayList.add(ImmutableMap.of(str, list2.get(i)));
        });
        return arrayList;
    }

    public String getType() {
        return this.type;
    }
}
