package org.opensearch.ml.rest;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.stats.MLClusterLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStatLevel;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.MLStatsInput;
import org.opensearch.ml.utils.IndexUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;

/* loaded from: input_file:org/opensearch/ml/rest/RestMLStatsAction.class */
public class RestMLStatsAction extends BaseRestHandler {
    private static final String STATS_ML_ACTION = "stats_ml";
    private MLStats mlStats;
    private ClusterService clusterService;
    private IndexUtils indexUtils;
    private NamedXContentRegistry xContentRegistry;
    private static final String QUERY_ALL_MODEL_META_DOC = "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}";

    @Generated
    private static final Logger log = LogManager.getLogger(RestMLStatsAction.class);
    private static final Set<String> ML_NODE_STAT_NAMES = (Set) EnumSet.allOf(MLNodeLevelStat.class).stream().map(mLNodeLevelStat -> {
        return mLNodeLevelStat.name();
    }).collect(Collectors.toSet());

    public RestMLStatsAction(MLStats mLStats, ClusterService clusterService, IndexUtils indexUtils, NamedXContentRegistry namedXContentRegistry) {
        this.mlStats = mLStats;
        this.clusterService = clusterService;
        this.indexUtils = indexUtils;
        this.xContentRegistry = namedXContentRegistry;
    }

    public String getName() {
        return STATS_ML_ACTION;
    }

