package org.opensearch.neuralsearch.search.collector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.lucene.MultiLeafFieldComparator;

/* loaded from: input_file:org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.class */
public abstract class HybridTopFieldDocSortCollector implements HybridSearchCollector {

    @Generated
    private static final Logger log;
    private final int numHits;
    private final HitsThresholdChecker hitsThresholdChecker;
    private final Sort sort;

    @Nullable
    private FieldDoc after;
    private FieldComparator<?> firstComparator;
    private FieldValueHitQueue.Entry bottom;
    private int totalHits;
    protected int docBase;
    protected LeafFieldComparator[] comparators;
    protected int reverseMul;
    protected FieldValueHitQueue<FieldValueHitQueue.Entry>[] compoundScores;
    protected boolean[] queueFull;
    protected int[] collectedHits;
    private static final TopFieldDocs EMPTY_TOP_FIELD_DOCS;
    static final /* synthetic */ boolean $assertionsDisabled;
    private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
    protected float maxScore = 0.0f;
    private Boolean searchSortPartOfIndexSort = null;

    /* loaded from: input_file:org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector$HybridTopDocSortLeafCollector.class */
    protected abstract class HybridTopDocSortLeafCollector implements LeafCollector {
        protected HybridQueryScorer compoundQueryScorer;
        private boolean collectedAllCompetitiveHits = false;
        private boolean initializeLeafComparatorsPerSegmentOnce = true;

        public HybridTopDocSortLeafCollector() {
        }

        public void setScorer(Scorable scorable) throws IOException {
            if (scorable instanceof HybridQueryScorer) {
                HybridTopFieldDocSortCollector.log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores");
                this.compoundQueryScorer = (HybridQueryScorer) scorable;
            } else {
                this.compoundQueryScorer = getHybridQueryScorer(scorable);
                if (Objects.isNull(this.compoundQueryScorer)) {
                    HybridTopFieldDocSortCollector.log.error(String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorable));
                }
            }
        }

