package org.opensearch.neuralsearch;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.WarningsHandler;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TestUtils;
import org.opensearch.neuralsearch.util.TokenWeightUtil;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.test.ClusterServiceUtils;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
/* loaded from: input_file:org/opensearch/neuralsearch/BaseNeuralSearchIT.class */
public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {
    protected static final Locale LOCALE = Locale.ROOT;
    protected static final Map<ProcessorType, String> PIPELINE_CONFIGS_BY_TYPE = Map.of(ProcessorType.TEXT_EMBEDDING, "processor/PipelineConfiguration.json", ProcessorType.SPARSE_ENCODING, "processor/SparseEncodingPipelineConfiguration.json", ProcessorType.TEXT_IMAGE_EMBEDDING, "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, "processor/PipelineConfigurationWithNestedFieldsMapping.json");
    private static final Set<RestStatus> SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK);
    protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled";
    protected final ClassLoader classLoader = getClass().getClassLoader();
    protected ThreadPool threadPool;
    protected ClusterService clusterService;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/opensearch/neuralsearch/BaseNeuralSearchIT$KNNFieldConfig.class */
    public static class KNNFieldConfig {
        private final String name;
        private final Integer dimension;
        private final SpaceType spaceType;

        @Generated
        public KNNFieldConfig(String str, Integer num, SpaceType spaceType) {
            this.name = str;
            this.dimension = num;
            this.spaceType = spaceType;
        }

        @Generated
        public String getName() {
            return this.name;
        }

        @Generated
        public Integer getDimension() {
            return this.dimension;
        }

        @Generated
        public SpaceType getSpaceType() {
            return this.spaceType;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/opensearch/neuralsearch/BaseNeuralSearchIT$ProcessorType.class */
    public enum ProcessorType {
        TEXT_EMBEDDING,
        TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING,
        TEXT_IMAGE_EMBEDDING,
        SPARSE_ENCODING
    }

    @Before
    public void setupSettings() {
        this.threadPool = setUpThreadPool();
        this.clusterService = createClusterService(this.threadPool);
        if (isUpdateClusterSettings()) {
            updateClusterSettings();
        }
        NeuralSearchClusterUtil.instance().initialize(this.clusterService);
    }

    protected ThreadPool setUpThreadPool() {
        return new TestThreadPool(getClass().getName(), threadPoolSettings(), new ExecutorBuilder[0]);
    }

    public Settings threadPoolSettings() {
        return Settings.EMPTY;
    }

    public static ClusterService createClusterService(ThreadPool threadPool) {
        return ClusterServiceUtils.createClusterService(threadPool);
    }

    protected void updateClusterSettings() {
        updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
        updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
        updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 95);
        updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
    }

    protected void updateClusterSettings(String str, Object obj) {
        assertEquals(RestStatus.OK, RestStatus.fromCode(makeRequest(client(), "PUT", "_cluster/settings", null, toHttpEntity(XContentFactory.jsonBuilder().startObject().startObject("persistent").field(str, obj).endObject().endObject().toString()), ImmutableList.of(new BasicHeader("User-Agent", ""))).getStatusLine().getStatusCode()));
    }

    protected String registerModelGroupAndUploadModel(String str) throws Exception {
        return uploadModel(String.format(LOCALE, str, getModelGroupId()));
    }

    protected String uploadModel(String str) throws Exception {
        String obj = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "POST", "/_plugins/_ml/models/_upload", null, toHttpEntity(str), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("task_id").toString();
        assertNotNull(obj);
        Map<String, Object> taskQueryResponse = getTaskQueryResponse(obj);
        boolean checkComplete = checkComplete(taskQueryResponse);
        for (int i = 0; !checkComplete && i < 300; i++) {
            taskQueryResponse = getTaskQueryResponse(obj);
            checkComplete = checkComplete(taskQueryResponse);
            Thread.sleep(1000L);
        }
        String str2 = (String) Optional.ofNullable(taskQueryResponse.get("model_id")).map((v0) -> {
            return v0.toString();
        }).orElse(null);
        assertNotNull(str2);
        return str2;
    }

    protected void loadModel(String str) throws Exception {
        String obj = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "POST", String.format(LOCALE, "/_plugins/_ml/models/%s/_deploy", str), null, toHttpEntity(""), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("task_id").toString();
        assertNotNull(obj);
        boolean checkComplete = checkComplete(getTaskQueryResponse(obj));
        for (int i = 0; !checkComplete && i < 300; i++) {
            checkComplete = checkComplete(getTaskQueryResponse(obj));
            Thread.sleep(1000L);
        }
        assertTrue(checkComplete);
    }

    protected String prepareModel() {
        String registerModelGroupAndUploadModel = registerModelGroupAndUploadModel(Files.readString(Path.of(this.classLoader.getResource("processor/UploadModelRequestBody.json").toURI())));
        loadModel(registerModelGroupAndUploadModel);
        return registerModelGroupAndUploadModel;
    }

    protected String prepareSparseEncodingModel() {
        String registerModelGroupAndUploadModel = registerModelGroupAndUploadModel(Files.readString(Path.of(this.classLoader.getResource("processor/UploadSparseEncodingModelRequestBody.json").toURI())));
        loadModel(registerModelGroupAndUploadModel);
        return registerModelGroupAndUploadModel;
    }

    protected float[] runInference(String str, String str2) {
        Object obj = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "POST", String.format(LOCALE, "/_plugins/_ml/_predict/text_embedding/%s", str), null, toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"],\"target_response\": [\"sentence_embedding\"]}", str2)), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("inference_results");
        assertTrue(obj instanceof List);
        assertEquals(1L, r0.size());
        List list = (List) ((Map) ((List) obj).get(0)).get("output");
        assertEquals(1L, list.size());
        return VectorUtil.vectorAsListToArray((List) ((List) ((Map) list.get(0)).get("data")).stream().map((v0) -> {
            return v0.floatValue();
        }).collect(Collectors.toList()));
    }

    protected void createIndexWithConfiguration(String str, String str2, String str3) throws Exception {
        if (StringUtils.isNotBlank(str3)) {
            str2 = String.format(LOCALE, str2, str3);
        }
        Map convertToMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "PUT", str, null, toHttpEntity(str2), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false);
        assertEquals("true", convertToMap.get("acknowledged").toString());
        assertEquals(str, convertToMap.get("index").toString());
    }

    protected void createPipelineProcessor(String str, String str2, ProcessorType processorType) throws Exception {
        createPipelineProcessor(str, str2, processorType, (Integer) null);
    }

    protected void createPipelineProcessor(String str, String str2, ProcessorType processorType, Integer num) throws Exception {
        createPipelineProcessor(Files.readString(Path.of(this.classLoader.getResource(PIPELINE_CONFIGS_BY_TYPE.get(processorType)).toURI())), str2, str, num);
    }

    protected void createPipelineProcessor(String str, String str2, String str3, Integer num) throws Exception {
        RestClient client = client();
        String str4 = "/_ingest/pipeline/" + str2;
        Locale locale = LOCALE;
        Object[] objArr = new Object[2];
        objArr[0] = str3;
        objArr[1] = Integer.valueOf(num == null ? 1 : num.intValue());
        assertEquals("true", XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client, "PUT", str4, null, toHttpEntity(String.format(locale, str, objArr)), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("acknowledged").toString());
    }

    protected void createNeuralSparseTwoPhaseSearchProcessor(String str) throws Exception {
        createNeuralSparseTwoPhaseSearchProcessor(str, 0.4f, 5.0f, 10000);
    }

    protected void createNeuralSparseTwoPhaseSearchProcessor(String str, float f, float f2, int i) throws Exception {
        assertEquals("true", XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "PUT", "/_search/pipeline/" + str, null, toHttpEntity(String.format(Locale.ROOT, Files.readString(Path.of(((URL) Objects.requireNonNull(this.classLoader.getResource("processor/NeuralSparseTwoPhaseProcessorConfiguration.json"))).toURI())), Float.valueOf(f), Float.valueOf(f2), Integer.valueOf(i))), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("acknowledged").toString());
    }

    protected void createSearchRequestProcessor(String str, String str2) throws Exception {
        assertEquals("true", XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "PUT", "/_search/pipeline/" + str2, null, toHttpEntity(String.format(LOCALE, Files.readString(Path.of(this.classLoader.getResource("processor/SearchRequestPipelineConfiguration.json").toURI())), str)), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("acknowledged").toString());
    }

    protected void createSearchPipelineViaConfig(String str, String str2, String str3) throws Exception {
        assertEquals("true", XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "PUT", "/_search/pipeline/" + str2, null, toHttpEntity(String.format(LOCALE, Files.readString(Path.of(this.classLoader.getResource(str3).toURI())), str)), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("acknowledged").toString());
    }

    protected void createIndexAlias(String str, String str2, QueryBuilder queryBuilder) throws Exception {
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject();
        startObject.startArray("actions");
        startObject.startObject();
        startObject.startObject("add");
        startObject.field("index", str);
        startObject.field("alias", str2);
        if (Objects.nonNull(queryBuilder)) {
            startObject.field("filter");
            queryBuilder.toXContent(startObject, ToXContent.EMPTY_PARAMS);
        }
        startObject.endObject();
        startObject.endObject();
        startObject.endArray();
        startObject.endObject();
        Request request = new Request("POST", "/_aliases");
        request.setJsonEntity(startObject.toString());
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(client().performRequest(request).getStatusLine().getStatusCode()));
    }

    protected void deleteIndexAlias(String str, String str2) {
        makeRequest(client(), "DELETE", String.format(Locale.ROOT, "%s/_alias/%s", str, str2), null, null, ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
    }

    protected int getDocCount(String str) {
        Request request = new Request("GET", "/" + str + "/_count");
        Response performRequest = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(performRequest.getStatusLine().getStatusCode()));
        return ((Integer) createParser(XContentType.JSON.xContent(), EntityUtils.toString(performRequest.getEntity())).map().get("count")).intValue();
    }

    protected Map<String, Object> getDocById(String str, String str2) {
        Request request = new Request("GET", "/" + str + "/_doc/" + str2);
        Response performRequest = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(performRequest.getStatusLine().getStatusCode()));
        return createParser(XContentType.JSON.xContent(), EntityUtils.toString(performRequest.getEntity())).map();
    }

    protected Map<String, Object> search(String str, QueryBuilder queryBuilder, int i) {
        return search(str, queryBuilder, null, i);
    }

    protected Map<String, Object> search(String str, QueryBuilder queryBuilder, QueryBuilder queryBuilder2, int i) {
        return search(str, queryBuilder, queryBuilder2, i, Map.of());
    }

    protected Map<String, Object> search(String str, QueryBuilder queryBuilder, QueryBuilder queryBuilder2, int i, Map<String, String> map) {
        return search(str, queryBuilder, queryBuilder2, i, map, null);
    }

    protected Map<String, Object> search(String str, QueryBuilder queryBuilder, QueryBuilder queryBuilder2, int i, Map<String, String> map, List<Object> list) {
        return search(str, queryBuilder, queryBuilder2, i, map, list, null, null, false, null, 0);
    }

    protected Map<String, Object> search(String str, QueryBuilder queryBuilder, QueryBuilder queryBuilder2, int i, Map<String, String> map, List<Object> list, QueryBuilder queryBuilder3, List<SortBuilder<?>> list2, boolean z, List<Object> list3, int i2) {
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject();
        startObject.field("from", i2);
        if (queryBuilder != null) {
            startObject.field("query");
            queryBuilder.toXContent(startObject, ToXContent.EMPTY_PARAMS);
        }
        if (queryBuilder2 != null) {
            startObject.startObject("rescore").startObject("query").field("query_weight", 0.0f).field("rescore_query");
            queryBuilder2.toXContent(startObject, ToXContent.EMPTY_PARAMS);
            startObject.endObject().endObject();
        }
        if (Objects.nonNull(list)) {
            startObject.startObject("aggs");
            Iterator<Object> it = list.iterator();
            while (it.hasNext()) {
                startObject.value(it.next());
            }
            startObject.endObject();
        }
        if (Objects.nonNull(queryBuilder3)) {
            startObject.field("post_filter");
            queryBuilder3.toXContent(startObject, ToXContent.EMPTY_PARAMS);
        }
        if (Objects.nonNull(list2) && !list2.isEmpty()) {
            startObject.startArray("sort");
            Iterator<SortBuilder<?>> it2 = list2.iterator();
            while (it2.hasNext()) {
                it2.next().toXContent(startObject, ToXContent.EMPTY_PARAMS);
            }
            startObject.endArray();
        }
        if (z) {
            startObject.field("track_scores", z);
        }
        if (list3 != null && !list3.isEmpty()) {
            startObject.startArray("search_after");
            Iterator<Object> it3 = list3.iterator();
            while (it3.hasNext()) {
                startObject.value(it3.next());
            }
            startObject.endArray();
        }
        startObject.endObject();
        Request request = new Request("GET", "/" + str + "/_search?timeout=1000s");
        request.addParameter("size", Integer.toString(i));
        if (map != null && !map.isEmpty()) {
            Objects.requireNonNull(request);
            map.forEach(request::addParameter);
        }
        this.logger.info("Sorting request  " + startObject.toString());
        request.setJsonEntity(startObject.toString());
        Response performRequest = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(performRequest.getStatusLine().getStatusCode()));
        String entityUtils = EntityUtils.toString(performRequest.getEntity());
        this.logger.info("Response  " + entityUtils);
        return XContentHelper.convertToMap(XContentType.JSON.xContent(), entityUtils, false);
    }

    protected void addKnnDoc(String str, String str2, List<String> list, List<Object[]> list2) {
        addKnnDoc(str, str2, list, list2, Collections.emptyList(), Collections.emptyList());
    }

    protected void addKnnDoc(String str, String str2, List<String> list, List<Object[]> list2, List<String> list3, List<String> list4) {
        addKnnDoc(str, str2, list, list2, list3, list4, Collections.emptyList(), Collections.emptyList());
    }

    protected void addKnnDoc(String str, String str2, List<String> list, List<Object[]> list2, List<String> list3, List<String> list4, List<String> list5, List<Map<String, String>> list6) {
        addKnnDoc(str, str2, list, list2, list3, list4, list5, list6, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
    }

    protected void addKnnDoc(String str, String str2, List<String> list, List<Object[]> list2, List<String> list3, List<String> list4, List<String> list5, List<Map<String, String>> list6, List<String> list7, List<Integer> list8, List<String> list9, List<String> list10, List<String> list11, List<String> list12) {
        Request request = new Request("POST", "/" + str + "/_doc/" + str2 + "?refresh=true");
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject();
        for (int i = 0; i < list.size(); i++) {
            startObject.field(list.get(i), list2.get(i));
        }
        for (int i2 = 0; i2 < list3.size(); i2++) {
            startObject.field(list3.get(i2), list4.get(i2));
        }
        for (int i3 = 0; i3 < list5.size(); i3++) {
            startObject.field(list5.get(i3));
            startObject.startObject();
            for (Map.Entry<String, String> entry : list6.get(i3).entrySet()) {
                startObject.field(entry.getKey(), entry.getValue());
            }
            startObject.endObject();
        }
        for (int i4 = 0; i4 < list7.size(); i4++) {
            startObject.field(list7.get(i4), list8.get(i4));
        }
        for (int i5 = 0; i5 < list9.size(); i5++) {
            startObject.field(list9.get(i5), list10.get(i5));
        }
        for (int i6 = 0; i6 < list11.size(); i6++) {
            startObject.field(list11.get(i6), list12.get(i6));
        }
        startObject.endObject();
        request.setJsonEntity(startObject.toString());
        assertTrue(request.getEndpoint() + ": failed", SUCCESS_STATUSES.contains(RestStatus.fromCode(client().performRequest(request).getStatusLine().getStatusCode())));
    }

    protected void addSparseEncodingDoc(String str, String str2, List<String> list, List<Map<String, Float>> list2) {
        addSparseEncodingDoc(str, str2, list, list2, Collections.emptyList(), Collections.emptyList());
    }

    protected void addSparseEncodingDoc(String str, String str2, List<String> list, List<Map<String, Float>> list2, List<String> list3, List<String> list4) {
        Request request = new Request("POST", "/" + str + "/_doc/" + str2 + "?refresh=true");
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject();
        for (int i = 0; i < list.size(); i++) {
            startObject.field(list.get(i), list2.get(i));
        }
        for (int i2 = 0; i2 < list3.size(); i2++) {
            startObject.field(list3.get(i2), list4.get(i2));
        }
        startObject.endObject();
        request.setJsonEntity(startObject.toString());
        assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(client().performRequest(request).getStatusLine().getStatusCode()));
    }

    protected void bulkAddDocuments(String str, String str2, String str3, List<Map<String, String>> list) throws IOException {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < list.size(); i++) {
            sb.append(String.format(Locale.ROOT, "{ \"index\": { \"_index\": \"%s\", \"_id\": \"%s\" } },\n{ \"%s\": \"%s\"}", str, list.get(i).get("id"), str2, list.get(i).get("text")));
            sb.append("\n");
        }
        Request request = new Request("POST", String.format(Locale.ROOT, "/_bulk?refresh=true&pipeline=%s", str3));
        request.setJsonEntity(sb.toString());
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(client().performRequest(request).getStatusLine().getStatusCode()));
    }

    protected void bulkIngest(String str, String str2) {
        HashMap hashMap = new HashMap();
        hashMap.put("refresh", "true");
        if (Objects.nonNull(str2)) {
            hashMap.put("pipeline", str2);
        }
        Response makeRequest = makeRequest(client(), "POST", "_bulk", hashMap, toHttpEntity(str), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
        int i = 0;
        Iterator it = ((List) XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest.getEntity()), false).get("items")).iterator();
        while (it.hasNext()) {
            if (((Map) ((Map) it.next()).get("index")).get("error") != null) {
                i++;
            }
        }
        assertEquals(0L, i);
        assertEquals("_bulk failed", RestStatus.OK, RestStatus.fromCode(makeRequest.getStatusLine().getStatusCode()));
    }

    protected Map<String, Object> getFirstInnerHit(Map<String, Object> map) {
        List list = (List) ((Map) map.get("hits")).get("hits");
        assertTrue(list.size() > 0);
        return (Map) list.get(0);
    }

    protected int getHitCount(Map<String, Object> map) {
        return ((List) ((Map) map.get("hits")).get("hits")).size();
    }

    protected List<Double> getNormalizationScoreList(Map<String, Object> map) {
        List list = (List) ((Map) map.get("hits")).get("hits");
        ArrayList arrayList = new ArrayList();
        Iterator it = list.iterator();
        while (it.hasNext()) {
            arrayList.add((Double) ((Map) it.next()).get("_score"));
        }
        return arrayList;
    }

    protected void prepareKnnIndex(String str, List<KNNFieldConfig> list) {
        prepareKnnIndex(str, list, 3);
    }

    protected void prepareKnnIndex(String str, List<KNNFieldConfig> list, int i) {
        createIndexWithConfiguration(str, buildIndexConfiguration(list, i), "");
    }

    protected void prepareSparseEncodingIndex(String str, List<String> list) {
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject().startObject("mappings").startObject("properties");
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            startObject.startObject(it.next()).field("type", "rank_features").endObject();
        }
        startObject.endObject().endObject().endObject();
        createIndexWithConfiguration(str, startObject.toString(), "");
    }

    protected float computeExpectedScore(String str, float[] fArr, SpaceType spaceType, String str2) {
        return spaceType.getKnnVectorSimilarityFunction().compare(runInference(str, str2), fArr);
    }

    protected Map<String, Object> getTaskQueryResponse(String str) throws Exception {
        return XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "GET", String.format(LOCALE, "_plugins/_ml/tasks/%s", str), null, toHttpEntity(""), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false);
    }

    protected boolean checkComplete(Map<String, Object> map) {
        Predicate predicate = map2 -> {
            return map.get("error") != null || "COMPLETED".equals(String.valueOf(map.get("state")));
        };
        return predicate.test(map);
    }

    protected String buildIndexConfiguration(List<KNNFieldConfig> list, int i) {
        return buildIndexConfiguration(list, Collections.emptyList(), i);
    }

    protected String buildIndexConfiguration(List<KNNFieldConfig> list, List<String> list2, int i) {
        return buildIndexConfiguration(list, list2, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), i);
    }

    protected String buildIndexConfiguration(List<KNNFieldConfig> list, List<String> list2, List<String> list3, List<String> list4, List<String> list5, int i) {
        return buildIndexConfiguration(list, list2, list3, Collections.emptyList(), list4, list5, i);
    }

    protected String buildIndexConfiguration(List<KNNFieldConfig> list, List<String> list2, List<String> list3, List<String> list4, List<String> list5, List<String> list6, int i) {
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject().startObject("settings").field("number_of_shards", i).field("index.knn", true).endObject().startObject("mappings").startObject("properties");
        for (KNNFieldConfig kNNFieldConfig : list) {
            startObject.startObject(kNNFieldConfig.getName()).field("type", "knn_vector").field("dimension", Integer.toString(kNNFieldConfig.getDimension().intValue())).startObject("method").field("engine", "lucene").field("space_type", kNNFieldConfig.getSpaceType().getValue()).field("name", "hnsw").endObject().endObject();
        }
        if (!list2.isEmpty()) {
            startObject.startObject(list2.get(0)).field("type", "nested");
            if (list2.size() > 1) {
                startObject.startObject("properties");
                for (int i2 = 1; i2 < list2.size(); i2++) {
                    startObject.startObject(list2.get(i2)).field("type", "keyword").endObject();
                }
                startObject.endObject();
            }
            startObject.endObject();
        }
        Iterator<String> it = list3.iterator();
        while (it.hasNext()) {
            startObject.startObject(it.next()).field("type", "integer").endObject();
        }
        Iterator<String> it2 = list4.iterator();
        while (it2.hasNext()) {
            startObject.startObject(it2.next()).field("type", "float").endObject();
        }
        Iterator<String> it3 = list5.iterator();
        while (it3.hasNext()) {
            startObject.startObject(it3.next()).field("type", "keyword").endObject();
        }
        Iterator<String> it4 = list6.iterator();
        while (it4.hasNext()) {
            startObject.startObject(it4.next()).field("type", "date").field("format", "MM/dd/yyyy").endObject();
        }
        startObject.endObject().endObject().endObject();
        return startObject.toString();
    }

    protected static Response makeRequest(RestClient restClient, String str, String str2, Map<String, String> map, HttpEntity httpEntity, List<Header> list) throws IOException {
        return makeRequest(restClient, str, str2, map, httpEntity, list, false);
    }

    protected static Response makeRequest(RestClient restClient, String str, String str2, Map<String, String> map, HttpEntity httpEntity, List<Header> list, boolean z) throws IOException {
        Request request = new Request(str, str2);
        RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder();
        if (list != null) {
            list.forEach(header -> {
                builder.addHeader(header.getName(), header.getValue());
            });
        }
        builder.setWarningsHandler(z ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE);
        request.setOptions(builder.build());
        if (map != null) {
            Objects.requireNonNull(request);
            map.forEach(request::addParameter);
        }
        if (httpEntity != null) {
            request.setEntity(httpEntity);
        }
        return restClient.performRequest(request);
    }

    protected static HttpEntity toHttpEntity(String str) {
        return new StringEntity(str, ContentType.APPLICATION_JSON);
    }

    protected void deleteModel(String str) {
        makeRequest(client(), "POST", String.format(LOCALE, "/_plugins/_ml/models/%s/_undeploy", str), null, toHttpEntity(""), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
        pollForModelState(str, Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED));
        makeRequest(client(), "DELETE", String.format(LOCALE, "/_plugins/_ml/models/%s", str), null, toHttpEntity(""), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
    }

    protected void pollForModelState(String str, Set<MLModelState> set) throws InterruptedException {
        MLModelState mLModelState = null;
        for (int i = 0; i < 5; i++) {
            Thread.sleep(3000L);
            mLModelState = getModelState(str);
            if (set.contains(mLModelState)) {
                return;
            }
        }
        fail(String.format(LOCALE, "Model state does not reached exit states %s after %d attempts with interval of %d ms, latest model state: %s.", StringUtils.join(set, ","), 5, Integer.valueOf(TestUtils.MAX_TIME_OUT_INTERVAL), mLModelState));
    }

    protected MLModelState getModelState(String str) {
        return MLModelState.valueOf((String) XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "GET", String.format(LOCALE, "/_plugins/_ml/models/%s", str), null, toHttpEntity(""), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("model_state"));
    }

    public boolean isUpdateClusterSettings() {
        return true;
    }

    protected void createSearchPipelineWithResultsPostProcessor(String str) {
        createSearchPipeline(str, TestUtils.DEFAULT_NORMALIZATION_METHOD, TestUtils.DEFAULT_COMBINATION_METHOD, Map.of());
    }

    protected void createSearchPipeline(String str, String str2, String str3, Map<String, String> map) {
        StringBuilder sb = new StringBuilder();
        sb.append("{\"description\": \"Post processor pipeline\",").append("\"phase_results_processors\": [{ ").append("\"normalization-processor\": {").append("\"normalization\": {").append("\"technique\": \"%s\"").append("},").append("\"combination\": {").append("\"technique\": \"%s\"");
        if (Objects.nonNull(map) && !map.isEmpty()) {
            sb.append(", \"parameters\": {");
            if (map.containsKey(TestUtils.PARAM_NAME_WEIGHTS)) {
                sb.append("\"weights\": ").append(map.get(TestUtils.PARAM_NAME_WEIGHTS));
            }
            sb.append(" }");
        }
        sb.append("}").append("}}]}");
        makeRequest(client(), "PUT", String.format(LOCALE, "/_search/pipeline/%s", str), null, toHttpEntity(String.format(LOCALE, sb.toString(), str2, str3)), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
    }

    protected void createSearchPipelineWithDefaultResultsPostProcessor(String str) {
        makeRequest(client(), "PUT", String.format(LOCALE, "/_search/pipeline/%s", str), null, toHttpEntity(String.format(LOCALE, "{\"description\": \"Post processor pipeline\",\"phase_results_processors\": [{ \"normalization-processor\": {}}]}", new Object[0])), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
    }

    protected void deleteSearchPipeline(String str) {
        makeRequest(client(), "DELETE", String.format(LOCALE, "/_search/pipeline/%s", str), null, toHttpEntity(""), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT)));
    }

    private String getModelGroupId() {
        return registerModelGroup(String.format(LOCALE, Files.readString(Path.of(this.classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI())), "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8)));
    }

    protected String registerModelGroup(String str) throws IOException, ParseException {
        String obj = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "POST", "/_plugins/_ml/model_groups/_register", null, toHttpEntity(str), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("model_group_id").toString();
        assertNotNull(obj);
        return obj;
    }

    protected void waitForClusterHealthGreen(String str) throws IOException {
        Request request = new Request("GET", "/_cluster/health");
        request.addParameter("wait_for_nodes", str);
        request.addParameter("wait_for_status", "green");
        request.addParameter("cluster_manager_timeout", "60s");
        request.addParameter("timeout", "60s");
        client().performRequest(request);
    }

    protected void addDocument(String str, String str2, String str3, String str4, String str5, String str6) throws IOException {
        Request request = new Request("PUT", "/" + str + "/_doc/" + str2 + "?refresh=true");
        XContentBuilder startObject = XContentFactory.jsonBuilder().startObject();
        startObject.field(str3, str4);
        if (str5 != null && str6 != null) {
            startObject.field(str5, str6);
        }
        startObject.endObject();
        request.setJsonEntity(startObject.toString());
        client().performRequest(request);
    }

    protected Map<String, Object> getIngestionPipeline(String str) {
        Request request = new Request("GET", "/_ingest/pipeline/" + str);
        Response performRequest = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(performRequest.getStatusLine().getStatusCode()));
        return (Map) createParser(XContentType.JSON.xContent(), EntityUtils.toString(performRequest.getEntity())).map().get(str);
    }

    protected Map<String, Object> deletePipeline(String str) {
        Request request = new Request("DELETE", "/_ingest/pipeline/" + str);
        Response performRequest = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(performRequest.getStatusLine().getStatusCode()));
        return createParser(XContentType.JSON.xContent(), EntityUtils.toString(performRequest.getEntity())).map();
    }

    protected float computeExpectedScore(String str, Map<String, Float> map, String str2) {
        return computeExpectedScore(map, runSparseModelInference(str, str2));
    }

    protected float computeExpectedScore(Map<String, Float> map, Map<String, Float> map2) {
        Float valueOf = Float.valueOf(0.0f);
        for (Map.Entry<String, Float> entry : map2.entrySet()) {
            if (map.containsKey(entry.getKey())) {
                valueOf = Float.valueOf(valueOf.floatValue() + (entry.getValue().floatValue() * getFeatureFieldCompressedNumber(map.get(entry.getKey())).floatValue()));
            }
        }
        return valueOf.floatValue();
    }

    protected Map<String, Float> runSparseModelInference(String str, String str2) {
        Object obj = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(makeRequest(client(), "POST", String.format(LOCALE, "/_plugins/_ml/models/%s/_predict", str), null, toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"]}", str2)), ImmutableList.of(new BasicHeader("User-Agent", TestUtils.DEFAULT_USER_AGENT))).getEntity()), false).get("inference_results");
        assertTrue(obj instanceof List);
        assertEquals(1L, r0.size());
        List list = (List) ((Map) ((List) obj).get(0)).get("output");
        assertEquals(1L, list.size());
        Map map = (Map) list.get(0);
        assertEquals(1L, map.size());
        return (Map) TokenWeightUtil.fetchListOfTokenWeightMap(List.of((Map) map.get("dataAsMap"))).get(0);
    }

    protected Float getFeatureFieldCompressedNumber(Float f) {
        return Float.valueOf(Float.intBitsToFloat((Float.floatToIntBits(f.floatValue()) >> 15) << 15));
    }

    protected void wipeOfTestResources(String str, String str2, String str3, String str4) throws IOException {
        if (str2 != null) {
            deletePipeline(str2);
        }
        if (str4 != null) {
            deleteSearchPipeline(str4);
        }
        if (str3 != null) {
            try {
                deleteModel(str3);
            } catch (AssertionError e) {
                deleteModel(str3);
            }
        }
        if (str != null) {
            deleteIndex(str);
        }
    }

    protected Object validateDocCountAndInfo(String str, int i, Supplier<Map<String, Object>> supplier, String str2, Class<?> cls) {
        assertEquals(i, getDocCount(str));
        Map<String, Object> map = supplier.get();
        assertNotNull(map);
        Object obj = map.get("_source");
        assertTrue(obj instanceof Map);
        Map map2 = (Map) obj;
        assertTrue(map2.containsKey(str2));
        Object obj2 = map2.get(str2);
        assertTrue(cls.isAssignableFrom(obj2.getClass()));
        return obj2;
    }
}
