package org.opensearch.ml.action.deploy;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodeRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodeResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesResponse;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.class */
public class TransportDeployModelOnNodeAction extends TransportNodesAction<MLDeployModelNodesRequest, MLDeployModelNodesResponse, MLDeployModelNodeRequest, MLDeployModelNodeResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportDeployModelOnNodeAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLCircuitBreakerService mlCircuitBreakerService;
    MLStats mlStats;

    @Inject
    public TransportDeployModelOnNodeAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mLTaskManager, MLModelManager mLModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry namedXContentRegistry, MLCircuitBreakerService mLCircuitBreakerService, MLStats mLStats) {
        super("cluster:admin/opensearch/ml/deploy_model_on_nodes", threadPool, clusterService, transportService, actionFilters, MLDeployModelNodesRequest::new, MLDeployModelNodeRequest::new, "management", MLDeployModelNodeResponse.class);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mLTaskManager;
        this.mlModelManager = mLModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.mlCircuitBreakerService = mLCircuitBreakerService;
        this.mlStats = mLStats;
    }

    protected MLDeployModelNodesResponse newResponse(MLDeployModelNodesRequest mLDeployModelNodesRequest, List<MLDeployModelNodeResponse> list, List<FailedNodeException> list2) {
        return new MLDeployModelNodesResponse(this.clusterService.getClusterName(), list, list2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLDeployModelNodeRequest newNodeRequest(MLDeployModelNodesRequest mLDeployModelNodesRequest) {
        return new MLDeployModelNodeRequest(mLDeployModelNodesRequest);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: newNodeResponse, reason: merged with bridge method [inline-methods] */
    public MLDeployModelNodeResponse m7newNodeResponse(StreamInput streamInput) throws IOException {
        return new MLDeployModelNodeResponse(streamInput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLDeployModelNodeResponse nodeOperation(MLDeployModelNodeRequest mLDeployModelNodeRequest) {
        return createDeployModelNodeResponse(mLDeployModelNodeRequest.getMLDeployModelNodesRequest());
    }

    private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNodesRequest mLDeployModelNodesRequest) {
        MLDeployModelInput mlDeployModelInput = mLDeployModelNodesRequest.getMlDeployModelInput();
        String modelId = mlDeployModelInput.getModelId();
        String taskId = mlDeployModelInput.getTaskId();
        String coordinatingNodeId = mlDeployModelInput.getCoordinatingNodeId();
        MLTask mlTask = mlDeployModelInput.getMlTask();
        String modelContentHash = mlDeployModelInput.getModelContentHash();
        boolean booleanValue = mlDeployModelInput.getIsDeployToAllNodes().booleanValue();
        HashMap hashMap = new HashMap();
        hashMap.put(modelId, "received");
        String id = this.clusterService.localNode().getId();
        ActionListener wrap = ActionListener.wrap(mLForwardResponse -> {
            log.info("deploy model task done " + taskId);
        }, exc -> {
            MLExceptionUtils.logException("Deploy model task failed: " + taskId, exc, log);
        });
        deployModel(modelId, modelContentHash, mlTask.getFunctionName(), id, coordinatingNodeId, booleanValue, mlTask, ActionListener.wrap(str -> {
            this.transportService.sendRequest(getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", new MLForwardRequest(MLForwardInput.builder().requestType(MLForwardRequestType.DEPLOY_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).build()), new ActionListenerResponseHandler(wrap, MLForwardResponse::new));
        }, exc2 -> {
            this.transportService.sendRequest(getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", new MLForwardRequest(MLForwardInput.builder().requestType(MLForwardRequestType.DEPLOY_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).error(MLExceptionUtils.getRootCauseMessage(exc2)).build()), new ActionListenerResponseHandler(wrap, MLForwardResponse::new));
        }));
        return new MLDeployModelNodeResponse(this.clusterService.localNode(), hashMap);
    }

    private DiscoveryNode getNodeById(String str) {
        Iterator it = this.clusterService.state().getNodes().iterator();
        while (it.hasNext()) {
            DiscoveryNode discoveryNode = (DiscoveryNode) it.next();
            if (discoveryNode.getId().equals(str)) {
                return discoveryNode;
            }
        }
        return null;
    }

    private void deployModel(String str, String str2, FunctionName functionName, String str3, String str4, boolean z, MLTask mLTask, ActionListener<String> actionListener) {
        try {
            log.debug("start deploying model {}", str);
            this.mlModelManager.deployModel(str, str2, functionName, z, mLTask, ActionListener.runBefore(actionListener, () -> {
                if (str4.equals(str3)) {
                    return;
                }
                this.mlTaskManager.remove(mLTask.getTaskId());
            }));
        } catch (Exception e) {
            MLExceptionUtils.logException("Failed to deploy model " + str, e, log);
            actionListener.onFailure(e);
        }
    }

    protected /* bridge */ /* synthetic */ BaseNodesResponse newResponse(BaseNodesRequest baseNodesRequest, List list, List list2) {
        return newResponse((MLDeployModelNodesRequest) baseNodesRequest, (List<MLDeployModelNodeResponse>) list, (List<FailedNodeException>) list2);
    }
}