    public List<RestHandler.Route> routes() {
        return ImmutableList.of(new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/{stat}"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/{stat}"));
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) throws IOException {
        MLStatsInput createMlStatsInputFromRequestParams;
        if (restRequest.hasContent()) {
            XContentParser contentParser = restRequest.contentParser();
            XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, contentParser.nextToken(), contentParser);
            createMlStatsInputFromRequestParams = MLStatsInput.parse(contentParser);
        } else {
            createMlStatsInputFromRequestParams = createMlStatsInputFromRequestParams(restRequest);
        }
        MLStatsNodesRequest mLStatsNodesRequest = new MLStatsNodesRequest(createMlStatsInputFromRequestParams.retrieveStatsOnAllNodes() ? getAllNodes() : (String[]) createMlStatsInputFromRequestParams.getNodeIds().toArray(new String[0]), createMlStatsInputFromRequestParams);
        HashMap hashMap = new HashMap();
        if (createMlStatsInputFromRequestParams.getTargetStatLevels().contains(MLStatLevel.CLUSTER)) {
            hashMap.putAll(getClusterStatsMap(createMlStatsInputFromRequestParams));
        }
        MLStatsInput mLStatsInput = createMlStatsInputFromRequestParams;
        return restChannel -> {
            if (mLStatsInput.getTargetStatLevels().contains(MLStatLevel.CLUSTER) && (mLStatsInput.retrieveAllClusterLevelStats() || mLStatsInput.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT))) {
                this.indexUtils.getNumberOfDocumentsInIndex(".plugins-ml-model", QUERY_ALL_MODEL_META_DOC, this.xContentRegistry, ActionListener.wrap(l -> {
                    hashMap.put(MLClusterLevelStat.ML_MODEL_COUNT, l);
                    this.indexUtils.getNumberOfDocumentsInIndex(".plugins-ml-connector", ActionListener.wrap(l -> {
                        hashMap.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, l);
                        getNodeStats(mLStatsInput, hashMap, nodeClient, mLStatsNodesRequest, restChannel);
                    }, exc -> {
                        log.error("Failed to get ML model count", exc);
                        onFailure(restChannel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to get ML model count", exc);
                    }));
                }, exc -> {
                    log.error("Failed to get ML model count", exc);
                    onFailure(restChannel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to get ML model count", exc);
                }));
            } else {
                getNodeStats(mLStatsInput, hashMap, nodeClient, mLStatsNodesRequest, restChannel);
            }
        };
    }

    MLStatsInput createMlStatsInputFromRequestParams(RestRequest restRequest) {
        MLStatsInput mLStatsInput = new MLStatsInput();
        Optional<String[]> splitCommaSeparatedParam = RestActionUtils.splitCommaSeparatedParam(restRequest, "nodeId");
        if (splitCommaSeparatedParam.isPresent()) {
            mLStatsInput.getNodeIds().addAll(Arrays.asList(splitCommaSeparatedParam.get()));
        }
        Optional<String[]> splitCommaSeparatedParam2 = RestActionUtils.splitCommaSeparatedParam(restRequest, "stat");
        if (splitCommaSeparatedParam2.isPresent()) {
            for (String str : splitCommaSeparatedParam2.get()) {
                String upperCase = str.toUpperCase(Locale.ROOT);
                if (ML_NODE_STAT_NAMES.contains(upperCase)) {
                    mLStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(upperCase));
                } else {
                    mLStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(upperCase));
                }
            }
            if (mLStatsInput.getClusterLevelStats().size() > 0) {
                mLStatsInput.getTargetStatLevels().add(MLStatLevel.CLUSTER);
            }
            if (mLStatsInput.getNodeLevelStats().size() > 0) {
                mLStatsInput.getTargetStatLevels().add(MLStatLevel.NODE);
            }
        } else {
            mLStatsInput.getTargetStatLevels().addAll(EnumSet.allOf(MLStatLevel.class));
        }
        return mLStatsInput;
    }

    void getNodeStats(MLStatsInput mLStatsInput, Map<MLClusterLevelStat, Object> map, NodeClient nodeClient, MLStatsNodesRequest mLStatsNodesRequest, RestChannel restChannel) throws IOException {
        XContentBuilder newBuilder = restChannel.newBuilder();
        if (!mLStatsInput.onlyRetrieveClusterLevelStats()) {
            nodeClient.execute(MLStatsNodesAction.INSTANCE, mLStatsNodesRequest, ActionListener.wrap(mLStatsNodesResponse -> {
                newBuilder.startObject();
                if (map != null && map.size() > 0) {
                    for (Map.Entry entry : map.entrySet()) {
                        newBuilder.field(((MLClusterLevelStat) entry.getKey()).name().toLowerCase(Locale.ROOT), entry.getValue());
                    }
                }
                List list = (List) mLStatsNodesResponse.getNodes().stream().filter(mLStatsNodeResponse -> {
                    return !mLStatsNodeResponse.isEmpty();
                }).collect(Collectors.toList());
                if (list != null && list.size() > 0) {
                    mLStatsNodesResponse.toXContent(newBuilder, ToXContent.EMPTY_PARAMS);
                }
                newBuilder.endObject();
                restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, newBuilder));
            }, exc -> {
                log.error("Failed to get ML node level stats", exc);
                onFailure(restChannel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to get ML node level stats", exc);
            }));
            return;
        }
        newBuilder.startObject();
        if (map != null && map.size() > 0) {
            for (Map.Entry<MLClusterLevelStat, Object> entry : map.entrySet()) {
                newBuilder.field(entry.getKey().name().toLowerCase(Locale.ROOT), entry.getValue());
            }
        }
        newBuilder.endObject();
        restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, newBuilder));
    }

    private String[] getAllNodes() {
        Iterator it = this.clusterService.state().nodes().iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            arrayList.add(((DiscoveryNode) it.next()).getId());
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    private void onFailure(RestChannel restChannel, RestStatus restStatus, String str, Exception exc) {
        BytesRestResponse bytesRestResponse;
        try {
            bytesRestResponse = new BytesRestResponse(restChannel, exc);
        } catch (Exception e) {
            bytesRestResponse = new BytesRestResponse(restStatus, str);
        }
        restChannel.sendResponse(bytesRestResponse);
    }

    private Map<MLClusterLevelStat, Object> getClusterStatsMap(MLStatsInput mLStatsInput) {
        HashMap hashMap = new HashMap();
        this.mlStats.getClusterStats().entrySet().stream().filter(entry -> {
            return mLStatsInput.retrieveStat((Enum) entry.getKey());
        }).forEach(entry2 -> {
            hashMap.put((MLClusterLevelStat) entry2.getKey(), ((MLStat) entry2.getValue()).getValue());
        });
        return hashMap;
    }
}
