package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.util.HybridSearchSortUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

/* loaded from: input_file:org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.class */
public class NormalizationProcessorWorkflow {

    @Generated
    private static final Logger log = LogManager.getLogger(NormalizationProcessorWorkflow.class);
    private final ScoreNormalizer scoreNormalizer;
    private final ScoreCombiner scoreCombiner;

    public void execute(List<QuerySearchResult> list, Optional<FetchSearchResult> optional, ScoreNormalizationTechnique scoreNormalizationTechnique, ScoreCombinationTechnique scoreCombinationTechnique) {
        List<Integer> unprocessedDocIds = unprocessedDocIds(list);
        log.debug("Pre-process query results");
        List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(list);
        log.debug("Do score normalization");
        this.scoreNormalizer.normalizeScores(queryTopDocs, scoreNormalizationTechnique);
        CombineScoresDto build = CombineScoresDto.builder().queryTopDocs(queryTopDocs).scoreCombinationTechnique(scoreCombinationTechnique).querySearchResults(list).sort(HybridSearchSortUtil.evaluateSortCriteria(list, queryTopDocs)).build();
        log.debug("Do score combination");
        this.scoreCombiner.combineScores(build);
        log.debug("Post-process query results after score normalization and combination");
        updateOriginalQueryResults(build);
        updateOriginalFetchResults(list, optional, unprocessedDocIds);
    }

    private List<CompoundTopDocs> getQueryTopDocs(List<QuerySearchResult> list) {
        List<CompoundTopDocs> list2 = (List) list.stream().filter(querySearchResult -> {
            return Objects.nonNull(querySearchResult.topDocs());
        }).map(querySearchResult2 -> {
            return querySearchResult2.topDocs().topDocs;
        }).map(CompoundTopDocs::new).collect(Collectors.toList());
        if (list2.size() != list.size()) {
            throw new IllegalStateException(String.format(Locale.ROOT, "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", Integer.valueOf(list.size()), Integer.valueOf(list2.size())));
        }
        return list2;
    }

    private void updateOriginalQueryResults(CombineScoresDto combineScoresDto) {
        List<QuerySearchResult> querySearchResults = combineScoresDto.getQuerySearchResults();
        List<CompoundTopDocs> compoundTopDocs = getCompoundTopDocs(combineScoresDto, querySearchResults);
        Sort sort = combineScoresDto.getSort();
        for (int i = 0; i < querySearchResults.size(); i++) {
            QuerySearchResult querySearchResult = querySearchResults.get(i);
            CompoundTopDocs compoundTopDocs2 = compoundTopDocs.get(i);
            querySearchResult.topDocs(new TopDocsAndMaxScore(buildTopDocs(compoundTopDocs2, sort), maxScoreForShard(compoundTopDocs2, sort != null)), querySearchResult.sortValueFormats());
        }
    }

    private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDto, List<QuerySearchResult> list) {
        List<CompoundTopDocs> queryTopDocs = combineScoresDto.getQueryTopDocs();
        if (list.size() != queryTopDocs.size()) {
            throw new IllegalStateException(String.format(Locale.ROOT, "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", Integer.valueOf(list.size()), Integer.valueOf(queryTopDocs.size())));
        }
        return queryTopDocs;
    }

    private float maxScoreForShard(CompoundTopDocs compoundTopDocs, boolean z) {
        if (compoundTopDocs.getTotalHits().value == 0 || compoundTopDocs.getScoreDocs().isEmpty()) {
            return ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND.floatValue();
        }
        if (!z) {
            return compoundTopDocs.getScoreDocs().get(0).score;
        }
        float floatValue = ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND.floatValue();
        Iterator<ScoreDoc> it = compoundTopDocs.getScoreDocs().iterator();
        while (it.hasNext()) {
            floatValue = Math.max(floatValue, it.next().score);
        }
        return floatValue;
    }

    private TopDocs buildTopDocs(CompoundTopDocs compoundTopDocs, Sort sort) {
        return sort != null ? new TopFieldDocs(compoundTopDocs.getTotalHits(), (ScoreDoc[]) compoundTopDocs.getScoreDocs().toArray(new FieldDoc[0]), sort.getSort()) : new TopDocs(compoundTopDocs.getTotalHits(), (ScoreDoc[]) compoundTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));
    }

    private void updateOriginalFetchResults(List<QuerySearchResult> list, Optional<FetchSearchResult> optional, List<Integer> list2) {
        if (optional.isEmpty()) {
            return;
        }
        FetchSearchResult fetchSearchResult = optional.get();
        SearchHit[] searchHits = getSearchHits(list2, fetchSearchResult, Objects.nonNull(list) && !list.isEmpty() && Objects.nonNull(list.get(0).getShardSearchRequest().requestCache()) && list.get(0).getShardSearchRequest().requestCache().booleanValue());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < searchHits.length; i++) {
            hashMap.put(Integer.valueOf(list2.get(i).intValue()), searchHits[i]);
        }
        QuerySearchResult querySearchResult = list.get(0);
        fetchSearchResult.hits(new SearchHits((SearchHit[]) Arrays.stream(querySearchResult.topDocs().topDocs.scoreDocs).map(scoreDoc -> {
            SearchHit searchHit = (SearchHit) hashMap.get(Integer.valueOf(scoreDoc.doc));
            searchHit.score(scoreDoc.score);
            return searchHit;
        }).toArray(i2 -> {
            return new SearchHit[i2];
        }), querySearchResult.getTotalHits(), querySearchResult.getMaxScore()));
    }

    private SearchHit[] getSearchHits(List<Integer> list, FetchSearchResult fetchSearchResult, boolean z) {
        SearchHit[] hits = fetchSearchResult.hits().getHits();
        if (Objects.isNull(hits)) {
            throw new IllegalStateException("score normalization processor cannot produce final query result, fetch query phase returns empty results");
        }
        if ((z || hits.length == list.size()) && (!z || list.size() >= hits.length)) {
            return hits;
        }
        throw new IllegalStateException(String.format(Locale.ROOT, "score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]", Integer.valueOf(hits.length), Integer.valueOf(list.size())));
    }

    private List<Integer> unprocessedDocIds(List<QuerySearchResult> list) {
        return list.isEmpty() ? List.of() : (List) Arrays.stream(list.get(0).topDocs().topDocs.scoreDocs).map(scoreDoc -> {
            return Integer.valueOf(scoreDoc.doc);
        }).collect(Collectors.toList());
    }

    @Generated
    public NormalizationProcessorWorkflow(ScoreNormalizer scoreNormalizer, ScoreCombiner scoreCombiner) {
        this.scoreNormalizer = scoreNormalizer;
        this.scoreCombiner = scoreCombiner;
    }
}
