package org.opensearch.ml.action.undeploy;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/undeploy/TransportUndeployModelAction.class */
public class TransportUndeployModelAction extends TransportNodesAction<MLUndeployModelNodesRequest, MLUndeployModelNodesResponse, MLUndeployModelNodeRequest, MLUndeployModelNodeResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportUndeployModelAction.class);
    private final MLModelManager mlModelManager;
    private final ClusterService clusterService;
    private final Client client;
    private final DiscoveryNodeHelper nodeFilter;
    private final MLStats mlStats;

    @Inject
    public TransportUndeployModelAction(TransportService transportService, ActionFilters actionFilters, MLModelManager mLModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, DiscoveryNodeHelper discoveryNodeHelper, MLStats mLStats) {
        super("cluster:admin/opensearch/ml/undeploy_model", threadPool, clusterService, transportService, actionFilters, MLUndeployModelNodesRequest::new, MLUndeployModelNodeRequest::new, "management", MLUndeployModelNodeResponse.class);
        this.mlModelManager = mLModelManager;
        this.clusterService = clusterService;
        this.client = client;
        this.nodeFilter = discoveryNodeHelper;
        this.mlStats = mLStats;
    }

    protected void doExecute(Task task, MLUndeployModelNodesRequest mLUndeployModelNodesRequest, ActionListener<MLUndeployModelNodesResponse> actionListener) {
        CheckedConsumer checkedConsumer = mLUndeployModelNodesResponse -> {
            processUndeployModelResponseAndUpdate(mLUndeployModelNodesResponse, actionListener);
        };
        Objects.requireNonNull(actionListener);
        super.doExecute(task, mLUndeployModelNodesRequest, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    void processUndeployModelResponseAndUpdate(MLUndeployModelNodesResponse mLUndeployModelNodesResponse, ActionListener<MLUndeployModelNodesResponse> actionListener) {
        List nodes = mLUndeployModelNodesResponse.getNodes();
        if (nodes == null || nodes.isEmpty()) {
            actionListener.onResponse(mLUndeployModelNodesResponse);
            return;
        }
        Map<String, List<String>> hashMap = new HashMap<>();
        HashMap hashMap2 = new HashMap();
        nodes.forEach(mLUndeployModelNodeResponse -> {
            Map modelWorkerNodeBeforeRemoval = mLUndeployModelNodeResponse.getModelWorkerNodeBeforeRemoval();
            if (modelWorkerNodeBeforeRemoval != null) {
                for (Map.Entry entry : modelWorkerNodeBeforeRemoval.entrySet()) {
                    if (entry.getValue() != null && (!hashMap2.containsKey(entry.getKey()) || ((String[]) hashMap2.get(entry.getKey())).length < ((String[]) entry.getValue()).length)) {
                        hashMap2.put((String) entry.getKey(), (String[]) entry.getValue());
                    }
                }
            }
            for (Map.Entry entry2 : mLUndeployModelNodeResponse.getModelUndeployStatus().entrySet()) {
                if ("undeployed".equals((String) entry2.getValue())) {
                    String str = (String) entry2.getKey();
                    if (!hashMap.containsKey(str)) {
                        hashMap.put(str, new ArrayList());
                    }
                    ((List) hashMap.get(str)).add(mLUndeployModelNodeResponse.getNode().getId());
                }
            }
        });
        MLSyncUpInput build = MLSyncUpInput.builder().removedWorkerNodes(covertRemoveNodesMapForSyncUp(hashMap)).build();
        MLSyncUpNodesRequest mLSyncUpNodesRequest = new MLSyncUpNodesRequest(this.nodeFilter.getAllNodes(), build);
        ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
        try {
            if (hashMap.size() > 0) {
                BulkRequest bulkRequest = new BulkRequest();
                HashMap hashMap3 = new HashMap();
                for (String str : hashMap.keySet()) {
                    UpdateRequest updateRequest = new UpdateRequest();
                    List<String> list = hashMap.get(str);
                    int size = list.size();
                    HashMap hashMap4 = new HashMap();
                    if (((String[]) hashMap2.get(str)).length == size) {
                        hashMap4.put("planning_worker_nodes", ImmutableList.of());
                        hashMap4.put("planning_worker_node_count", 0);
                        hashMap4.put("current_worker_node_count", 0);
                        hashMap4.put("model_state", MLModelState.UNDEPLOYED);
                    } else {
                        hashMap4.put("deploy_to_all_nodes", false);
                        List list2 = (List) Arrays.stream((String[]) hashMap2.get(str)).filter(str2 -> {
                            return !list.contains(str2);
                        }).collect(Collectors.toList());
                        hashMap4.put("planning_worker_nodes", list2);
                        hashMap4.put("planning_worker_node_count", Integer.valueOf(list2.size()));
                        hashMap4.put("current_worker_node_count", Integer.valueOf(list2.size()));
                        hashMap3.put(str, false);
                    }
                    updateRequest.index(".plugins-ml-model").id(str).doc(hashMap4);
                    bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                }
                build.setDeployToAllNodes(hashMap3);
                this.client.bulk(bulkRequest, ActionListener.runAfter(ActionListener.wrap(bulkResponse -> {
                    log.debug("updated model state as undeployed for : {}", Arrays.toString(hashMap.keySet().toArray(new String[0])));
                }, exc -> {
                    log.error("Failed to update model state as undeployed", exc);
                }), () -> {
                    syncUpUndeployedModels(mLSyncUpNodesRequest);
                    actionListener.onResponse(mLUndeployModelNodesResponse);
                }));
            } else {
                syncUpUndeployedModels(mLSyncUpNodesRequest);
                actionListener.onResponse(mLUndeployModelNodesResponse);
            }
            if (stashContext != null) {
                stashContext.close();
            }
        } catch (Throwable th) {
            if (stashContext != null) {
                try {
                    stashContext.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    protected MLUndeployModelNodesResponse newResponse(MLUndeployModelNodesRequest mLUndeployModelNodesRequest, List<MLUndeployModelNodeResponse> list, List<FailedNodeException> list2) {
        return new MLUndeployModelNodesResponse(this.clusterService.getClusterName(), list, list2);
    }

    private Map<String, String[]> covertRemoveNodesMapForSyncUp(Map<String, List<String>> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, List<String>> entry : map.entrySet()) {
            hashMap.put(entry.getKey(), (String[]) entry.getValue().toArray(new String[0]));
            log.debug("removed node for model: {}, {}", entry.getKey(), Arrays.toString(entry.getValue().toArray(new String[0])));
        }
        return hashMap;
    }

    private void syncUpUndeployedModels(MLSyncUpNodesRequest mLSyncUpNodesRequest) {
        this.client.execute(MLSyncUpAction.INSTANCE, mLSyncUpNodesRequest, ActionListener.wrap(mLSyncUpNodesResponse -> {
            log.debug("sync up removed nodes successfully");
        }, exc -> {
            log.error("failed to sync up removed node", exc);
        }));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLUndeployModelNodeRequest newNodeRequest(MLUndeployModelNodesRequest mLUndeployModelNodesRequest) {
        return new MLUndeployModelNodeRequest(mLUndeployModelNodesRequest);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: newNodeResponse, reason: merged with bridge method [inline-methods] */
    public MLUndeployModelNodeResponse m52newNodeResponse(StreamInput streamInput) throws IOException {
        return new MLUndeployModelNodeResponse(streamInput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest mLUndeployModelNodeRequest) {
        return createUndeployModelNodeResponse(mLUndeployModelNodeRequest.getMlUndeployModelNodesRequest());
    }

    private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest mLUndeployModelNodesRequest) {
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        String[] modelIds = mLUndeployModelNodesRequest.getModelIds();
        HashMap hashMap = new HashMap();
        String[] allModelIds = modelIds != null && modelIds.length > 0 ? modelIds : this.mlModelManager.getAllModelIds();
        if (allModelIds != null) {
            for (String str : allModelIds) {
                hashMap.put(str, this.mlModelManager.getWorkerNodes(str, this.mlModelManager.getModelFunctionName(str)));
            }
        }
        Map<String, String> undeployModel = this.mlModelManager.undeployModel(modelIds);
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
        return new MLUndeployModelNodeResponse(this.clusterService.localNode(), undeployModel, hashMap);
    }

    protected /* bridge */ /* synthetic */ BaseNodesResponse newResponse(BaseNodesRequest baseNodesRequest, List list, List list2) {
        return newResponse((MLUndeployModelNodesRequest) baseNodesRequest, (List<MLUndeployModelNodeResponse>) list, (List<FailedNodeException>) list2);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, BaseNodesRequest baseNodesRequest, ActionListener actionListener) {
        doExecute(task, (MLUndeployModelNodesRequest) baseNodesRequest, (ActionListener<MLUndeployModelNodesResponse>) actionListener);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (MLUndeployModelNodesRequest) actionRequest, (ActionListener<MLUndeployModelNodesResponse>) actionListener);
    }
}
