package org.opensearch.neuralsearch.search.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.collector.PagingFieldCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

/* loaded from: input_file:org/opensearch/neuralsearch/search/query/HybridCollectorManager.class */
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

    @Generated
    private static final Logger log = LogManager.getLogger(HybridCollectorManager.class);
    private final int numHits;
    private final HitsThresholdChecker hitsThresholdChecker;
    private final int trackTotalHitsUpTo;
    private final SortAndFormats sortAndFormats;

    @Nullable
    private final Weight filterWeight;
    private static final float boostFactor = 1.0f;
    private final TopDocsMerger topDocsMerger;

    @Nullable
    private final FieldDoc after;
    private final SearchContext searchContext;

    /* loaded from: input_file:org/opensearch/neuralsearch/search/query/HybridCollectorManager$HybridCollectorConcurrentSearchManager.class */
    static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager {
        public HybridCollectorConcurrentSearchManager(int i, HitsThresholdChecker hitsThresholdChecker, int i2, Weight weight, SearchContext searchContext) {
            super(i, hitsThresholdChecker, i2, searchContext.sort(), weight, new TopDocsMerger(searchContext.sort()), searchContext.searchAfter(), searchContext);
        }

        @Override // org.opensearch.neuralsearch.search.query.HybridCollectorManager
        /* renamed from: reduce */
        public /* bridge */ /* synthetic */ Object mo41reduce(Collection collection) throws IOException {
            return super.reduce((Collection<Collector>) collection);
        }
    }

    /* loaded from: input_file:org/opensearch/neuralsearch/search/query/HybridCollectorManager$HybridCollectorNonConcurrentManager.class */
    static class HybridCollectorNonConcurrentManager extends HybridCollectorManager {
        private final Collector scoreCollector;
        static final /* synthetic */ boolean $assertionsDisabled;

        public HybridCollectorNonConcurrentManager(int i, HitsThresholdChecker hitsThresholdChecker, int i2, Weight weight, SearchContext searchContext) {
            super(i, hitsThresholdChecker, i2, searchContext.sort(), weight, new TopDocsMerger(searchContext.sort()), searchContext.searchAfter(), searchContext);
            this.scoreCollector = (Collector) Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
        }

        @Override // org.opensearch.neuralsearch.search.query.HybridCollectorManager
        public Collector newCollector() {
            return this.scoreCollector;
        }

        @Override // org.opensearch.neuralsearch.search.query.HybridCollectorManager
        public ReduceableSearchResult reduce(Collection<Collector> collection) {
            if ($assertionsDisabled || collection.isEmpty()) {
                return super.reduce((Collection<Collector>) List.of(this.scoreCollector));
            }
            throw new AssertionError("reduce on HybridCollectorNonConcurrentManager called with non-empty collectors");
        }

        @Override // org.opensearch.neuralsearch.search.query.HybridCollectorManager
        /* renamed from: reduce */
        public /* bridge */ /* synthetic */ Object mo41reduce(Collection collection) throws IOException {
            return reduce((Collection<Collector>) collection);
        }

        static {
            $assertionsDisabled = !HybridCollectorManager.class.desiredAssertionStatus();
        }
    }

    public static CollectorManager createHybridCollectorManager(SearchContext searchContext) throws IOException {
        int min = Math.min(searchContext.from() + searchContext.size(), Math.max(0, searchContext.searcher().getIndexReader().numDocs()));
        int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
        if (searchContext.sort() != null) {
            validateSortCriteria(searchContext, searchContext.trackScores());
        }
        Weight weight = null;
        if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) {
            Query query = searchContext.parsedPostFilter().query();
            ContextIndexSearcher searcher = searchContext.searcher();
            weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, boostFactor);
        }
        return searchContext.shouldUseConcurrentSearch() ? new HybridCollectorConcurrentSearchManager(min, new HitsThresholdChecker(Math.max(min, searchContext.trackTotalHitsUpTo())), trackTotalHitsUpTo, weight, searchContext) : new HybridCollectorNonConcurrentManager(min, new HitsThresholdChecker(Math.max(min, searchContext.trackTotalHitsUpTo())), trackTotalHitsUpTo, weight, searchContext);
    }

    public Collector newCollector() {
        Collector hybridQueryCollector = getHybridQueryCollector();
        return Objects.nonNull(this.filterWeight) ? new FilteredCollector(hybridQueryCollector, this.filterWeight) : hybridQueryCollector;
    }

    private Collector getHybridQueryCollector() {
        if (this.sortAndFormats == null) {
            return new HybridTopScoreDocCollector(this.numHits, this.hitsThresholdChecker);
        }
        if (this.after == null) {
            return new SimpleFieldCollector(this.numHits, this.hitsThresholdChecker, this.sortAndFormats.sort);
        }
        validateSearchAfterFieldAndSortFormats();
        return new PagingFieldCollector(this.numHits, this.hitsThresholdChecker, this.sortAndFormats.sort, this.after);
    }

    public ReduceableSearchResult reduce(Collection<Collector> collection) {
        List<HybridSearchCollector> hybridSearchCollectors = getHybridSearchCollectors(collection);
        if (hybridSearchCollectors.isEmpty()) {
            throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper collectors");
        }
        return reduceSearchResults(getSearchResults(hybridSearchCollectors));
    }

    private List<ReduceableSearchResult> getSearchResults(List<HybridSearchCollector> list) {
        ArrayList arrayList = new ArrayList();
        DocValueFormat[] sortValueFormats = getSortValueFormats(this.sortAndFormats);
        Iterator<HybridSearchCollector> it = list.iterator();
        while (it.hasNext()) {
            TopDocsAndMaxScore topDocsAndAndMaxScore = getTopDocsAndAndMaxScore(it.next(), sortValueFormats != null);
            arrayList.add(querySearchResult -> {
                reduceCollectorResults(querySearchResult, topDocsAndAndMaxScore, sortValueFormats);
            });
        }
        return arrayList;
    }

    private TopDocsAndMaxScore getTopDocsAndAndMaxScore(HybridSearchCollector hybridSearchCollector, boolean z) {
        List<? extends TopDocs> list = hybridSearchCollector.topDocs();
        return z ? getSortedTopDocsAndMaxScore(list, hybridSearchCollector) : getTopDocsAndMaxScore(list, hybridSearchCollector);
    }

    private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List<TopFieldDocs> list, HybridSearchCollector hybridSearchCollector) {
        return new TopDocsAndMaxScore(getNewTopFieldDocs(getTotalHits(this.trackTotalHitsUpTo, list, hybridSearchCollector.getTotalHits()), list, this.sortAndFormats.sort.getSort()), hybridSearchCollector.getMaxScore());
    }

    private TopDocsAndMaxScore getTopDocsAndMaxScore(List<TopDocs> list, HybridSearchCollector hybridSearchCollector) {
        if (shouldRescore()) {
            list = rescore(list);
        }
        return new TopDocsAndMaxScore(getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, list, hybridSearchCollector.getTotalHits()), list), calculateMaxScore(list, hybridSearchCollector.getMaxScore()));
    }

    private boolean shouldRescore() {
        List rescore = this.searchContext.rescore();
        return Objects.nonNull(rescore) && !rescore.isEmpty();
    }

    private List<TopDocs> rescore(List<TopDocs> list) {
        List<TopDocs> list2 = list;
        Iterator it = this.searchContext.rescore().iterator();
        while (it.hasNext()) {
            list2 = rescoredTopDocs((RescoreContext) it.next(), list2);
        }
        return list2;
    }

    private List<TopDocs> rescoredTopDocs(RescoreContext rescoreContext, List<TopDocs> list) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<TopDocs> it = list.iterator();
        while (it.hasNext()) {
            try {
                arrayList.add(rescoreContext.rescorer().rescore(it.next(), this.searchContext.searcher(), rescoreContext));
            } catch (IOException e) {
                log.error("rescore failed for hybrid query in collector_manager.reduce call", e);
                throw new HybridSearchRescoreQueryException(e);
            }
        }
        return arrayList;
    }

    private float calculateMaxScore(List<TopDocs> list, float f) {
        List rescore = this.searchContext.rescore();
        if (Objects.nonNull(rescore) && !rescore.isEmpty()) {
            for (TopDocs topDocs : list) {
                if (Objects.nonNull(topDocs.scoreDocs) && topDocs.scoreDocs.length > 0) {
                    f = Math.max(f, topDocs.scoreDocs[0].score);
                }
            }
        }
        return f;
    }

    private List<HybridSearchCollector> getHybridSearchCollectors(Collection<Collector> collection) {
        ArrayList arrayList = new ArrayList();
        Iterator<Collector> it = collection.iterator();
        while (it.hasNext()) {
            FilteredCollector filteredCollector = (Collector) it.next();
            if (filteredCollector instanceof MultiCollectorWrapper) {
                for (Collector collector : ((MultiCollectorWrapper) filteredCollector).getCollectors()) {
                    if ((collector instanceof HybridTopScoreDocCollector) || (collector instanceof HybridTopFieldDocSortCollector)) {
                        arrayList.add((HybridSearchCollector) collector);
                    }
                }
            } else if ((filteredCollector instanceof HybridTopScoreDocCollector) || (filteredCollector instanceof HybridTopFieldDocSortCollector)) {
                arrayList.add((HybridSearchCollector) filteredCollector);
            } else if ((filteredCollector instanceof FilteredCollector) && ((filteredCollector.getCollector() instanceof HybridTopScoreDocCollector) || (filteredCollector.getCollector() instanceof HybridTopFieldDocSortCollector))) {
                arrayList.add((HybridSearchCollector) filteredCollector.getCollector());
            }
        }
        return arrayList;
    }

    private static void validateSortCriteria(SearchContext searchContext, boolean z) {
        boolean z2 = false;
        boolean z3 = false;
        for (SortField sortField : searchContext.sort().sort.getSort()) {
            if (sortField.getType().equals(SortField.Type.SCORE)) {
                z3 = true;
            } else {
                z2 = true;
            }
            if (z3 && z2) {
                break;
            }
        }
        if (z3 && z2) {
            throw new IllegalArgumentException("_score sort criteria cannot be applied with any other criteria. Please select one sort criteria out of them.");
        }
        if (z && z2) {
            throw new IllegalArgumentException("Hybrid search results when sorted by any field, docId or _id, track_scores must be set to false.");
        }
        if (z && z3) {
            throw new IllegalArgumentException("Hybrid search results are by default sorted by _score, track_scores must be set to false.");
        }
    }

    private void validateSearchAfterFieldAndSortFormats() {
        if (this.after.fields == null) {
            throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search");
        }
        if (this.after.fields.length != this.sortAndFormats.sort.getSort().length) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "after.fields has %s values but sort has %s", Integer.valueOf(this.after.fields.length), Integer.valueOf(this.sortAndFormats.sort.getSort().length)));
        }
    }

    private TopDocs getNewTopDocs(TotalHits totalHits, List<TopDocs> list) {
        ScoreDoc[] scoreDocArr = new ScoreDoc[0];
        if (Objects.nonNull(list)) {
            int intValue = ((Integer) list.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).filter(topDocs -> {
                return Objects.nonNull(topDocs.scoreDocs);
            }).map(topDocs2 -> {
                return topDocs2.scoreDocs;
            }).filter(scoreDocArr2 -> {
                return scoreDocArr2.length > 0;
            }).map(scoreDocArr3 -> {
                return Integer.valueOf(scoreDocArr3[0].doc);
            }).findFirst().orElse(-1)).intValue();
            if (intValue == -1) {
                return new TopDocs(totalHits, scoreDocArr);
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(intValue));
            for (TopDocs topDocs3 : list) {
                if (Objects.isNull(topDocs3) || Objects.isNull(topDocs3.scoreDocs)) {
                    arrayList.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(intValue));
                } else {
                    arrayList.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(intValue));
                    arrayList.addAll(Arrays.asList(topDocs3.scoreDocs));
                }
            }
            arrayList.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(intValue));
            scoreDocArr = (ScoreDoc[]) arrayList.stream().map(scoreDoc -> {
                return new ScoreDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
            }).toArray(i -> {
                return new ScoreDoc[i];
            });
        }
        return new TopDocs(totalHits, scoreDocArr);
    }

    private TotalHits getTotalHits(int i, List<?> list, long j) {
        TotalHits.Relation relation = i == -1 ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO;
        return (list == null || list.isEmpty()) ? new TotalHits(0L, relation) : new TotalHits(j, relation);
    }

    private TopDocs getNewTopFieldDocs(TotalHits totalHits, List<TopFieldDocs> list, SortField[] sortFieldArr) {
        if (Objects.isNull(list)) {
            return new TopFieldDocs(totalHits, new FieldDoc[0], sortFieldArr);
        }
        int intValue = ((Integer) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).filter(topFieldDocs -> {
            return Objects.nonNull(topFieldDocs.scoreDocs);
        }).map(topFieldDocs2 -> {
            return topFieldDocs2.scoreDocs;
        }).filter(scoreDocArr -> {
            return scoreDocArr.length > 0;
        }).map(scoreDocArr2 -> {
            return Integer.valueOf(scoreDocArr2[0].doc);
        }).findFirst().orElse(-1)).intValue();
        if (intValue == -1) {
            return new TopFieldDocs(totalHits, new FieldDoc[0], sortFieldArr);
        }
        Object[] createSortFieldsForDelimiterResults = HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults(sortFieldArr);
        ArrayList arrayList = new ArrayList();
        arrayList.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(intValue, createSortFieldsForDelimiterResults));
        for (TopFieldDocs topFieldDocs3 : list) {
            if (Objects.isNull(topFieldDocs3) || Objects.isNull(topFieldDocs3.scoreDocs)) {
                arrayList.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(intValue, createSortFieldsForDelimiterResults));
            } else {
                ArrayList arrayList2 = new ArrayList();
                for (FieldDoc fieldDoc : topFieldDocs3.scoreDocs) {
                    arrayList2.add(fieldDoc);
                }
                arrayList.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(intValue, createSortFieldsForDelimiterResults));
                arrayList.addAll(arrayList2);
            }
        }
        arrayList.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(intValue, createSortFieldsForDelimiterResults));
        return new TopFieldDocs(totalHits, (FieldDoc[]) arrayList.toArray(new FieldDoc[0]), sortFieldArr);
    }

    private DocValueFormat[] getSortValueFormats(SortAndFormats sortAndFormats) {
        if (sortAndFormats == null) {
            return null;
        }
        return sortAndFormats.formats;
    }

    private void reduceCollectorResults(QuerySearchResult querySearchResult, TopDocsAndMaxScore topDocsAndMaxScore, DocValueFormat[] docValueFormatArr) {
        if (querySearchResult.hasConsumedTopDocs()) {
            querySearchResult.topDocs(topDocsAndMaxScore, docValueFormatArr);
        } else {
            if (topDocsAndMaxScore.topDocs.totalHits.value == 0) {
                return;
            }
            querySearchResult.topDocs(this.topDocsMerger.merge(querySearchResult.topDocs(), topDocsAndMaxScore), docValueFormatArr);
        }
    }

    private ReduceableSearchResult reduceSearchResults(List<ReduceableSearchResult> list) {
        return querySearchResult -> {
            Iterator it = list.iterator();
            while (it.hasNext()) {
                ((ReduceableSearchResult) it.next()).reduce(querySearchResult);
            }
        };
    }

    @Generated
    public HybridCollectorManager(int i, HitsThresholdChecker hitsThresholdChecker, int i2, SortAndFormats sortAndFormats, Weight weight, TopDocsMerger topDocsMerger, FieldDoc fieldDoc, SearchContext searchContext) {
        this.numHits = i;
        this.hitsThresholdChecker = hitsThresholdChecker;
        this.trackTotalHitsUpTo = i2;
        this.sortAndFormats = sortAndFormats;
        this.filterWeight = weight;
        this.topDocsMerger = topDocsMerger;
        this.after = fieldDoc;
        this.searchContext = searchContext;
    }

    /* renamed from: reduce, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Object mo41reduce(Collection collection) throws IOException {
        return reduce((Collection<Collector>) collection);
    }
}
