package org.opensearch.knn;

import com.google.common.collect.ImmutableMap;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.function.BiFunction;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.opensearch.common.xcontent.DeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.plugin.script.KNNScoringUtil;

/* loaded from: input_file:org/opensearch/knn/TestUtils.class */
public class TestUtils {
    public static Map<SpaceType, BiFunction<float[], float[], Float>> KNN_SCORING_SPACE_TYPE = ImmutableMap.of(SpaceType.L1, KNNScoringUtil::l1Norm, SpaceType.L2, KNNScoringUtil::l2Squared, SpaceType.LINF, KNNScoringUtil::lInfNorm, SpaceType.COSINESIMIL, KNNScoringUtil::cosinesimil, SpaceType.INNER_PRODUCT, KNNScoringUtil::innerProduct);
    public static final String KNN_BWC_PREFIX = "knn-bwc-";
    public static final String OPENDISTRO_SECURITY = ".opendistro_security";
    public static final String BWCSUITE_CLUSTER = "tests.rest.bwcsuite_cluster";
    public static final String BWC_VERSION = "tests.plugin_bwc_version";
    public static final String CLIENT_TIMEOUT_VALUE = "90s";
    public static final String FIELD = "field";
    public static final int KNN_ALGO_PARAM_M_MIN_VALUE = 2;
    public static final int KNN_ALGO_PARAM_EF_CONSTRUCTION_MIN_VALUE = 2;
    public static final String MIXED_CLUSTER = "mixed_cluster";
    public static final String NODES_BWC_CLUSTER = "3";
    public static final String NUMBER_OF_SHARDS = "number_of_shards";
    public static final String NUMBER_OF_REPLICAS = "number_of_replicas";
    public static final String INDEX_KNN = "index.knn";
    public static final String OLD_CLUSTER = "old_cluster";
    public static final String PROPERTIES = "properties";
    public static final String VECTOR_TYPE = "type";
    public static final String KNN_VECTOR = "knn_vector";
    public static final String QUERY_VALUE = "query_value";
    public static final String RESTART_UPGRADE_OLD_CLUSTER = "tests.is_old_cluster";
    public static final String ROLLING_UPGRADE_FIRST_ROUND = "tests.rest.first_round";
    public static final String SKIP_DELETE_MODEL_INDEX = "tests.skip_delete_model_index";
    public static final String UPGRADED_CLUSTER = "upgraded_cluster";

    /* loaded from: input_file:org/opensearch/knn/TestUtils$TestData.class */
    public static class TestData {
        public KNNCodecUtil.Pair indexData;
        public float[][] queries;

        public TestData(String str, String str2) throws IOException {
            this.indexData = readIndexData(str);
            this.queries = readQueries(str2);
        }

