package org.opensearch.ml.task;

import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.naming.LimitExceededException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.action.stats.MLStatsNodeResponse;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.stats.MLNodeLevelStat;

/* loaded from: input_file:org/opensearch/ml/task/MLTaskDispatcher.class */
public class MLTaskDispatcher {

    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskDispatcher.class);
    private final ClusterService clusterService;
    private final Client client;
    private volatile Integer maxMLBatchTaskPerNode;
    private volatile String dispatchPolicy;
    private DiscoveryNodeHelper nodeHelper;
    private final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85;
    private final String ROUND_ROBIN = "round_robin";
    private final String LEAST_LOAD = "least_load";
    private AtomicInteger nextNode = new AtomicInteger(0);

    public MLTaskDispatcher(ClusterService clusterService, Client client, Settings settings, DiscoveryNodeHelper discoveryNodeHelper) {
        this.clusterService = clusterService;
        this.client = client;
        this.nodeHelper = discoveryNodeHelper;
        this.maxMLBatchTaskPerNode = (Integer) MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE.get(settings);
        this.dispatchPolicy = (String) MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY, str -> {
            this.dispatchPolicy = str;
        });
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE, num -> {
            this.maxMLBatchTaskPerNode = num;
        });
    }

    public void dispatch(FunctionName functionName, ActionListener<DiscoveryNode> actionListener) {
        if ("round_robin".equals(this.dispatchPolicy)) {
            dispatchTaskWithRoundRobin(functionName, actionListener);
        } else {
            if (!"least_load".equals(this.dispatchPolicy)) {
                throw new IllegalArgumentException("Unknown policy");
            }
            dispatchTaskWithLeastLoad(functionName, actionListener);
        }
    }

    public void dispatchPredictTask(String[] strArr, ActionListener<DiscoveryNode> actionListener) {
        if (strArr == null || strArr.length == 0) {
            throw new IllegalArgumentException("no eligible node to run predict request");
        }
        if ("round_robin".equals(this.dispatchPolicy)) {
            dispatchTaskWithRoundRobin(strArr, ActionListener.wrap(str -> {
                actionListener.onResponse(this.nodeHelper.getNode(str));
            }, exc -> {
                actionListener.onFailure(exc);
            }));
        } else {
            if (!"least_load".equals(this.dispatchPolicy)) {
                throw new IllegalArgumentException("Unknown policy");
            }
            dispatchTaskWithLeastLoad(strArr, actionListener);
        }
    }

    private <T> void dispatchTaskWithRoundRobin(T[] tArr, ActionListener<T> actionListener) {
        int andIncrement = this.nextNode.getAndIncrement();
        if (andIncrement > tArr.length - 1) {
            andIncrement = 0;
            this.nextNode.set(0 + 1);
        }
        actionListener.onResponse(tArr[andIncrement]);
    }

    private void dispatchTaskWithLeastLoad(String[] strArr, ActionListener<DiscoveryNode> actionListener) {
        dispatchTaskWithLeastLoad(this.nodeHelper.getNodes(strArr), actionListener);
    }

    private void dispatchTaskWithLeastLoad(DiscoveryNode[] discoveryNodeArr, ActionListener<DiscoveryNode> actionListener) {
        MLStatsNodesRequest mLStatsNodesRequest = new MLStatsNodesRequest(discoveryNodeArr);
        mLStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE));
        this.client.execute(MLStatsNodesAction.INSTANCE, mLStatsNodesRequest, ActionListener.wrap(mLStatsNodesResponse -> {
            List list = (List) mLStatsNodesResponse.getNodes().stream().filter(mLStatsNodeResponse -> {
                return ((Long) mLStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)).longValue() < 85;
            }).collect(Collectors.toList());
            if (list.size() == 0) {
                log.warn("All nodes' memory usage exceeds limitation 85. No eligible node available to run ml jobs ");
                actionListener.onFailure(new LimitExceededException("All nodes' memory usage exceeds limitation 85. No eligible node available to run ml jobs "));
                return;
            }
            List list2 = (List) list.stream().filter(mLStatsNodeResponse2 -> {
                return ((Long) mLStatsNodeResponse2.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)).longValue() < ((long) this.maxMLBatchTaskPerNode.intValue());
            }).collect(Collectors.toList());
            if (list2.size() != 0) {
                actionListener.onResponse(((MLStatsNodeResponse) list2.stream().sorted((mLStatsNodeResponse3, mLStatsNodeResponse4) -> {
                    int compareTo = ((Long) mLStatsNodeResponse3.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)).compareTo((Long) mLStatsNodeResponse4.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT));
                    return compareTo == 0 ? ((Long) mLStatsNodeResponse3.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)).compareTo((Long) mLStatsNodeResponse4.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) : compareTo;
                }).findFirst().get()).getNode());
            } else {
                log.warn("All nodes' executing ML task count reach limitation.");
                actionListener.onFailure(new LimitExceededException("All nodes' executing ML task count reach limitation."));
            }
        }, exc -> {
            log.error("Failed to get node's task stats", exc);
            actionListener.onFailure(exc);
        }));
    }

    private void dispatchTaskWithLeastLoad(FunctionName functionName, ActionListener<DiscoveryNode> actionListener) {
        dispatchTaskWithLeastLoad(this.nodeHelper.getEligibleNodes(functionName), actionListener);
    }

    private void dispatchTaskWithRoundRobin(FunctionName functionName, ActionListener<DiscoveryNode> actionListener) {
        DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes(functionName);
        if (eligibleNodes == null || eligibleNodes.length == 0) {
            throw new IllegalArgumentException("No eligible node found to execute this request. It's best practice to provision ML nodes to serve your models. You can disable this setting to serve the model on your data node for development purposes by disabling the \"plugins.ml_commons.only_run_on_ml_node\" configuration using the _cluster/setting api");
        }
        dispatchTaskWithRoundRobin(eligibleNodes, actionListener);
    }
}
