package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.client.Client;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.SetOnce;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

/* loaded from: input_file:org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.class */
public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQueryBuilder> implements ModelInferenceQueryBuilder {
    public static final String NAME = "neural_sparse";
    private static MLCommonsClientAccessor ML_CLIENT;
    private String fieldName;
    private String queryText;
    private String modelId;
    private Float maxTokenScore;
    private Supplier<Map<String, Float>> queryTokensSupplier;
    private Map<String, Float> twoPhaseSharedQueryToken;
    private float twoPhasePruneRatio;

    @VisibleForTesting
    static final ParseField QUERY_TEXT_FIELD = new ParseField(QueryContextSourceFetcher.QUERY_TEXT_FIELD, new String[0]);

    @VisibleForTesting
    static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens", new String[0]);

    @VisibleForTesting
    static final ParseField MODEL_ID_FIELD = new ParseField("model_id", new String[0]);

    @VisibleForTesting
    @Deprecated
    static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score", new String[0]).withAllDeprecated();
    private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

    public static void initialize(MLCommonsClientAccessor mLCommonsClientAccessor) {
        ML_CLIENT = mLCommonsClientAccessor;
    }

    public NeuralSparseQueryBuilder(StreamInput streamInput) throws IOException {
        super(streamInput);
        this.twoPhasePruneRatio = 0.0f;
        this.fieldName = streamInput.readString();
        this.queryText = streamInput.readString();
        if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            this.modelId = streamInput.readOptionalString();
        } else {
            this.modelId = streamInput.readString();
        }
        this.maxTokenScore = streamInput.readOptionalFloat();
        if (streamInput.readBoolean()) {
            Map readMap = streamInput.readMap((v0) -> {
                return v0.readString();
            }, (v0) -> {
                return v0.readFloat();
            });
            this.queryTokensSupplier = () -> {
                return readMap;
            };
        }
        if ("".equals(this.queryText)) {
            this.queryText = null;
        }
        if ("".equals(this.modelId)) {
            this.modelId = null;
        }
    }

    public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float f) {
        twoPhasePruneRatio(f);
        NeuralSparseQueryBuilder twoPhasePruneRatio = ((NeuralSparseQueryBuilder) new NeuralSparseQueryBuilder().fieldName(this.fieldName).queryName(this.queryName)).queryText(this.queryText).modelId(this.modelId).maxTokenScore(this.maxTokenScore).twoPhasePruneRatio((-1.0f) * f);
        if (Objects.nonNull(this.queryTokensSupplier)) {
            Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold = NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold(this.queryTokensSupplier.get(), f);
            queryTokensSupplier(() -> {
                return (Map) splitQueryTokensByRatioedMaxScoreAsThreshold.v1();
            });
            twoPhasePruneRatio.queryTokensSupplier(() -> {
                return (Map) splitQueryTokensByRatioedMaxScoreAsThreshold.v2();
            });
        } else {
            this.twoPhaseSharedQueryToken = new HashMap();
            twoPhasePruneRatio.queryTokensSupplier(() -> {
                return this.twoPhaseSharedQueryToken;
            });
        }
        return twoPhasePruneRatio;
    }

    protected void doWriteTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeString(this.fieldName);
        streamOutput.writeString(StringUtils.defaultString(this.queryText, ""));
        if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            streamOutput.writeOptionalString(this.modelId);
        } else {
            streamOutput.writeString(StringUtils.defaultString(this.modelId, ""));
        }
        streamOutput.writeOptionalFloat(this.maxTokenScore);
        if (Objects.isNull(this.queryTokensSupplier) || Objects.isNull(this.queryTokensSupplier.get())) {
            streamOutput.writeBoolean(false);
        } else {
            streamOutput.writeBoolean(true);
            streamOutput.writeMap(this.queryTokensSupplier.get(), (v0, v1) -> {
                v0.writeString(v1);
            }, (v0, v1) -> {
                v0.writeFloat(v1);
            });
        }
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject(NAME);
        xContentBuilder.startObject(this.fieldName);
        if (Objects.nonNull(this.queryText)) {
            xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), this.queryText);
        }
        if (Objects.nonNull(this.modelId)) {
            xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), this.modelId);
        }
        if (Objects.nonNull(this.maxTokenScore)) {
            xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), this.maxTokenScore);
        }
        if (Objects.nonNull(this.queryTokensSupplier) && Objects.nonNull(this.queryTokensSupplier.get())) {
            xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), this.queryTokensSupplier.get());
        }
        printBoostAndQueryName(xContentBuilder);
        xContentBuilder.endObject();
        xContentBuilder.endObject();
    }

    public static NeuralSparseQueryBuilder fromXContent(XContentParser xContentParser) throws IOException {
        NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
        if (xContentParser.currentToken() != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(xContentParser.getTokenLocation(), "First token of neural_sparsequery must be START_OBJECT", new Object[0]);
        }
        xContentParser.nextToken();
        neuralSparseQueryBuilder.fieldName(xContentParser.currentName());
        xContentParser.nextToken();
        parseQueryParams(xContentParser, neuralSparseQueryBuilder);
        if (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) {
            throw new ParsingException(xContentParser.getTokenLocation(), String.format(Locale.ROOT, "[%s] query doesn't support multiple fields, found [%s] and [%s]", NAME, neuralSparseQueryBuilder.fieldName(), xContentParser.currentName()), new Object[0]);
        }
        requireValue(neuralSparseQueryBuilder.fieldName(), "Field name must be provided for neural_sparse query");
        if (Objects.isNull(neuralSparseQueryBuilder.queryTokensSupplier())) {
            requireValue(neuralSparseQueryBuilder.queryText(), String.format(Locale.ROOT, "either %s field or %s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), QUERY_TOKENS_FIELD.getPreferredName(), NAME));
            if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
                requireValue(neuralSparseQueryBuilder.modelId(), String.format(Locale.ROOT, "using %s, %s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName(), NAME));
            }
        }
        if ("".equals(neuralSparseQueryBuilder.queryText())) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", QUERY_TEXT_FIELD.getPreferredName()));
        }
        if ("".equals(neuralSparseQueryBuilder.modelId())) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName()));
        }
        return neuralSparseQueryBuilder;
    }

    private static void parseQueryParams(XContentParser xContentParser, NeuralSparseQueryBuilder neuralSparseQueryBuilder) throws IOException {
        String str = "";
        while (true) {
            XContentParser.Token nextToken = xContentParser.nextToken();
            if (nextToken == XContentParser.Token.END_OBJECT) {
                return;
            }
            if (nextToken == XContentParser.Token.FIELD_NAME) {
                str = xContentParser.currentName();
            } else if (nextToken.isValue()) {
                if (NAME_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralSparseQueryBuilder.queryName(xContentParser.text());
                } else if (BOOST_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralSparseQueryBuilder.boost(xContentParser.floatValue());
                } else if (QUERY_TEXT_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralSparseQueryBuilder.queryText(xContentParser.text());
                } else if (MODEL_ID_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralSparseQueryBuilder.modelId(xContentParser.text());
                } else {
                    if (!MAX_TOKEN_SCORE_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                        throw new ParsingException(xContentParser.getTokenLocation(), String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, str), new Object[0]);
                    }
                    neuralSparseQueryBuilder.maxTokenScore(Float.valueOf(xContentParser.floatValue()));
                }
            } else {
                if (!QUERY_TOKENS_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    throw new ParsingException(xContentParser.getTokenLocation(), String.format(Locale.ROOT, "[%s] unknown token [%s] after [%s]", NAME, nextToken, str), new Object[0]);
                }
                Map map = xContentParser.map(HashMap::new, (v0) -> {
                    return v0.floatValue();
                });
                neuralSparseQueryBuilder.queryTokensSupplier(() -> {
                    return map;
                });
            }
        }
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (Objects.nonNull(this.queryTokensSupplier)) {
            return this;
        }
        validateForRewrite(this.queryText, this.modelId);
        SetOnce<Map<String, Float>> setOnce = new SetOnce<>();
        queryRewriteContext.registerAsyncAction(getModelInferenceAsync(setOnce));
        NeuralSparseQueryBuilder maxTokenScore = new NeuralSparseQueryBuilder().fieldName(this.fieldName).queryText(this.queryText).modelId(this.modelId).maxTokenScore(this.maxTokenScore);
        Objects.requireNonNull(setOnce);
        return maxTokenScore.queryTokensSupplier(setOnce::get).twoPhaseSharedQueryToken(this.twoPhaseSharedQueryToken).twoPhasePruneRatio(this.twoPhasePruneRatio);
    }

    private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map<String, Float>> setOnce) {
        return (client, actionListener) -> {
            MLCommonsClientAccessor mLCommonsClientAccessor = ML_CLIENT;
            String modelId = modelId();
            List<String> of = List.of(this.queryText);
            CheckedConsumer checkedConsumer = list -> {
                Map<String, Float> map = TokenWeightUtil.fetchListOfTokenWeightMap(list).get(0);
                if (Objects.nonNull(this.twoPhaseSharedQueryToken)) {
                    Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold = NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold(map, this.twoPhasePruneRatio);
                    setOnce.set((Map) splitQueryTokensByRatioedMaxScoreAsThreshold.v1());
                    this.twoPhaseSharedQueryToken = (Map) splitQueryTokensByRatioedMaxScoreAsThreshold.v2();
                } else {
                    setOnce.set(map);
                }
                actionListener.onResponse((Object) null);
            };
            Objects.requireNonNull(actionListener);
            mLCommonsClientAccessor.inferenceSentencesWithMapResult(modelId, of, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
        };
    }

    protected Query doToQuery(QueryShardContext queryShardContext) throws IOException {
        validateFieldType(queryShardContext.fieldMapper(this.fieldName));
        Map<String, Float> map = this.queryTokensSupplier.get();
        if (Objects.isNull(map)) {
            throw new IllegalArgumentException("Query tokens cannot be null.");
        }
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        for (Map.Entry<String, Float> entry : map.entrySet()) {
            builder.add(FeatureField.newLinearQuery(this.fieldName, entry.getKey(), entry.getValue().floatValue()), BooleanClause.Occur.SHOULD);
        }
        return builder.build();
    }

    private static void validateForRewrite(String str, String str2) {
        if (StringUtils.isBlank(str) || StringUtils.isBlank(str2)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s and %s cannot be null", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName()));
        }
    }

    private static void validateFieldType(MappedFieldType mappedFieldType) {
        if (Objects.isNull(mappedFieldType) || !mappedFieldType.typeName().equals("rank_features")) {
            throw new IllegalArgumentException("[neural_sparse] query only works on [rank_features] fields");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean doEquals(NeuralSparseQueryBuilder neuralSparseQueryBuilder) {
        if (this == neuralSparseQueryBuilder) {
            return true;
        }
        if (Objects.isNull(neuralSparseQueryBuilder) || getClass() != neuralSparseQueryBuilder.getClass()) {
            return false;
        }
        if (Objects.isNull(this.queryTokensSupplier) && Objects.nonNull(neuralSparseQueryBuilder.queryTokensSupplier)) {
            return false;
        }
        if (Objects.nonNull(this.queryTokensSupplier) && Objects.isNull(neuralSparseQueryBuilder.queryTokensSupplier)) {
            return false;
        }
        EqualsBuilder append = new EqualsBuilder().append(this.fieldName, neuralSparseQueryBuilder.fieldName).append(this.queryText, neuralSparseQueryBuilder.queryText).append(this.modelId, neuralSparseQueryBuilder.modelId).append(this.maxTokenScore, neuralSparseQueryBuilder.maxTokenScore).append(this.twoPhasePruneRatio, neuralSparseQueryBuilder.twoPhasePruneRatio).append(this.twoPhaseSharedQueryToken, neuralSparseQueryBuilder.twoPhaseSharedQueryToken);
        if (Objects.nonNull(this.queryTokensSupplier)) {
            append.append(this.queryTokensSupplier.get(), neuralSparseQueryBuilder.queryTokensSupplier.get());
        }
        return append.isEquals();
    }

    protected int doHashCode() {
        HashCodeBuilder append = new HashCodeBuilder().append(this.fieldName).append(this.queryText).append(this.modelId).append(this.maxTokenScore).append(this.twoPhasePruneRatio).append(this.twoPhaseSharedQueryToken);
        if (Objects.nonNull(this.queryTokensSupplier)) {
            append.append(this.queryTokensSupplier.get());
        }
        return append.toHashCode();
    }

    public String getWriteableName() {
        return NAME;
    }

    private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
        return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
    }

    @Override // org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder
    @Generated
    public String fieldName() {
        return this.fieldName;
    }

    @Generated
    public String queryText() {
        return this.queryText;
    }

    @Override // org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder
    @Generated
    public String modelId() {
        return this.modelId;
    }

    @Generated
    public Float maxTokenScore() {
        return this.maxTokenScore;
    }

    @Generated
    public Supplier<Map<String, Float>> queryTokensSupplier() {
        return this.queryTokensSupplier;
    }

    @Generated
    public Map<String, Float> twoPhaseSharedQueryToken() {
        return this.twoPhaseSharedQueryToken;
    }

    @Generated
    public float twoPhasePruneRatio() {
        return this.twoPhasePruneRatio;
    }

    @Generated
    public NeuralSparseQueryBuilder fieldName(String str) {
        this.fieldName = str;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder queryText(String str) {
        this.queryText = str;
        return this;
    }

    @Override // org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder
    @Generated
    public NeuralSparseQueryBuilder modelId(String str) {
        this.modelId = str;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder maxTokenScore(Float f) {
        this.maxTokenScore = f;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder queryTokensSupplier(Supplier<Map<String, Float>> supplier) {
        this.queryTokensSupplier = supplier;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder twoPhaseSharedQueryToken(Map<String, Float> map) {
        this.twoPhaseSharedQueryToken = map;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder twoPhasePruneRatio(float f) {
        this.twoPhasePruneRatio = f;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder() {
        this.twoPhasePruneRatio = 0.0f;
    }

    @Generated
    public NeuralSparseQueryBuilder(String str, String str2, String str3, Float f, Supplier<Map<String, Float>> supplier, Map<String, Float> map, float f2) {
        this.twoPhasePruneRatio = 0.0f;
        this.fieldName = str;
        this.queryText = str2;
        this.modelId = str3;
        this.maxTokenScore = f;
        this.queryTokensSupplier = supplier;
        this.twoPhaseSharedQueryToken = map;
        this.twoPhasePruneRatio = f2;
    }
}
