package org.opensearch.neuralsearch.processor;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.Collections;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.collect.Tuple;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.search.rescore.RescorerBuilder;

/* loaded from: input_file:org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.class */
public class NeuralSparseTwoPhaseProcessor extends AbstractProcessor implements SearchRequestProcessor {
    public static final String TYPE = "neural_sparse_two_phase_processor";
    private boolean enabled;
    private float ratio;
    private float windowExpansion;
    private int maxWindowSize;
    private static final String PARAMETER_KEY = "two_phase_parameter";
    private static final String RATIO_KEY = "prune_ratio";
    private static final String ENABLE_KEY = "enabled";
    private static final String EXPANSION_KEY = "expansion_rate";
    private static final String MAX_WINDOW_SIZE_KEY = "max_window_size";
    private static final boolean DEFAULT_ENABLED = true;
    private static final float DEFAULT_RATIO = 0.4f;
    private static final float DEFAULT_WINDOW_EXPANSION = 5.0f;
    private static final int DEFAULT_MAX_WINDOW_SIZE = 10000;
    private static final int DEFAULT_BASE_QUERY_SIZE = 10;
    private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50;
    private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f;
    private static final float RATIO_LOWER_BOUND = 0.0f;
    private static final float RATIO_UPPER_BOUND = 1.0f;

    /* loaded from: input_file:org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor$Factory.class */
    public static class Factory implements Processor.Factory<SearchRequestProcessor> {
        public NeuralSparseTwoPhaseProcessor create(Map<String, Processor.Factory<SearchRequestProcessor>> map, String str, String str2, boolean z, Map<String, Object> map2, Processor.PipelineContext pipelineContext) throws IllegalArgumentException {
            boolean readBooleanProperty = ConfigurationUtils.readBooleanProperty(NeuralSparseTwoPhaseProcessor.TYPE, str, map2, NeuralSparseTwoPhaseProcessor.ENABLE_KEY, true);
            Map readOptionalMap = ConfigurationUtils.readOptionalMap(NeuralSparseTwoPhaseProcessor.TYPE, str, map2, NeuralSparseTwoPhaseProcessor.PARAMETER_KEY);
            float f = 0.4f;
            float f2 = 5.0f;
            int i = NeuralSparseTwoPhaseProcessor.DEFAULT_MAX_WINDOW_SIZE;
            if (Objects.nonNull(readOptionalMap)) {
                f = ((Number) readOptionalMap.getOrDefault(NeuralSparseTwoPhaseProcessor.RATIO_KEY, Float.valueOf(NeuralSparseTwoPhaseProcessor.DEFAULT_RATIO))).floatValue();
                f2 = ((Number) readOptionalMap.getOrDefault(NeuralSparseTwoPhaseProcessor.EXPANSION_KEY, Float.valueOf(NeuralSparseTwoPhaseProcessor.DEFAULT_WINDOW_EXPANSION))).floatValue();
                i = ((Number) readOptionalMap.getOrDefault(NeuralSparseTwoPhaseProcessor.MAX_WINDOW_SIZE_KEY, Integer.valueOf(i))).intValue();
            }
            return new NeuralSparseTwoPhaseProcessor(str, str2, z, readBooleanProperty, f, f2, i);
        }