        private KNNCodecUtil.Pair readIndexData(String str) throws IOException {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            String readLine = bufferedReader.readLine();
            while (true) {
                String str2 = readLine;
                if (str2 == null) {
                    break;
                }
                Map map = XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, str2).map();
                arrayList.add((Integer) map.get("id"));
                ArrayList arrayList3 = (ArrayList) map.get("vector");
                Float[] fArr = new Float[arrayList3.size()];
                for (int i = 0; i < arrayList3.size(); i++) {
                    fArr[i] = Float.valueOf(((Double) arrayList3.get(i)).floatValue());
                }
                arrayList2.add(fArr);
                readLine = bufferedReader.readLine();
            }
            bufferedReader.close();
            int[] iArr = new int[arrayList.size()];
            float[][] fArr2 = new float[arrayList2.size()][((Float[]) arrayList2.get(0)).length];
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
                for (int i3 = 0; i3 < ((Float[]) arrayList2.get(i2)).length; i3++) {
                    fArr2[i2][i3] = ((Float[]) arrayList2.get(i2))[i3].floatValue();
                }
            }
            return new KNNCodecUtil.Pair(iArr, fArr2);
        }

        private float[][] readQueries(String str) throws IOException {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            ArrayList arrayList = new ArrayList();
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split(",");
                Float[] fArr = new Float[split.length];
                for (int i = 0; i < fArr.length; i++) {
                    fArr[i] = Float.valueOf(Float.parseFloat(split[i]));
                }
                arrayList.add(fArr);
            }
            bufferedReader.close();
            float[][] fArr2 = new float[arrayList.size()][((Float[]) arrayList.get(0)).length];
            for (int i2 = 0; i2 < fArr2.length; i2++) {
                for (int i3 = 0; i3 < fArr2[i2].length; i3++) {
                    fArr2[i2][i3] = ((Float[]) arrayList.get(i2))[i3].floatValue();
                }
            }
            return fArr2;
        }
    }

    public static float[][] randomlyGenerateStandardVectors(int i, int i2, int i3) {
        float[][] fArr = new float[i][i2];
        Random random = new Random(i3);
        for (int i4 = 0; i4 < i; i4++) {
            float[] fArr2 = new float[i2];
            for (int i5 = 0; i5 < i2; i5++) {
                fArr2[i5] = random.nextFloat();
            }
            fArr[i4] = fArr2;
        }
        return fArr;
    }

    public static float[][] generateRandomVectors(int i, int i2) {
        float[][] fArr = new float[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            float[] fArr2 = new float[i2];
            for (int i4 = 0; i4 < i2; i4++) {
                fArr2[i4] = LuceneTestCase.random().nextFloat();
            }
            fArr[i3] = fArr2;
        }
        return fArr;
    }

    public static List<Set<String>> computeGroundTruthValues(float[][] fArr, float[][] fArr2, SpaceType spaceType, int i) {
        ArrayList arrayList = new ArrayList();
        for (float[] fArr3 : fArr2) {
            PriorityQueue<DistVector> priorityQueue = new PriorityQueue<>(i, new DistComparator());
            for (int i2 = 0; i2 < fArr.length; i2++) {
                priorityQueue = insertWithOverflow(priorityQueue, i, computeDistFromSpaceType(spaceType, fArr[i2], fArr3), i2);
            }
            HashSet hashSet = new HashSet();
            while (!priorityQueue.isEmpty()) {
                hashSet.add(priorityQueue.poll().getDocID());
            }
            arrayList.add(hashSet);
        }
        return arrayList;
    }

    public static float[][] getQueryVectors(int i, int i2, int i3, boolean z) {
        return z ? randomlyGenerateStandardVectors(i, i2, i3 + 1) : generateRandomVectors(i, i2);
    }

    public static float[][] getIndexVectors(int i, int i2, boolean z) {
        return z ? randomlyGenerateStandardVectors(i, i2, 1) : generateRandomVectors(i, i2);
    }

    public static double calculateRecallValue(List<List<String>> list, List<Set<String>> list2, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            float f = 0.0f;
            for (int i3 = 0; i3 < list.get(i2).size(); i3++) {
                if (list2.get(i2).contains(list.get(i2).get(i3))) {
                    f = (float) (f + 1.0d);
                }
            }
            arrayList.add(Float.valueOf(f / i));
        }
        return ((Float) arrayList.stream().reduce((f2, f3) -> {
            return Float.valueOf(f2.floatValue() + f3.floatValue());
        }).get()).floatValue() / arrayList.size();
    }

    public static PriorityQueue<DistVector> computeGroundTruthValues(int i, SpaceType spaceType, IDVectorProducer iDVectorProducer) {
        PriorityQueue<DistVector> priorityQueue = new PriorityQueue<>(i, new DistComparator());
        int vectorCount = iDVectorProducer.getVectorCount();
        float[] vector = iDVectorProducer.getVector(vectorCount);
        for (int i2 = 0; i2 < vectorCount; i2++) {
            priorityQueue = insertWithOverflow(priorityQueue, i, computeDistFromSpaceType(spaceType, iDVectorProducer.getVector(i2), vector), i2);
        }
        return priorityQueue;
    }

    public static float computeDistFromSpaceType(SpaceType spaceType, float[] fArr, float[] fArr2) {
        if (spaceType != null) {
            return KNN_SCORING_SPACE_TYPE.getOrDefault(spaceType, (fArr3, fArr4) -> {
                throw new IllegalArgumentException(String.format("Invalid SpaceType function: \"%s\"", spaceType));
            }).apply(fArr2, fArr).floatValue();
        }
        throw new NullPointerException("SpaceType is null. Provide a valid SpaceType.");
    }

    public static PriorityQueue<DistVector> insertWithOverflow(PriorityQueue<DistVector> priorityQueue, int i, float f, int i2) {
        if (priorityQueue.size() < i) {
            priorityQueue.add(new DistVector(f, String.valueOf(i2)));
        } else if (priorityQueue.peek().getDist() > f) {
            priorityQueue.poll();
            priorityQueue.add(new DistVector(f, String.valueOf(i2)));
        }
        return priorityQueue;
    }
}
