package org.opensearch.ml.task;

import java.time.Instant;
import java.util.UUID;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLModel;
import org.opensearch.ml.common.parameter.MLTask;
import org.opensearch.ml.common.parameter.MLTaskState;
import org.opensearch.ml.common.parameter.MLTaskType;
import org.opensearch.ml.common.parameter.MLTrainingOutput;
import org.opensearch.ml.common.parameter.Model;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.StatNames;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/task/MLTrainingTaskRunner.class */
public class MLTrainingTaskRunner extends MLTaskRunner<MLTrainingTaskRequest, MLTaskResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(MLTrainingTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLIndicesHandler mlIndicesHandler;
    private final MLInputDatasetHandler mlInputDatasetHandler;

    public MLTrainingTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mLTaskManager, MLStats mLStats, MLIndicesHandler mLIndicesHandler, MLInputDatasetHandler mLInputDatasetHandler, MLTaskDispatcher mLTaskDispatcher, MLCircuitBreakerService mLCircuitBreakerService) {
        super(mLTaskManager, mLStats, mLTaskDispatcher, mLCircuitBreakerService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlIndicesHandler = mLIndicesHandler;
        this.mlInputDatasetHandler = mLInputDatasetHandler;
    }

    @Override // org.opensearch.ml.task.MLTaskRunner
    public void executeTask(MLTrainingTaskRequest mLTrainingTaskRequest, TransportService transportService, ActionListener<MLTaskResponse> actionListener) {
        this.mlTaskDispatcher.dispatchTask(ActionListener.wrap(discoveryNode -> {
            if (this.clusterService.localNode().getId().equals(discoveryNode.getId())) {
                log.info("execute ML training request {} locally on node {}", mLTrainingTaskRequest.toString(), discoveryNode.getId());
                createMLTaskAndTrain(mLTrainingTaskRequest, actionListener);
            } else {
                log.info("execute ML training request {} remotely on node {}", mLTrainingTaskRequest.toString(), discoveryNode.getId());
                transportService.sendRequest(discoveryNode, "cluster:admin/opensearch/ml/train", mLTrainingTaskRequest, new ActionListenerResponseHandler(actionListener, MLTaskResponse::new));
            }
        }, exc -> {
            actionListener.onFailure(exc);
        }));
    }

    public void createMLTaskAndTrain(MLTrainingTaskRequest mLTrainingTaskRequest, ActionListener<MLTaskResponse> actionListener) {
        MLInputDataType inputDataType = mLTrainingTaskRequest.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        MLTask build = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(mLTrainingTaskRequest.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(this.clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(mLTrainingTaskRequest.isAsync()).build();
        if (mLTrainingTaskRequest.isAsync()) {
            this.mlTaskManager.createMLTask(build, ActionListener.wrap(indexResponse -> {
                String id = indexResponse.getId();
                build.setTaskId(id);
                actionListener.onResponse(new MLTaskResponse(new MLTrainingOutput((String) null, id, build.getState().name())));
                startTrainingTask(build, mLTrainingTaskRequest.getMlInput(), ActionListener.wrap(mLTaskResponse -> {
                    String modelId = mLTaskResponse.getOutput().getModelId();
                    log.info("ML model trained successfully, task id: {}, model id: {}", id, modelId);
                    build.setModelId(modelId);
                    handleAsyncMLTaskComplete(build);
                }, exc -> {
                    log.error("Failed to train ML model for task " + id);
                    handleAsyncMLTaskFailure(build, exc);
                }));
            }, exc -> {
                log.error("Failed to create ML task", exc);
                actionListener.onFailure(exc);
            }));
        } else {
            build.setTaskId(UUID.randomUUID().toString());
            startTrainingTask(build, mLTrainingTaskRequest.getMlInput(), actionListener);
        }
    }

    public void startTrainingTask(MLTask mLTask, MLInput mLInput, ActionListener<MLTaskResponse> actionListener) {
        ActionListener<MLTaskResponse> wrappedCleanupListener = wrappedCleanupListener(actionListener, mLTask.getTaskId());
        this.mlStats.getStat(StatNames.ML_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(StatNames.ML_TOTAL_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(StatNames.requestCountStat(mLTask.getFunctionName(), ActionName.TRAIN)).increment();
        this.mlTaskManager.add(mLTask);
        try {
            if (mLInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
                this.mlInputDatasetHandler.parseSearchQueryInput(mLInput.getInputDataset(), new ThreadedActionListener(log, this.threadPool, MachineLearningPlugin.TASK_THREAD_POOL, ActionListener.wrap(dataFrame -> {
                    train(mLTask, mLInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), wrappedCleanupListener);
                }, exc -> {
                    log.error("Failed to generate DataFrame from search query", exc);
                    wrappedCleanupListener.onFailure(exc);
                }), false));
            } else {
                this.threadPool.executor(MachineLearningPlugin.TASK_THREAD_POOL).execute(() -> {
                    train(mLTask, mLInput, wrappedCleanupListener);
                });
            }
        } catch (Exception e) {
            log.error("Failed to train " + mLInput.getAlgorithm(), e);
            wrappedCleanupListener.onFailure(e);
        }
    }

    private void train(MLTask mLTask, MLInput mLInput, ActionListener<MLTaskResponse> actionListener) {
        ActionListener wrap = ActionListener.wrap(mLTaskResponse -> {
            actionListener.onResponse(mLTaskResponse);
        }, exc -> {
            this.mlStats.createCounterStatIfAbsent(StatNames.failureCountStat(mLTask.getFunctionName(), ActionName.TRAIN)).increment();
            this.mlStats.getStat(StatNames.ML_TOTAL_FAILURE_COUNT).increment();
            actionListener.onFailure(exc);
        });
        try {
            this.mlTaskManager.updateTaskState(mLTask.getTaskId(), MLTaskState.RUNNING, mLTask.isAsync());
            Model train = MLEngine.train(mLInput);
            this.mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(bool -> {
                if (!bool.booleanValue()) {
                    wrap.onFailure(new RuntimeException("No response to create ML task index"));
                    return;
                }
                MLModel mLModel = new MLModel(mLInput.getAlgorithm(), train);
                try {
                    ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
                    try {
                        ActionListener wrap2 = ActionListener.wrap(indexResponse -> {
                            log.info("Model data indexing done, result:{}, model id: {}", indexResponse.getResult(), indexResponse.getId());
                            this.mlStats.getStat(StatNames.ML_TOTAL_MODEL_COUNT).increment();
                            this.mlStats.createCounterStatIfAbsent(StatNames.modelCountStat(mLTask.getFunctionName())).increment();
                            wrap.onResponse(MLTaskResponse.builder().output(new MLTrainingOutput(indexResponse.getId(), mLTask.isAsync() ? mLTask.getTaskId() : null, MLTaskState.COMPLETED.name())).build());
                        }, exc2 -> {
                            wrap.onFailure(exc2);
                        });
                        IndexRequest indexRequest = new IndexRequest(MLIndicesHandler.ML_MODEL_INDEX);
                        indexRequest.source(mLModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                        indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                        this.client.index(indexRequest, ActionListener.runBefore(wrap2, () -> {
                            stashContext.restore();
                        }));
                        if (stashContext != null) {
                            stashContext.close();
                        }
                    } finally {
                    }
                } catch (Exception e) {
                    log.error("Failed to save ML model", e);
                    wrap.onFailure(e);
                }
            }, exc2 -> {
                log.error("Failed to init ML model index", exc2);
                wrap.onFailure(exc2);
            }));
        } catch (Exception e) {
            log.error("Failed to train " + mLInput.getAlgorithm(), e);
            wrap.onFailure(e);
        }
    }
}