        /* renamed from: create, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Processor m7create(Map map, String str, String str2, boolean z, Map map2, Processor.PipelineContext pipelineContext) throws Exception {
            return create((Map<String, Processor.Factory<SearchRequestProcessor>>) map, str, str2, z, (Map<String, Object>) map2, pipelineContext);
        }
    }

    protected NeuralSparseTwoPhaseProcessor(String str, String str2, boolean z, boolean z2, float f, float f2, int i) {
        super(str, str2, z);
        this.enabled = z2;
        if (f < RATIO_LOWER_BOUND || f > 1.0f) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f", Float.valueOf(f)));
        }
        this.ratio = f;
        if (f2 < 1.0f) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f", Float.valueOf(f2)));
        }
        this.windowExpansion = f2;
        if (i < MAX_WINDOWS_SIZE_LOWER_BOUND) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two_phase_parameter.max_window_size must >= 50. Received: %n" + i, new Object[0]));
        }
        this.maxWindowSize = i;
    }

    public SearchRequest processRequest(SearchRequest searchRequest) {
        if (!this.enabled || this.ratio == RATIO_LOWER_BOUND) {
            return searchRequest;
        }
        Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilder = collectNeuralSparseQueryBuilder(searchRequest.source().query(), 1.0f);
        if (collectNeuralSparseQueryBuilder.isEmpty()) {
            return searchRequest;
        }
        QueryBuilder nestedQueryBuilderFromNeuralSparseQueryBuilderMap = getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(collectNeuralSparseQueryBuilder);
        nestedQueryBuilderFromNeuralSparseQueryBuilderMap.boost(getOriginQueryWeightAfterRescore(searchRequest.source()));
        searchRequest.source().addRescorer(buildRescoreQueryBuilderForTwoPhase(nestedQueryBuilderFromNeuralSparseQueryBuilderMap, searchRequest));
        return searchRequest;
    }

    public String getType() {
        return TYPE;
    }

    public static Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold(Map<String, Float> map, float f) {
        if (Objects.isNull(map)) {
            throw new IllegalArgumentException("Query tokens cannot be null or empty.");
        }
        float f2 = 0.0f;
        Iterator<Float> it = map.values().iterator();
        while (it.hasNext()) {
            f2 = Math.max(it.next().floatValue(), f2);
        }
        float f3 = f2 * f;
        Map map2 = (Map) map.entrySet().stream().collect(Collectors.partitioningBy(entry -> {
            return ((Float) entry.getValue()).floatValue() >= f3;
        }, Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        })));
        Map map3 = (Map) map2.get(Boolean.TRUE);
        Map map4 = (Map) map2.get(Boolean.FALSE);
        if (Objects.isNull(map3)) {
            map3 = Collections.emptyMap();
        }
        if (Objects.isNull(map4)) {
            map4 = Collections.emptyMap();
        }
        return Tuple.tuple(map3, map4);
    }

    private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(Multimap<NeuralSparseQueryBuilder, Float> multimap) {
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        multimap.asMap().forEach((neuralSparseQueryBuilder, collection) -> {
            boolQueryBuilder.should(neuralSparseQueryBuilder.boost(((Float) collection.stream().reduce(Float.valueOf(RATIO_LOWER_BOUND), (v0, v1) -> {
                return Float.sum(v0, v1);
            })).floatValue()));
        });
        return boolQueryBuilder;
    }

    private float getOriginQueryWeightAfterRescore(SearchSourceBuilder searchSourceBuilder) {
        if (Objects.isNull(searchSourceBuilder.rescores())) {
            return 1.0f;
        }
        return ((Float) searchSourceBuilder.rescores().stream().map(rescorerBuilder -> {
            return Float.valueOf(((QueryRescorerBuilder) rescorerBuilder).getQueryWeight());
        }).reduce(Float.valueOf(1.0f), (f, f2) -> {
            return Float.valueOf(f.floatValue() * f2.floatValue());
        })).floatValue();
    }

    private Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilder(QueryBuilder queryBuilder, float f) {
        ArrayListMultimap create = ArrayListMultimap.create();
        if (queryBuilder instanceof BoolQueryBuilder) {
            BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) queryBuilder;
            float boost = f * boolQueryBuilder.boost();
            Iterator it = boolQueryBuilder.should().iterator();
            while (it.hasNext()) {
                create.putAll(collectNeuralSparseQueryBuilder((QueryBuilder) it.next(), boost));
            }
        } else if (queryBuilder instanceof NeuralSparseQueryBuilder) {
            NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder;
            create.put(neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(this.ratio), Float.valueOf(f * neuralSparseQueryBuilder.boost()));
        }
        return create;
    }

    private RescorerBuilder<QueryRescorerBuilder> buildRescoreQueryBuilderForTwoPhase(QueryBuilder queryBuilder, SearchRequest searchRequest) {
        QueryRescorerBuilder queryRescorerBuilder = new QueryRescorerBuilder(queryBuilder);
        int i = (int) ((searchRequest.source().size() == -1 ? DEFAULT_BASE_QUERY_SIZE : r0) * this.windowExpansion);
        if (i > this.maxWindowSize || i < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two-phase window size of neural_sparse_two_phase_processor should be [0,%d], but get the value of %d", Integer.valueOf(this.maxWindowSize), Integer.valueOf(i)));
        }
        queryRescorerBuilder.windowSize(i);
        return queryRescorerBuilder;
    }

    @Generated
    public void setEnabled(boolean z) {
        this.enabled = z;
    }

    @Generated
    public void setRatio(float f) {
        this.ratio = f;
    }

    @Generated
    public void setWindowExpansion(float f) {
        this.windowExpansion = f;
    }

    @Generated
    public void setMaxWindowSize(int i) {
        this.maxWindowSize = i;
    }

    @Generated
    public boolean isEnabled() {
        return this.enabled;
    }

    @Generated
    public float getRatio() {
        return this.ratio;
    }

    @Generated
    public float getWindowExpansion() {
        return this.windowExpansion;
    }

    @Generated
    public int getMaxWindowSize() {
        return this.maxWindowSize;
    }
}