        private HybridQueryScorer getHybridQueryScorer(Scorable scorable) throws IOException {
            if (scorable == null) {
                return null;
            }
            if (scorable instanceof HybridQueryScorer) {
                return (HybridQueryScorer) scorable;
            }
            for (Scorable.ChildScorable childScorable : scorable.getChildren()) {
                HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child);
                if (Objects.nonNull(hybridQueryScorer)) {
                    HybridTopFieldDocSortCollector.log.debug(String.format(Locale.ROOT, "found hybrid query scorer, it's child of scorer %s", childScorable.child.getClass().getSimpleName()));
                    return hybridQueryScorer;
                }
            }
            return null;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void incrementTotalHitCount() throws IOException {
            HybridTopFieldDocSortCollector.this.totalHits++;
            HybridTopFieldDocSortCollector.this.hitsThresholdChecker.incrementHitCount();
            if (!HybridTopFieldDocSortCollector.this.scoreMode().isExhaustive() && HybridTopFieldDocSortCollector.this.getTotalHitsRelation() == TotalHits.Relation.EQUAL_TO && HybridTopFieldDocSortCollector.this.hitsThresholdChecker.isThresholdReached()) {
                for (LeafFieldComparator leafFieldComparator : HybridTopFieldDocSortCollector.this.comparators) {
                    leafFieldComparator.setHitsThresholdReached();
                }
                HybridTopFieldDocSortCollector.this.setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void collectHit(int i, int i2, int i3, float f) throws IOException {
            int i4 = i2 - 1;
            if (HybridTopFieldDocSortCollector.this.numHits <= 0) {
                HybridTopFieldDocSortCollector.this.queueFull[i3] = true;
                return;
            }
            HybridTopFieldDocSortCollector.this.comparators[i3].copy(i4, i);
            HybridTopFieldDocSortCollector.this.add(i4, i, HybridTopFieldDocSortCollector.this.compoundScores[i3], i3, f);
            if (HybridTopFieldDocSortCollector.this.queueFull[i3]) {
                HybridTopFieldDocSortCollector.this.comparators[i3].setBottom(HybridTopFieldDocSortCollector.this.bottom.slot);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void collectCompetitiveHit(int i, int i2) throws IOException {
            if (HybridTopFieldDocSortCollector.this.numHits > 0) {
                HybridTopFieldDocSortCollector.this.comparators[i2].copy(HybridTopFieldDocSortCollector.this.bottom.slot, i);
                HybridTopFieldDocSortCollector.this.updateBottom(i, HybridTopFieldDocSortCollector.this.compoundScores[i2]);
                HybridTopFieldDocSortCollector.this.comparators[i2].setBottom(HybridTopFieldDocSortCollector.this.bottom.slot);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public boolean thresholdCheck(int i, int i2) throws IOException {
            if (!this.collectedAllCompetitiveHits && HybridTopFieldDocSortCollector.this.reverseMul * HybridTopFieldDocSortCollector.this.comparators[i2].compareBottom(i) > 0) {
                return false;
            }
            if (!HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort.booleanValue()) {
                return true;
            }
            if (!HybridTopFieldDocSortCollector.this.hitsThresholdChecker.isThresholdReached()) {
                this.collectedAllCompetitiveHits = true;
                return true;
            }
            HybridTopFieldDocSortCollector.this.setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
            HybridTopFieldDocSortCollector.log.info("Terminating collection as hits threshold is reached");
            throw new CollectionTerminatedException();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void initializePriorityQueuesWithComparators(LeafReaderContext leafReaderContext, int i) throws IOException {
            if (HybridTopFieldDocSortCollector.this.compoundScores == null) {
                HybridTopFieldDocSortCollector.this.compoundScores = new FieldValueHitQueue[i];
                HybridTopFieldDocSortCollector.this.comparators = new LeafFieldComparator[i];
                HybridTopFieldDocSortCollector.this.queueFull = new boolean[i];
                HybridTopFieldDocSortCollector.this.collectedHits = new int[i];
                for (int i2 = 0; i2 < i; i2++) {
                    initializeLeafFieldComparators(leafReaderContext, i2);
                }
            }
            if (this.initializeLeafComparatorsPerSegmentOnce) {
                for (int i3 = 0; i3 < i; i3++) {
                    initializeComparators(leafReaderContext, i3);
                }
                this.initializeLeafComparatorsPerSegmentOnce = false;
            }
        }

        private void initializeLeafFieldComparators(LeafReaderContext leafReaderContext, int i) throws IOException {
            HybridTopFieldDocSortCollector.this.compoundScores[i] = FieldValueHitQueue.create(HybridTopFieldDocSortCollector.this.sort.getSort(), HybridTopFieldDocSortCollector.this.numHits);
            HybridTopFieldDocSortCollector.this.firstComparator = HybridTopFieldDocSortCollector.this.compoundScores[i].getComparators()[0];
            if (HybridTopFieldDocSortCollector.this.compoundScores[i].getComparators().length == 1) {
                HybridTopFieldDocSortCollector.this.firstComparator.setSingleSort();
            }
            if (HybridTopFieldDocSortCollector.this.after != null) {
                setAfterFieldValueInFieldCompartor(i);
            }
        }

        private void initializeComparators(LeafReaderContext leafReaderContext, int i) throws IOException {
            if (HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort == null) {
                Sort sort = leafReaderContext.reader().getMetaData().getSort();
                HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort = Boolean.valueOf(HybridTopFieldDocSortCollector.this.canEarlyTerminate(HybridTopFieldDocSortCollector.this.sort, sort));
                if (HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort.booleanValue()) {
                    HybridTopFieldDocSortCollector.this.firstComparator.disableSkipping();
                }
            }
            LeafFieldComparator[] comparators = HybridTopFieldDocSortCollector.this.compoundScores[i].getComparators(leafReaderContext);
            int[] reverseMul = HybridTopFieldDocSortCollector.this.compoundScores[i].getReverseMul();
            if (comparators.length == 1) {
                HybridTopFieldDocSortCollector.this.reverseMul = reverseMul[0];
                HybridTopFieldDocSortCollector.this.comparators[i] = comparators[0];
            } else {
                HybridTopFieldDocSortCollector.this.reverseMul = 1;
                HybridTopFieldDocSortCollector.this.comparators[i] = new MultiLeafFieldComparator(comparators, reverseMul);
            }
            HybridTopFieldDocSortCollector.this.comparators[i].setScorer(this.compoundQueryScorer);
        }

        private void setAfterFieldValueInFieldCompartor(int i) {
            FieldComparator[] comparators = HybridTopFieldDocSortCollector.this.compoundScores[i].getComparators();
            for (int i2 = 0; i2 < comparators.length; i2++) {
                comparators[i2].setTopValue(HybridTopFieldDocSortCollector.this.after.fields[i2]);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public HybridTopFieldDocSortCollector(int i, HitsThresholdChecker hitsThresholdChecker, Sort sort, FieldDoc fieldDoc) {
        this.numHits = i;
        this.hitsThresholdChecker = hitsThresholdChecker;
        this.sort = sort;
        this.after = fieldDoc;
    }

    @Override // org.opensearch.neuralsearch.search.collector.HybridSearchCollector
    public List<TopFieldDocs> topDocs() {
        if (this.compoundScores == null) {
            return new ArrayList();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.compoundScores.length; i++) {
            arrayList.add(topDocsPerQuery(0, Math.min(this.collectedHits[i], this.compoundScores[i].size()), this.compoundScores[i], this.collectedHits[i], this.sort.getSort()));
        }
        return arrayList;
    }

    public ScoreMode scoreMode() {
        return this.hitsThresholdChecker.scoreMode();
    }

    private TopFieldDocs topDocsPerQuery(int i, int i2, PriorityQueue<FieldValueHitQueue.Entry> priorityQueue, int i3, SortField[] sortFieldArr) {
        if (i2 < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Number of hits requested must be greater than 0 but value was %d", Integer.valueOf(i2)));
        }
        if (i < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected value of starting position is between 0 and %d, got %d", Integer.valueOf(i2), Integer.valueOf(i)));
        }
        if (i >= i2 || i2 == 0) {
            return EMPTY_TOP_FIELD_DOCS;
        }
        int i4 = i2 - i;
        ScoreDoc[] scoreDocArr = new ScoreDoc[i4];
        populateResults(scoreDocArr, i4, priorityQueue);
        return new TopFieldDocs(new TotalHits(i3, this.totalHitsRelation), scoreDocArr, sortFieldArr);
    }

    private void populateResults(ScoreDoc[] scoreDocArr, int i, PriorityQueue<FieldValueHitQueue.Entry> priorityQueue) {
        FieldValueHitQueue fieldValueHitQueue = (FieldValueHitQueue) priorityQueue;
        for (int i2 = i - 1; i2 >= 0 && priorityQueue.size() > 0; i2--) {
            if (i2 < scoreDocArr.length) {
                FieldValueHitQueue.Entry entry = (FieldValueHitQueue.Entry) fieldValueHitQueue.pop();
                int length = fieldValueHitQueue.getComparators().length;
                Object[] objArr = new Object[length];
                for (int i3 = 0; i3 < length; i3++) {
                    objArr[i3] = fieldValueHitQueue.getComparators()[i3].value(entry.slot);
                }
                scoreDocArr[i2] = new FieldDoc(entry.doc, entry.score, objArr);
            }
        }
    }

    private void add(int i, int i2, FieldValueHitQueue<FieldValueHitQueue.Entry> fieldValueHitQueue, int i3, float f) {
        FieldValueHitQueue.Entry entry = new FieldValueHitQueue.Entry(i, this.docBase + i2);
        entry.score = f;
        this.bottom = (FieldValueHitQueue.Entry) fieldValueHitQueue.add(entry);
        if (!$assertionsDisabled && i >= this.numHits) {
            throw new AssertionError();
        }
        boolean z = false;
        if (i == this.numHits - 1) {
            z = true;
        }
        this.queueFull[i3] = z;
    }

    private void updateBottom(int i, FieldValueHitQueue<FieldValueHitQueue.Entry> fieldValueHitQueue) {
        this.bottom.doc = this.docBase + i;
        this.bottom = (FieldValueHitQueue.Entry) fieldValueHitQueue.updateTop();
    }

    private boolean canEarlyTerminate(Sort sort, Sort sort2) {
        return canEarlyTerminateOnDocId(sort) || canEarlyTerminateOnPrefix(sort, sort2);
    }

    private boolean canEarlyTerminateOnDocId(Sort sort) {
        return SortField.FIELD_DOC.equals(sort.getSort()[0]);
    }

    private boolean canEarlyTerminateOnPrefix(Sort sort, Sort sort2) {
        if (sort2 == null) {
            return false;
        }
        SortField[] sort3 = sort.getSort();
        SortField[] sort4 = sort2.getSort();
        if (sort3.length > sort4.length) {
            return false;
        }
        for (int i = 0; i < sort3.length; i++) {
            if (!sort3[i].equals(sort4[i])) {
                return false;
            }
        }
        return true;
    }

    @Override // org.opensearch.neuralsearch.search.collector.HybridSearchCollector
    @Generated
    public int getTotalHits() {
        return this.totalHits;
    }

    @Generated
    public TotalHits.Relation getTotalHitsRelation() {
        return this.totalHitsRelation;
    }

    @Generated
    public void setTotalHitsRelation(TotalHits.Relation relation) {
        this.totalHitsRelation = relation;
    }

    @Override // org.opensearch.neuralsearch.search.collector.HybridSearchCollector
    @Generated
    public float getMaxScore() {
        return this.maxScore;
    }

    static {
        $assertionsDisabled = !HybridTopFieldDocSortCollector.class.desiredAssertionStatus();
        log = LogManager.getLogger(HybridTopFieldDocSortCollector.class);
        EMPTY_TOP_FIELD_DOCS = new TopFieldDocs(new TotalHits(0L, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0], new SortField[0]);
    }
}
