package org.opensearch.knn.index.query;

import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.io.PathUtils;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

/* loaded from: input_file:org/opensearch/knn/index/query/KNNWeight.class */
public class KNNWeight extends Weight {
    private static Logger logger = LogManager.getLogger(KNNWeight.class);
    private static ModelDao modelDao;
    private final KNNQuery knnQuery;
    private final float boost;
    private NativeMemoryCacheManager nativeMemoryCacheManager;

    public KNNWeight(KNNQuery kNNQuery, float f) {
        super(kNNQuery);
        this.knnQuery = kNNQuery;
        this.boost = f;
        this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
    }

    public static void initialize(ModelDao modelDao2) {
        modelDao = modelDao2;
    }

    public Explanation explain(LeafReaderContext leafReaderContext, int i) {
        return Explanation.match(Float.valueOf(1.0f), "No Explanation", new Explanation[0]);
    }

    public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException {
        KNNEngine engine;
        SpaceType space;
        SegmentReader unwrap = FilterLeafReader.unwrap(leafReaderContext.reader());
        String path = FilterDirectory.unwrap(unwrap.directory()).getDirectory().toString();
        FieldInfo fieldInfo = unwrap.getFieldInfos().fieldInfo(this.knnQuery.getField());
        if (fieldInfo == null) {
            logger.debug("[KNN] Field info not found for {}:{}", this.knnQuery.getField(), unwrap.getSegmentName());
            return null;
        }
        String attribute = fieldInfo.getAttribute(KNNConstants.MODEL_ID);
        if (attribute != null) {
            ModelMetadata metadata = modelDao.getMetadata(attribute);
            if (metadata == null) {
                throw new RuntimeException("Model \"" + attribute + "\" does not exist.");
            }
            engine = metadata.getKnnEngine();
            space = metadata.getSpaceType();
        } else {
            engine = KNNEngine.getEngine((String) fieldInfo.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()));
            space = SpaceType.getSpace((String) fieldInfo.attributes().getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()));
        }
        String str = this.knnQuery.getField() + (unwrap.getSegmentInfo().info.getUseCompoundFile() ? engine.getExtension() + "c" : engine.getExtension());
        List list = (List) unwrap.getSegmentInfo().files().stream().filter(str2 -> {
            return str2.endsWith(str);
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            logger.debug("[KNN] No engine index found for field {} for segment {}", this.knnQuery.getField(), unwrap.getSegmentName());
            return null;
        }
        Path path2 = PathUtils.get(path, new String[]{(String) list.get(0)});
        KNNCounter.GRAPH_QUERY_REQUESTS.increment();
        try {
            NativeMemoryAllocation nativeMemoryAllocation = this.nativeMemoryCacheManager.get(new NativeMemoryEntryContext.IndexEntryContext(path2.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), IndexUtil.getParametersAtLoading(space, engine, this.knnQuery.getIndexName()), this.knnQuery.getIndexName()), true);
            nativeMemoryAllocation.readLock();
            try {
                try {
                    if (nativeMemoryAllocation.isClosed()) {
                        throw new RuntimeException("Index has already been closed");
                    }
                    KNNQueryResult[] queryIndex = JNIService.queryIndex(nativeMemoryAllocation.getMemoryAddress(), this.knnQuery.getQueryVector(), this.knnQuery.getK(), engine.getName());
                    nativeMemoryAllocation.readUnlock();
                    if (queryIndex.length == 0) {
                        logger.debug("[KNN] Query yielded 0 results");
                        return null;
                    }
                    KNNEngine kNNEngine = engine;
                    SpaceType spaceType = space;
                    Map map = (Map) Arrays.stream(queryIndex).collect(Collectors.toMap((v0) -> {
                        return v0.getId();
                    }, kNNQueryResult -> {
                        return Float.valueOf(kNNEngine.score(kNNQueryResult.getScore(), spaceType));
                    }));
                    DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(((Integer) Collections.max(map.keySet())).intValue() + 1);
                    DocIdSetBuilder.BulkAdder grow = docIdSetBuilder.grow(queryIndex.length);
                    Arrays.stream(queryIndex).forEach(kNNQueryResult2 -> {
                        grow.add(kNNQueryResult2.getId());
                    });
                    return new KNNScorer(this, docIdSetBuilder.build().iterator(), map, this.boost);
                } catch (Throwable th) {
                    nativeMemoryAllocation.readUnlock();
                    throw th;
                }
            } catch (Exception e) {
                KNNCounter.GRAPH_QUERY_ERRORS.increment();
                throw new RuntimeException(e);
            }
        } catch (ExecutionException e2) {
            KNNCounter.GRAPH_QUERY_ERRORS.increment();
            throw new RuntimeException(e2);
        }
    }

    public boolean isCacheable(LeafReaderContext leafReaderContext) {
        return true;
    }

    public static float normalizeScore(float f) {
        return f >= 0.0f ? 1.0f / (1.0f + f) : (-f) + 1.0f;
    }
}
