package org.opensearch.ml.action.deploy;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/deploy/TransportDeployModelAction.class */
public class TransportDeployModelAction extends HandledTransportAction<ActionRequest, MLDeployModelResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportDeployModelAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    MLTaskDispatcher mlTaskDispatcher;
    MLModelManager mlModelManager;
    MLStats mlStats;
    private volatile boolean allowCustomDeploymentPlan;
    private ModelAccessControlHelper modelAccessControlHelper;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportDeployModelAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mLTaskManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry namedXContentRegistry, DiscoveryNodeHelper discoveryNodeHelper, MLTaskDispatcher mLTaskDispatcher, MLModelManager mLModelManager, MLStats mLStats, Settings settings, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mLFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/deploy_model", transportService, actionFilters, MLDeployModelRequest::new);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mLTaskManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.nodeFilter = discoveryNodeHelper;
        this.mlTaskDispatcher = mLTaskDispatcher;
        this.mlModelManager = mLModelManager;
        this.mlStats = mLStats;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mLFeatureEnabledSetting;
        this.allowCustomDeploymentPlan = ((Boolean) MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings)).booleanValue();
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, bool -> {
            this.allowCustomDeploymentPlan = bool.booleanValue();
        });
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLDeployModelResponse> actionListener) {
        MLDeployModelRequest fromActionRequest = MLDeployModelRequest.fromActionRequest(actionRequest);
        String modelId = fromActionRequest.getModelId();
        User userContext = RestActionUtils.getUserContext(this.client);
        String[] strArr = {"model_content", "content"};
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                ActionListener runBefore = ActionListener.runBefore(actionListener, () -> {
                    stashContext.restore();
                });
                this.mlModelManager.getModel(modelId, null, strArr, ActionListener.wrap(mLModel -> {
                    FunctionName algorithm = mLModel.getAlgorithm();
                    if (algorithm == FunctionName.REMOTE && !this.mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
                        throw new IllegalStateException(MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG);
                    }
                    this.modelAccessControlHelper.validateModelGroupAccess(userContext, mLModel.getModelGroupId(), this.client, ActionListener.wrap(bool -> {
                        if (!bool.booleanValue()) {
                            runBefore.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation on this model"));
                            return;
                        }
                        String[] modelNodeIds = fromActionRequest.getModelNodeIds();
                        boolean z = modelNodeIds == null || modelNodeIds.length == 0;
                        if (!this.allowCustomDeploymentPlan && !z) {
                            throw new IllegalArgumentException("Don't allow custom deployment plan");
                        }
                        DiscoveryNode[] eligibleNodes = this.nodeFilter.getEligibleNodes(algorithm);
                        HashMap hashMap = new HashMap();
                        for (DiscoveryNode discoveryNode : eligibleNodes) {
                            hashMap.put(discoveryNode.getId(), discoveryNode);
                        }
                        Set set = (Set) Arrays.stream(eligibleNodes).map((v0) -> {
                            return v0.getId();
                        }).collect(Collectors.toSet());
                        ArrayList arrayList = new ArrayList();
                        ArrayList arrayList2 = new ArrayList();
                        if (z) {
                            arrayList2.addAll(set);
                            arrayList.addAll(Arrays.asList(eligibleNodes));
                        } else {
                            for (String str : modelNodeIds) {
                                if (set.contains(str)) {
                                    arrayList.add((DiscoveryNode) hashMap.get(str));
                                    arrayList2.add(str);
                                }
                            }
                            String[] workerNodes = this.mlModelManager.getWorkerNodes(modelId, algorithm);
                            if (workerNodes != null && workerNodes.length > 0) {
                                HashSet hashSet = new HashSet(Arrays.asList(workerNodes));
                                hashSet.removeAll(Arrays.asList(modelNodeIds));
                                if (hashSet.size() > 0) {
                                    runBefore.onFailure(new IllegalArgumentException("Model already deployed to these nodes: " + Arrays.toString(hashSet.toArray(new String[0])) + ", but they are not included in target node ids. Undeploy model from these nodes if don't need them any more."));
                                    return;
                                }
                            }
                        }
                        if (arrayList2.size() == 0) {
                            runBefore.onFailure(new IllegalArgumentException("no eligible node found"));
                            return;
                        }
                        log.info("Will deploy model on these nodes: {}", String.join(",", arrayList2));
                        String id = this.clusterService.localNode().getId();
                        FunctionName algorithm2 = mLModel.getAlgorithm();
                        MLTask build = MLTask.builder().async(true).modelId(modelId).taskType(MLTaskType.DEPLOY_MODEL).functionName(algorithm2).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).workerNodes(arrayList2).build();
                        this.mlTaskManager.createMLTask(build, ActionListener.wrap(indexResponse -> {
                            String id2 = indexResponse.getId();
                            build.setTaskId(id2);
                            if (algorithm2 == FunctionName.REMOTE) {
                                this.mlTaskManager.add(build, arrayList2);
                                deployRemoteModel(mLModel, build, id, arrayList, z, actionListener);
                                return;
                            }
                            try {
                                this.mlTaskManager.add(build, arrayList2);
                                runBefore.onResponse(new MLDeployModelResponse(id2, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
                                this.threadPool.executor(MachineLearningPlugin.DEPLOY_THREAD_POOL).execute(() -> {
                                    updateModelDeployStatusAndTriggerOnNodesAction(modelId, id2, mLModel, id, build, arrayList, z);
                                });
                            } catch (Exception e) {
                                log.error("Failed to deploy model", e);
                                this.mlTaskManager.updateMLTask(id2, ImmutableMap.of("state", MLTaskState.FAILED, "error", MLExceptionUtils.getRootCauseMessage(e)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                                runBefore.onFailure(e);
                            }
                        }, exc -> {
                            log.error("Failed to create deploy model task for " + modelId, exc);
                            runBefore.onFailure(exc);
                        }));
                    }, exc -> {
                        log.error("Failed to Validate Access for ModelId " + modelId, exc);
                        runBefore.onFailure(exc);
                    }));
                }, exc -> {
                    log.error("Failed to retrieve the ML model with ID: " + modelId, exc);
                    runBefore.onFailure(exc);
                }));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Failed to deploy the ML model with ID " + modelId, e);
            actionListener.onFailure(e);
        }
    }

    @VisibleForTesting
    void deployRemoteModel(MLModel mLModel, MLTask mLTask, String str, List<DiscoveryNode> list, boolean z, ActionListener<MLDeployModelResponse> actionListener) {
        MLDeployModelNodesRequest mLDeployModelNodesRequest = new MLDeployModelNodesRequest((DiscoveryNode[]) list.toArray(new DiscoveryNode[0]), new MLDeployModelInput(mLModel.getModelId(), mLTask.getTaskId(), mLModel.getModelContentHash(), Integer.valueOf(list.size()), str, Boolean.valueOf(z), mLTask));
        ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListener = deployModelNodesResponseListener(mLTask.getTaskId(), mLModel.getModelId(), actionListener);
        List list2 = (List) list.stream().map(discoveryNode -> {
            return discoveryNode.getId();
        }).collect(Collectors.toList());
        MLModelManager mLModelManager = this.mlModelManager;
        String modelId = mLModel.getModelId();
        ImmutableMap of = ImmutableMap.of("model_state", MLModelState.DEPLOYING, "planning_worker_node_count", Integer.valueOf(list.size()), "planning_worker_nodes", list2, "deploy_to_all_nodes", Boolean.valueOf(z));
        CheckedConsumer checkedConsumer = updateResponse -> {
            this.client.execute(MLDeployModelOnNodeAction.INSTANCE, mLDeployModelNodesRequest, deployModelNodesResponseListener);
        };
        Objects.requireNonNull(deployModelNodesResponseListener);
        mLModelManager.updateModel(modelId, of, ActionListener.wrap(checkedConsumer, deployModelNodesResponseListener::onFailure));
    }

    private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListener(String str, String str2, ActionListener<MLDeployModelResponse> actionListener) {
        return ActionListener.wrap(mLDeployModelNodesResponse -> {
            if (this.mlTaskManager.contains(str)) {
                this.mlTaskManager.updateMLTask(str, ImmutableMap.of("state", MLTaskState.RUNNING), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, false);
            }
            actionListener.onResponse(new MLDeployModelResponse(str, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name()));
        }, exc -> {
            log.error("Failed to deploy model " + str2, exc);
            this.mlTaskManager.updateMLTask(str, ImmutableMap.of("error", MLExceptionUtils.getRootCauseMessage(exc), "state", MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
            this.mlModelManager.updateModel(str2, ImmutableMap.of("model_state", MLModelState.DEPLOY_FAILED));
            actionListener.onFailure(exc);
        });
    }

    @VisibleForTesting
    void updateModelDeployStatusAndTriggerOnNodesAction(String str, String str2, MLModel mLModel, String str3, MLTask mLTask, List<DiscoveryNode> list, boolean z) {
        MLDeployModelNodesRequest mLDeployModelNodesRequest = new MLDeployModelNodesRequest((DiscoveryNode[]) list.toArray(new DiscoveryNode[0]), new MLDeployModelInput(str, str2, mLModel.getModelContentHash(), Integer.valueOf(list.size()), str3, Boolean.valueOf(z), mLTask));
        ActionListener wrap = ActionListener.wrap(mLDeployModelNodesResponse -> {
            if (this.mlTaskManager.contains(str2)) {
                this.mlTaskManager.updateMLTask(str2, ImmutableMap.of("state", MLTaskState.RUNNING), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, false);
            }
        }, exc -> {
            log.error("Failed to deploy model " + str, exc);
            this.mlTaskManager.updateMLTask(str2, ImmutableMap.of("error", MLExceptionUtils.getRootCauseMessage(exc), "state", MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
            this.mlModelManager.updateModel(str, ImmutableMap.of("model_state", MLModelState.DEPLOY_FAILED));
        });
        List list2 = (List) list.stream().map(discoveryNode -> {
            return discoveryNode.getId();
        }).collect(Collectors.toList());
        MLModelManager mLModelManager = this.mlModelManager;
        ImmutableMap of = ImmutableMap.of("model_state", MLModelState.DEPLOYING, "planning_worker_node_count", Integer.valueOf(list.size()), "planning_worker_nodes", list2, "deploy_to_all_nodes", Boolean.valueOf(z));
        CheckedConsumer checkedConsumer = updateResponse -> {
            this.client.execute(MLDeployModelOnNodeAction.INSTANCE, mLDeployModelNodesRequest, wrap);
        };
        Objects.requireNonNull(wrap);
        mLModelManager.updateModel(str, of, ActionListener.wrap(checkedConsumer, wrap::onFailure));
    }
}
