package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;

/* loaded from: input_file:org/opensearch/neuralsearch/query/NeuralQueryBuilder.class */
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> implements ModelInferenceQueryBuilder {
    public static final String NAME = "neural";
    private static final int DEFAULT_K = 10;
    private static MLCommonsClientAccessor ML_CLIENT;
    private String fieldName;
    private String queryText;
    private String queryImage;
    private String modelId;
    private Integer k;
    private Float maxDistance;
    private Float minScore;

    @VisibleForTesting
    private Supplier<float[]> vectorSupplier;
    private QueryBuilder filter;
    private Map<String, ?> methodParameters;
    private RescoreContext rescoreContext;

    @Generated
    private static final Logger log = LogManager.getLogger(NeuralQueryBuilder.class);

    @VisibleForTesting
    static final ParseField QUERY_TEXT_FIELD = new ParseField(QueryContextSourceFetcher.QUERY_TEXT_FIELD, new String[0]);

    @VisibleForTesting
    static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image", new String[0]);
    public static final ParseField MODEL_ID_FIELD = new ParseField("model_id", new String[0]);

    @VisibleForTesting
    static final ParseField K_FIELD = new ParseField("k", new String[0]);

    public static void initialize(MLCommonsClientAccessor mLCommonsClientAccessor) {
        ML_CLIENT = mLCommonsClientAccessor;
    }

    public NeuralQueryBuilder(StreamInput streamInput) throws IOException {
        super(streamInput);
        this.k = null;
        this.maxDistance = null;
        this.minScore = null;
        this.fieldName = streamInput.readString();
        this.queryText = streamInput.readString();
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            this.modelId = streamInput.readOptionalString();
        } else {
            this.modelId = streamInput.readString();
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            this.k = streamInput.readOptionalInt();
        } else {
            this.k = Integer.valueOf(streamInput.readVInt());
        }
        this.filter = streamInput.readOptionalNamedWriteable(QueryBuilder.class);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            this.maxDistance = streamInput.readOptionalFloat();
            this.minScore = streamInput.readOptionalFloat();
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.METHOD_PARAMS_FIELD.getPreferredName())) {
            this.methodParameters = MethodParametersParser.streamInput(streamInput, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
        }
        this.rescoreContext = RescoreParser.streamInput(streamInput);
    }

    protected void doWriteTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeString(this.fieldName);
        streamOutput.writeString(this.queryText);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            streamOutput.writeOptionalString(this.modelId);
        } else {
            streamOutput.writeString(this.modelId);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            streamOutput.writeOptionalInt(this.k);
        } else {
            streamOutput.writeVInt(this.k.intValue());
        }
        streamOutput.writeOptionalNamedWriteable(this.filter);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            streamOutput.writeOptionalFloat(this.maxDistance);
            streamOutput.writeOptionalFloat(this.minScore);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.METHOD_PARAMS_FIELD.getPreferredName())) {
            MethodParametersParser.streamOutput(streamOutput, this.methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
        }
        RescoreParser.streamOutput(streamOutput, this.rescoreContext);
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject(NAME);
        xContentBuilder.startObject(this.fieldName);
        xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), this.queryText);
        if (Objects.nonNull(this.modelId)) {
            xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), this.modelId);
        }
        if (Objects.nonNull(this.k)) {
            xContentBuilder.field(K_FIELD.getPreferredName(), this.k);
        }
        if (Objects.nonNull(this.filter)) {
            xContentBuilder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), this.filter);
        }
        if (Objects.nonNull(this.maxDistance)) {
            xContentBuilder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), this.maxDistance);
        }
        if (Objects.nonNull(this.minScore)) {
            xContentBuilder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), this.minScore);
        }
        if (Objects.nonNull(this.methodParameters)) {
            MethodParametersParser.doXContent(xContentBuilder, this.methodParameters);
        }
        if (Objects.nonNull(this.rescoreContext)) {
            RescoreParser.doXContent(xContentBuilder, this.rescoreContext);
        }
        printBoostAndQueryName(xContentBuilder);
        xContentBuilder.endObject();
        xContentBuilder.endObject();
    }

    public static NeuralQueryBuilder fromXContent(XContentParser xContentParser) throws IOException {
        NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
        if (xContentParser.currentToken() != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(xContentParser.getTokenLocation(), "Token must be START_OBJECT", new Object[0]);
        }
        xContentParser.nextToken();
        neuralQueryBuilder.fieldName(xContentParser.currentName());
        xContentParser.nextToken();
        parseQueryParams(xContentParser, neuralQueryBuilder);
        if (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) {
            throw new ParsingException(xContentParser.getTokenLocation(), "[neural] query doesn't support multiple fields, found [" + neuralQueryBuilder.fieldName() + "] and [" + xContentParser.currentName() + "]", new Object[0]);
        }
        if (StringUtils.isBlank(neuralQueryBuilder.queryText()) && StringUtils.isBlank(neuralQueryBuilder.queryImage())) {
            throw new IllegalArgumentException("Either query text or image text must be provided for neural query");
        }
        requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
        if (!MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
        }
        if (!validateKNNQueryType(neuralQueryBuilder)) {
            neuralQueryBuilder.k(Integer.valueOf(DEFAULT_K));
        }
        return neuralQueryBuilder;
    }

    private static void parseQueryParams(XContentParser xContentParser, NeuralQueryBuilder neuralQueryBuilder) throws IOException {
        String str = "";
        while (true) {
            XContentParser.Token nextToken = xContentParser.nextToken();
            if (nextToken == XContentParser.Token.END_OBJECT) {
                return;
            }
            if (nextToken == XContentParser.Token.FIELD_NAME) {
                str = xContentParser.currentName();
            } else if (nextToken.isValue()) {
                if (QUERY_TEXT_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryText(xContentParser.text());
                } else if (QUERY_IMAGE_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryImage(xContentParser.text());
                } else if (MODEL_ID_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.modelId(xContentParser.text());
                } else if (K_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.k((Integer) NumberFieldMapper.NumberType.INTEGER.parse(xContentParser.objectBytes(), false));
                } else if (NAME_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryName(xContentParser.text());
                } else if (BOOST_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.boost(xContentParser.floatValue());
                } else if (KNNQueryBuilder.MAX_DISTANCE_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.maxDistance(Float.valueOf(xContentParser.floatValue()));
                } else {
                    if (!KNNQueryBuilder.MIN_SCORE_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                        throw new ParsingException(xContentParser.getTokenLocation(), "[neural] query does not support [" + str + "]", new Object[0]);
                    }
                    neuralQueryBuilder.minScore(Float.valueOf(xContentParser.floatValue()));
                }
            } else {
                if (nextToken != XContentParser.Token.START_OBJECT) {
                    throw new ParsingException(xContentParser.getTokenLocation(), "[neural] unknown token [" + String.valueOf(nextToken) + "] after [" + str + "]", new Object[0]);
                }
                if (KNNQueryBuilder.FILTER_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.filter(parseInnerQueryBuilder(xContentParser));
                } else if (KNNQueryBuilder.METHOD_PARAMS_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(xContentParser));
                } else if (KNNQueryBuilder.RESCORE_FIELD.match(str, xContentParser.getDeprecationHandler())) {
                    neuralQueryBuilder.rescoreContext(RescoreParser.fromXContent(xContentParser));
                }
            }
        }
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (vectorSupplier() != null) {
            return vectorSupplier().get() == null ? this : KNNQueryBuilder.builder().fieldName(fieldName()).vector(this.vectorSupplier.get()).filter(filter()).maxDistance(this.maxDistance).minScore(this.minScore).k(this.k).methodParameters(this.methodParameters).rescoreContext(this.rescoreContext).build();
        }
        SetOnce setOnce = new SetOnce();
        HashMap hashMap = new HashMap();
        if (StringUtils.isNotBlank(queryText())) {
            hashMap.put(TextImageEmbeddingProcessor.INPUT_TEXT, queryText());
        }
        if (StringUtils.isNotBlank(queryImage())) {
            hashMap.put(TextImageEmbeddingProcessor.INPUT_IMAGE, queryImage());
        }
        queryRewriteContext.registerAsyncAction((client, actionListener) -> {
            MLCommonsClientAccessor mLCommonsClientAccessor = ML_CLIENT;
            String modelId = modelId();
            CheckedConsumer checkedConsumer = list -> {
                setOnce.set(VectorUtil.vectorAsListToArray(list));
                actionListener.onResponse((Object) null);
            };
            Objects.requireNonNull(actionListener);
            mLCommonsClientAccessor.inferenceSentences(modelId, (Map<String, String>) hashMap, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
        });
        String fieldName = fieldName();
        String queryText = queryText();
        String queryImage = queryImage();
        String modelId = modelId();
        Integer k = k();
        Float maxDistance = maxDistance();
        Float minScore = minScore();
        Objects.requireNonNull(setOnce);
        return new NeuralQueryBuilder(fieldName, queryText, queryImage, modelId, k, maxDistance, minScore, setOnce::get, filter(), methodParameters(), rescoreContext());
    }

    protected Query doToQuery(QueryShardContext queryShardContext) {
        throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean doEquals(NeuralQueryBuilder neuralQueryBuilder) {
        if (this == neuralQueryBuilder) {
            return true;
        }
        if (neuralQueryBuilder == null || getClass() != neuralQueryBuilder.getClass()) {
            return false;
        }
        EqualsBuilder equalsBuilder = new EqualsBuilder();
        equalsBuilder.append(this.fieldName, neuralQueryBuilder.fieldName);
        equalsBuilder.append(this.queryText, neuralQueryBuilder.queryText);
        equalsBuilder.append(this.modelId, neuralQueryBuilder.modelId);
        equalsBuilder.append(this.k, neuralQueryBuilder.k);
        equalsBuilder.append(this.filter, neuralQueryBuilder.filter);
        return equalsBuilder.isEquals();
    }

    protected int doHashCode() {
        return new HashCodeBuilder().append(this.fieldName).append(this.queryText).append(this.modelId).append(this.k).toHashCode();
    }

    public String getWriteableName() {
        return NAME;
    }

    private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
        int i = 0;
        if (neuralQueryBuilder.k() != null) {
            i = 0 + 1;
        }
        if (neuralQueryBuilder.maxDistance() != null) {
            i++;
        }
        if (neuralQueryBuilder.minScore() != null) {
            i++;
        }
        if (i > 1) {
            throw new IllegalArgumentException("Only one of k, max_distance, or min_score can be provided");
        }
        return i == 1;
    }

    @Override // org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder
    @Generated
    public String fieldName() {
        return this.fieldName;
    }

    @Generated
    public String queryText() {
        return this.queryText;
    }

    @Generated
    public String queryImage() {
        return this.queryImage;
    }

    @Override // org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder
    @Generated
    public String modelId() {
        return this.modelId;
    }

    @Generated
    public Integer k() {
        return this.k;
    }

    @Generated
    public Float maxDistance() {
        return this.maxDistance;
    }

    @Generated
    public Float minScore() {
        return this.minScore;
    }

    @Generated
    public QueryBuilder filter() {
        return this.filter;
    }

    @Generated
    public Map<String, ?> methodParameters() {
        return this.methodParameters;
    }

    @Generated
    public RescoreContext rescoreContext() {
        return this.rescoreContext;
    }

    @Generated
    public NeuralQueryBuilder fieldName(String str) {
        this.fieldName = str;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryText(String str) {
        this.queryText = str;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryImage(String str) {
        this.queryImage = str;
        return this;
    }

    @Override // org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder
    @Generated
    public NeuralQueryBuilder modelId(String str) {
        this.modelId = str;
        return this;
    }

    @Generated
    public NeuralQueryBuilder k(Integer num) {
        this.k = num;
        return this;
    }

    @Generated
    public NeuralQueryBuilder maxDistance(Float f) {
        this.maxDistance = f;
        return this;
    }

    @Generated
    public NeuralQueryBuilder minScore(Float f) {
        this.minScore = f;
        return this;
    }

    @Generated
    public NeuralQueryBuilder filter(QueryBuilder queryBuilder) {
        this.filter = queryBuilder;
        return this;
    }

    @Generated
    public NeuralQueryBuilder methodParameters(Map<String, ?> map) {
        this.methodParameters = map;
        return this;
    }

    @Generated
    public NeuralQueryBuilder rescoreContext(RescoreContext rescoreContext) {
        this.rescoreContext = rescoreContext;
        return this;
    }

    @Generated
    public NeuralQueryBuilder() {
        this.k = null;
        this.maxDistance = null;
        this.minScore = null;
    }

    @Generated
    public NeuralQueryBuilder(String str, String str2, String str3, String str4, Integer num, Float f, Float f2, Supplier<float[]> supplier, QueryBuilder queryBuilder, Map<String, ?> map, RescoreContext rescoreContext) {
        this.k = null;
        this.maxDistance = null;
        this.minScore = null;
        this.fieldName = str;
        this.queryText = str2;
        this.queryImage = str3;
        this.modelId = str4;
        this.k = num;
        this.maxDistance = f;
        this.minScore = f2;
        this.vectorSupplier = supplier;
        this.filter = queryBuilder;
        this.methodParameters = map;
        this.rescoreContext = rescoreContext;
    }

    @Generated
    Supplier<float[]> vectorSupplier() {
        return this.vectorSupplier;
    }

    @Generated
    NeuralQueryBuilder vectorSupplier(Supplier<float[]> supplier) {
        this.vectorSupplier = supplier;
        return this;
    }
}
