package org.opensearch.ml.action.tasks;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest;
import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.class */
public class CancelBatchJobTransportAction extends HandledTransportAction<ActionRequest, MLCancelBatchJobResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(CancelBatchJobTransportAction.class);
    Client client;
    NamedXContentRegistry xContentRegistry;
    ClusterService clusterService;
    ScriptService scriptService;
    ConnectorAccessControlHelper connectorAccessControlHelper;
    ModelAccessControlHelper modelAccessControlHelper;
    EncryptorImpl encryptor;
    MLModelManager mlModelManager;
    MLTaskManager mlTaskManager;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public CancelBatchJobTransportAction(TransportService transportService, ActionFilters actionFilters, Client client, NamedXContentRegistry namedXContentRegistry, ClusterService clusterService, ScriptService scriptService, ConnectorAccessControlHelper connectorAccessControlHelper, ModelAccessControlHelper modelAccessControlHelper, EncryptorImpl encryptorImpl, MLTaskManager mLTaskManager, MLModelManager mLModelManager, MLFeatureEnabledSetting mLFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/tasks/cancel", transportService, actionFilters, MLCancelBatchJobRequest::new);
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.clusterService = clusterService;
        this.scriptService = scriptService;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.encryptor = encryptorImpl;
        this.mlTaskManager = mLTaskManager;
        this.mlModelManager = mLModelManager;
        this.mlFeatureEnabledSetting = mLFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLCancelBatchJobResponse> actionListener) {
        String taskId = MLCancelBatchJobRequest.fromActionRequest(actionRequest).getTaskId();
        GetRequest id = new GetRequest(".plugins-ml-task").id(taskId);
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                this.client.get(id, ActionListener.runBefore(ActionListener.wrap(getResponse -> {
                    log.debug("Completed Get Task Request, id:{}", taskId);
                    if (getResponse == null || !getResponse.isExists()) {
                        actionListener.onFailure(new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND, new Object[0]));
                        return;
                    }
                    try {
                        XContentParser createXContentParserFromRegistry = MLNodeUtils.createXContentParserFromRegistry(this.xContentRegistry, getResponse.getSourceAsBytesRef());
                        try {
                            XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, createXContentParserFromRegistry.nextToken(), createXContentParserFromRegistry);
                            MLTask parse = MLTask.parse(createXContentParserFromRegistry);
                            if (parse.getTaskType() == MLTaskType.BATCH_PREDICTION && !this.mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled().booleanValue()) {
                                throw new IllegalStateException(MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG);
                            }
                            if (parse.getTaskType() == MLTaskType.BATCH_PREDICTION && parse.getFunctionName() == FunctionName.REMOTE) {
                                processRemoteBatchPrediction(parse, actionListener);
                            } else {
                                actionListener.onFailure(new IllegalArgumentException("The task ID you provided does not have any associated batch job"));
                            }
                            if (createXContentParserFromRegistry != null) {
                                createXContentParserFromRegistry.close();
                            }
                        } finally {
                        }
                    } catch (Exception e) {
                        log.error("Failed to parse ml task " + getResponse.getId(), e);
                        actionListener.onFailure(e);
                    }
                }, exc -> {
                    if (exc instanceof IndexNotFoundException) {
                        actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
                    } else {
                        log.error("Failed to get ML task " + taskId, exc);
                        actionListener.onFailure(exc);
                    }
                }), () -> {
                    stashContext.restore();
                }));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Failed to get ML task " + taskId, e);
            actionListener.onFailure(e);
        }
    }

    private void processRemoteBatchPrediction(MLTask mLTask, ActionListener<MLCancelBatchJobResponse> actionListener) {
        Map remoteJob = mLTask.getRemoteJob();
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : remoteJob.entrySet()) {
            if (entry.getValue() instanceof String) {
                hashMap.put((String) entry.getKey(), (String) entry.getValue());
            } else {
                log.debug("Value for key " + ((String) entry.getKey()) + " is not a String");
            }
        }
        hashMap.computeIfAbsent("TransformJobName", str -> {
            return (String) Optional.ofNullable((String) hashMap.get("TransformJobArn")).map(str -> {
                return str.substring(str.lastIndexOf("/") + 1);
            }).orElse(null);
        });
        MLInput build = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(new RemoteInferenceInputDataSet(hashMap, ConnectorAction.ActionType.BATCH_PREDICT_STATUS)).build();
        String modelId = mLTask.getModelId();
        User userContext = RestActionUtils.getUserContext(this.client);
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                ActionListener wrap = ActionListener.wrap(mLModel -> {
                    this.modelAccessControlHelper.validateModelGroupAccess(userContext, mLModel.getModelGroupId(), this.client, ActionListener.wrap(bool -> {
                        if (!bool.booleanValue()) {
                            actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job"));
                            return;
                        }
                        if (mLModel.getConnector() != null) {
                            executeConnector(mLModel.getConnector(), build, actionListener);
                            return;
                        }
                        if (!this.clusterService.state().metadata().hasIndex(".plugins-ml-connector")) {
                            actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + mLModel.getConnectorId(), new Object[0]));
                            return;
                        }
                        ActionListener wrap2 = ActionListener.wrap(connector -> {
                            executeConnector(connector, build, actionListener);
                        }, exc -> {
                            log.error("Failed to get connector " + mLModel.getConnectorId(), exc);
                            actionListener.onFailure(exc);
                        });
                        ThreadContext.StoredContext stashContext2 = this.client.threadPool().getThreadContext().stashContext();
                        try {
                            ConnectorAccessControlHelper connectorAccessControlHelper = this.connectorAccessControlHelper;
                            Client client = this.client;
                            String connectorId = mLModel.getConnectorId();
                            Objects.requireNonNull(stashContext2);
                            connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.runBefore(wrap2, stashContext2::restore));
                            if (stashContext2 != null) {
                                stashContext2.close();
                            }
                        } catch (Throwable th) {
                            if (stashContext2 != null) {
                                try {
                                    stashContext2.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    }, exc -> {
                        log.error("Failed to validate Access for Model Group " + mLModel.getModelGroupId(), exc);
                        actionListener.onFailure(exc);
                    }));
                }, exc -> {
                    log.error("Failed to retrieve the ML model with the given ID", exc);
                    actionListener.onFailure(new OpenSearchStatusException("Failed to retrieve the ML model for the given task ID", RestStatus.NOT_FOUND, new Object[0]));
                });
                MLModelManager mLModelManager = this.mlModelManager;
                Objects.requireNonNull(stashContext);
                mLModelManager.getModel(modelId, null, null, ActionListener.runBefore(wrap, stashContext::restore));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Unable to fetch cancel batch job in ml task ", e);
            throw new OpenSearchException("Unable to fetch cancel batch job in ml task " + e.getMessage(), new Object[0]);
        }
    }

    private void executeConnector(Connector connector, MLInput mLInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
        Optional findAction = connector.findAction(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.name());
        if (!findAction.isPresent() || ((ConnectorAction) findAction.get()).getRequestBody() == null) {
            connector.addAction(ConnectorUtils.createConnectorAction(connector, ConnectorAction.ActionType.CANCEL_BATCH_PREDICT));
        }
        connector.decrypt(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.name(), str -> {
            return this.encryptor.decrypt(str);
        });
        RemoteConnectorExecutor remoteConnectorExecutor = (RemoteConnectorExecutor) MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
        remoteConnectorExecutor.setScriptService(this.scriptService);
        remoteConnectorExecutor.setClusterService(this.clusterService);
        remoteConnectorExecutor.setClient(this.client);
        remoteConnectorExecutor.setXContentRegistry(this.xContentRegistry);
        remoteConnectorExecutor.executeAction(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.name(), mLInput, ActionListener.wrap(mLTaskResponse -> {
            processTaskResponse(mLTaskResponse, actionListener);
        }, exc -> {
            actionListener.onFailure(exc);
        }));
    }

    private void processTaskResponse(MLTaskResponse mLTaskResponse, ActionListener<MLCancelBatchJobResponse> actionListener) {
        try {
            ModelTensorOutput output = mLTaskResponse.getOutput();
            if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) {
                log.debug("ML Model Outputs are null or empty.");
                actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job", new Object[0]));
            } else {
                ModelTensors modelTensors = (ModelTensors) output.getMlModelOutputs().get(0);
                if (modelTensors.getStatusCode() == null || !modelTensors.getStatusCode().equals(200)) {
                    log.debug("The status code from remote service is: " + modelTensors.getStatusCode());
                    actionListener.onFailure(new OpenSearchException("Couldn't cancel the transform job. Please try again", new Object[0]));
                } else {
                    actionListener.onResponse(new MLCancelBatchJobResponse(RestStatus.OK));
                }
            }
        } catch (Exception e) {
            log.error("Unable to fetch status for ml task ", e);
        }
    }
}
