package org.opensearch.ml.action.syncup;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.class */
public class TransportSyncUpOnNodeAction extends TransportNodesAction<MLSyncUpNodesRequest, MLSyncUpNodesResponse, MLSyncUpNodeRequest, MLSyncUpNodeResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportSyncUpOnNodeAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLEngine mlEngine;
    private volatile Integer mlTaskTimeout;
    private final MLModelCacheHelper mlModelCacheHelper;

    @Inject
    public TransportSyncUpOnNodeAction(TransportService transportService, Settings settings, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mLTaskManager, MLModelManager mLModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry namedXContentRegistry, MLEngine mLEngine, MLModelCacheHelper mLModelCacheHelper) {
        super("cluster:admin/opensearch/mlinternal/syncup", threadPool, clusterService, transportService, actionFilters, MLSyncUpNodesRequest::new, MLSyncUpNodeRequest::new, "management", MLSyncUpNodeResponse.class);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mLTaskManager;
        this.mlModelManager = mLModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.mlEngine = mLEngine;
        this.mlModelCacheHelper = mLModelCacheHelper;
        this.mlTaskTimeout = (Integer) MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, num -> {
            this.mlTaskTimeout = num;
        });
    }

    protected MLSyncUpNodesResponse newResponse(MLSyncUpNodesRequest mLSyncUpNodesRequest, List<MLSyncUpNodeResponse> list, List<FailedNodeException> list2) {
        return new MLSyncUpNodesResponse(this.clusterService.getClusterName(), list, list2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLSyncUpNodeRequest newNodeRequest(MLSyncUpNodesRequest mLSyncUpNodesRequest) {
        return new MLSyncUpNodeRequest(mLSyncUpNodesRequest);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: newNodeResponse, reason: merged with bridge method [inline-methods] */
    public MLSyncUpNodeResponse m27newNodeResponse(StreamInput streamInput) throws IOException {
        return new MLSyncUpNodeResponse(streamInput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLSyncUpNodeResponse nodeOperation(MLSyncUpNodeRequest mLSyncUpNodeRequest) {
        return createSyncUpNodeResponse(mLSyncUpNodeRequest.getSyncUpNodesRequest());
    }

    private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest mLSyncUpNodesRequest) {
        MLSyncUpInput syncUpInput = mLSyncUpNodesRequest.getSyncUpInput();
        Map addedWorkerNodes = syncUpInput.getAddedWorkerNodes();
        Map removedWorkerNodes = syncUpInput.getRemovedWorkerNodes();
        Map<String, Set<String>> modelRoutingTable = syncUpInput.getModelRoutingTable();
        Map<String, Set<String>> runningDeployModelTasks = syncUpInput.getRunningDeployModelTasks();
        Map deployToAllNodes = syncUpInput.getDeployToAllNodes();
        if (addedWorkerNodes != null && addedWorkerNodes.size() > 0) {
            for (Map.Entry entry : addedWorkerNodes.entrySet()) {
                this.mlModelManager.addModelWorkerNode((String) entry.getKey(), (String[]) entry.getValue());
            }
        }
        if (removedWorkerNodes != null && removedWorkerNodes.size() > 0) {
            for (Map.Entry entry2 : removedWorkerNodes.entrySet()) {
                this.mlModelManager.removeModelWorkerNode((String) entry2.getKey(), ((Map) Optional.ofNullable(deployToAllNodes).orElse(ImmutableMap.of())).containsKey(entry2.getKey()), (String[]) entry2.getValue());
            }
        }
        String[] strArr = null;
        String[] strArr2 = null;
        String[] strArr3 = null;
        if (syncUpInput.isGetDeployedModels()) {
            strArr = this.mlModelManager.getLocalDeployedModels();
            List<String[]> localRunningDeployModelTasks = this.mlTaskManager.getLocalRunningDeployModelTasks();
            strArr2 = localRunningDeployModelTasks.get(0);
            strArr3 = localRunningDeployModelTasks.get(1);
        }
        if (syncUpInput.isClearRoutingTable()) {
            this.mlModelManager.clearRoutingTable();
        } else if (modelRoutingTable != null) {
            for (Map.Entry<String, Set<String>> entry3 : modelRoutingTable.entrySet()) {
                log.debug("latest routing table for model: {}:  {}", entry3.getKey(), entry3.getValue().toArray(new String[0]));
            }
            this.mlModelManager.syncModelWorkerNodes(modelRoutingTable);
        }
        cleanUpLocalCache(runningDeployModelTasks);
        cleanUpLocalCacheFiles();
        return new MLSyncUpNodeResponse(this.clusterService.localNode(), "ok", strArr, strArr3, strArr2);
    }

    @VisibleForTesting
    void cleanUpLocalCache(Map<String, Set<String>> map) {
        String[] allTaskIds = this.mlTaskManager.getAllTaskIds();
        if (allTaskIds == null) {
            return;
        }
        for (String str : allTaskIds) {
            MLTask mlTask = this.mlTaskManager.getMLTaskCache(str).getMlTask();
            if (Instant.now().isAfter(mlTask.getLastUpdateTime().plusSeconds(this.mlTaskTimeout.intValue()))) {
                log.info("ML task timeout. task id: {}, task type: {}", str, mlTask.getTaskType());
                if (mlTask.getTaskType() != MLTaskType.DEPLOY_MODEL || mlTask.getState() != MLTaskState.CREATED || map == null || !map.containsKey(str)) {
                    this.mlTaskManager.updateMLTask(str, ImmutableMap.of("state", MLTaskState.FAILED, "error", "timeout after " + this.mlTaskTimeout + " seconds"), 10000L, true);
                }
            }
        }
    }

    private void cleanUpLocalCacheFiles() {
        Set<String> fileNames = FileUtils.getFileNames(new Path[]{this.mlEngine.getRegisterModelRootPath(), this.mlEngine.getDeployModelRootPath(), this.mlEngine.getModelCacheRootPath()});
        if (fileNames.size() > 0) {
            log.debug("Found {} models in cache folder: {}", Integer.valueOf(fileNames.size()), Arrays.toString(fileNames.toArray(new String[0])));
            for (String str : fileNames) {
                if (!this.mlTaskManager.contains(str) && !this.mlTaskManager.containsModel(str) && !this.mlModelManager.isModelRunningOnNode(str)) {
                    log.info("ML model not in cache. Remove all of its cache files. model id: {}", str);
                    deleteFileCache(str);
                }
            }
        }
    }

    private void deleteFileCache(String str) {
        FileUtils.deleteFileQuietly(this.mlEngine.getModelCachePath(str));
        FileUtils.deleteFileQuietly(this.mlEngine.getDeployModelPath(str));
        FileUtils.deleteFileQuietly(this.mlEngine.getRegisterModelPath(str));
    }

    protected /* bridge */ /* synthetic */ BaseNodesResponse newResponse(BaseNodesRequest baseNodesRequest, List list, List list2) {
        return newResponse((MLSyncUpNodesRequest) baseNodesRequest, (List<MLSyncUpNodeResponse>) list, (List<FailedNodeException>) list2);
    }
}
