package org.opensearch.ml.task;

import com.google.common.collect.ImmutableList;
import java.time.Instant;
import java.util.UUID;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.permission.AccessController;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/task/MLPredictTaskRunner.class */
public class MLPredictTaskRunner extends MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(MLPredictTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLInputDatasetHandler mlInputDatasetHandler;
    private final NamedXContentRegistry xContentRegistry;
    private final MLModelManager mlModelManager;
    private final DiscoveryNodeHelper nodeHelper;
    private final MLEngine mlEngine;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.opensearch.ml.task.MLPredictTaskRunner$1, reason: invalid class name */
    /* loaded from: input_file:org/opensearch/ml/task/MLPredictTaskRunner$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$opensearch$ml$common$dataset$MLInputDataType = new int[MLInputDataType.values().length];

        static {
            try {
                $SwitchMap$org$opensearch$ml$common$dataset$MLInputDataType[MLInputDataType.SEARCH_QUERY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$opensearch$ml$common$dataset$MLInputDataType[MLInputDataType.DATA_FRAME.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$opensearch$ml$common$dataset$MLInputDataType[MLInputDataType.TEXT_DOCS.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public MLPredictTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mLTaskManager, MLStats mLStats, MLInputDatasetHandler mLInputDatasetHandler, MLTaskDispatcher mLTaskDispatcher, MLCircuitBreakerService mLCircuitBreakerService, NamedXContentRegistry namedXContentRegistry, MLModelManager mLModelManager, DiscoveryNodeHelper discoveryNodeHelper, MLEngine mLEngine) {
        super(mLTaskManager, mLStats, discoveryNodeHelper, mLTaskDispatcher, mLCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlInputDatasetHandler = mLInputDatasetHandler;
        this.xContentRegistry = namedXContentRegistry;
        this.mlModelManager = mLModelManager;
        this.nodeHelper = discoveryNodeHelper;
        this.mlEngine = mLEngine;
    }

    @Override // org.opensearch.ml.task.MLTaskRunner
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/predict";
    }

    @Override // org.opensearch.ml.task.MLTaskRunner
    protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> actionListener) {
        return new ActionListenerResponseHandler(actionListener, MLTaskResponse::new);
    }

    @Override // org.opensearch.ml.task.MLTaskRunner
    public void dispatchTask(FunctionName functionName, MLPredictionTaskRequest mLPredictionTaskRequest, TransportService transportService, ActionListener<MLTaskResponse> actionListener) {
        String modelId = mLPredictionTaskRequest.getModelId();
        try {
            ActionListener<DiscoveryNode> wrap = ActionListener.wrap(discoveryNode -> {
                if (this.clusterService.localNode().getId().equals(discoveryNode.getId())) {
                    log.debug("Execute ML predict request {} locally on node {}", mLPredictionTaskRequest.getRequestID(), discoveryNode.getId());
                    mLPredictionTaskRequest.setDispatchTask(false);
                    executeTask(mLPredictionTaskRequest, (ActionListener<MLTaskResponse>) actionListener);
                } else {
                    log.debug("Execute ML predict request {} remotely on node {}", mLPredictionTaskRequest.getRequestID(), discoveryNode.getId());
                    mLPredictionTaskRequest.setDispatchTask(false);
                    transportService.sendRequest(discoveryNode, getTransportActionName(), mLPredictionTaskRequest, getResponseHandler(actionListener));
                }
            }, exc -> {
                actionListener.onFailure(exc);
            });
            String[] workerNodes = this.mlModelManager.getWorkerNodes(modelId, functionName, true);
            if (workerNodes == null || workerNodes.length == 0) {
                if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) {
                    actionListener.onFailure(new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + modelId + "/_deploy"));
                    return;
                }
                workerNodes = this.nodeHelper.getEligibleNodeIds(functionName);
            }
            this.mlTaskDispatcher.dispatchPredictTask(workerNodes, wrap);
        } catch (Exception e) {
            log.error("Failed to predict model " + modelId, e);
            actionListener.onFailure(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.opensearch.ml.task.MLTaskRunner
    public void executeTask(MLPredictionTaskRequest mLPredictionTaskRequest, ActionListener<MLTaskResponse> actionListener) {
        MLInputDataType inputDataType = mLPredictionTaskRequest.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        String modelId = mLPredictionTaskRequest.getModelId();
        MLTask build = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(modelId).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(mLPredictionTaskRequest.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNodes(ImmutableList.of(this.clusterService.localNode().getId())).createTime(now).lastUpdateTime(now).async(false).build();
        MLInput mlInput = mLPredictionTaskRequest.getMlInput();
        switch (AnonymousClass1.$SwitchMap$org$opensearch$ml$common$dataset$MLInputDataType[inputDataType.ordinal()]) {
            case 1:
                this.mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), threadedActionListener(ActionListener.wrap(mLInputDataset -> {
                    predict(modelId, build, mlInput.toBuilder().inputDataset(mLInputDataset).build(), actionListener);
                }, exc -> {
                    log.error("Failed to generate DataFrame from search query", exc);
                    handleAsyncMLTaskFailure(build, exc);
                    actionListener.onFailure(exc);
                })));
                return;
            case 2:
            case 3:
            default:
                this.threadPool.executor(MachineLearningPlugin.PREDICT_THREAD_POOL).execute(() -> {
                    predict(modelId, build, mlInput, actionListener);
                });
                return;
        }
    }

    private void predict(String str, MLTask mLTask, MLInput mLInput, ActionListener<MLTaskResponse> actionListener) {
        ActionListener<MLTaskResponse> wrappedCleanupListener = wrappedCleanupListener(actionListener, mLTask.getTaskId());
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(mLTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        if (str != null) {
            this.mlStats.createModelCounterStatIfAbsent(str, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        }
        mLTask.setState(MLTaskState.RUNNING);
        this.mlTaskManager.add(mLTask);
        FunctionName algorithm = mLInput.getAlgorithm();
        if (str == null) {
            IllegalArgumentException illegalArgumentException = new IllegalArgumentException("ModelId is invalid");
            log.error("ModelId is invalid", illegalArgumentException);
            handlePredictFailure(mLTask, wrappedCleanupListener, illegalArgumentException, false, str);
            return;
        }
        Predictable predictor = this.mlModelManager.getPredictor(str);
        if (predictor != null) {
            try {
                if (!predictor.isModelReady()) {
                    throw new IllegalArgumentException("Model not ready: " + str);
                }
                MLPredictionOutput mLPredictionOutput = (MLOutput) this.mlModelManager.trackPredictDuration(str, () -> {
                    return predictor.predict(mLInput);
                });
                if (mLPredictionOutput instanceof MLPredictionOutput) {
                    mLPredictionOutput.setStatus(MLTaskState.COMPLETED.name());
                }
                handleAsyncMLTaskComplete(mLTask);
                wrappedCleanupListener.onResponse(MLTaskResponse.builder().output(mLPredictionOutput).build());
                return;
            } catch (Exception e) {
                handlePredictFailure(mLTask, wrappedCleanupListener, e, false, str);
                return;
            }
        }
        if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
            throw new IllegalArgumentException("Model not ready to be used: " + str);
        }
        try {
            ThreadContext.StoredContext stashContext = this.threadPool.getThreadContext().stashContext();
            try {
                ActionListener wrap = ActionListener.wrap(getResponse -> {
                    if (getResponse == null || !getResponse.isExists()) {
                        wrappedCleanupListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.", new Object[0]));
                        return;
                    }
                    try {
                        XContentParser createParser = XContentType.JSON.xContent().createParser(this.xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString());
                        try {
                            XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, createParser.nextToken(), createParser);
                            MLModel parse = MLModel.parse(createParser, getResponse.getSource().get(RestActionUtils.PARAMETER_ALGORITHM).toString());
                            parse.setModelId(str);
                            User user = parse.getUser();
                            User userContext = AccessController.getUserContext(this.client);
                            if (!AccessController.checkUserPermissions(userContext, user, str)) {
                                handlePredictFailure(mLTask, wrappedCleanupListener, new OpenSearchException("User: " + userContext.getName() + " does not have permissions to run predict by model: " + str, new Object[0]), false, str);
                                if (createParser != null) {
                                    createParser.close();
                                    return;
                                }
                                return;
                            }
                            if (this.mlTaskManager.contains(mLTask.getTaskId())) {
                                this.mlTaskManager.updateTaskStateAsRunning(mLTask.getTaskId(), mLTask.isAsync());
                            }
                            MLPredictionOutput predict = this.mlEngine.predict(mLInput, parse);
                            if (predict instanceof MLPredictionOutput) {
                                predict.setStatus(MLTaskState.COMPLETED.name());
                            }
                            handleAsyncMLTaskComplete(mLTask);
                            wrappedCleanupListener.onResponse(MLTaskResponse.builder().output(predict).build());
                            if (createParser != null) {
                                createParser.close();
                            }
                        } finally {
                        }
                    } catch (Exception e2) {
                        log.error("Failed to predict model " + str, e2);
                        wrappedCleanupListener.onFailure(e2);
                    }
                }, exc -> {
                    log.error("Failed to predict " + mLInput.getAlgorithm() + ", modelId: " + mLTask.getModelId(), exc);
                    handlePredictFailure(mLTask, wrappedCleanupListener, exc, true, str);
                });
                this.client.get(new GetRequest(".plugins-ml-model", mLTask.getModelId()), threadedActionListener(ActionListener.runBefore(wrap, () -> {
                    stashContext.restore();
                })));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e2) {
            log.error("Failed to get model " + mLTask.getModelId(), e2);
            handlePredictFailure(mLTask, wrappedCleanupListener, e2, true, str);
        }
    }

    private <T> ThreadedActionListener<T> threadedActionListener(ActionListener<T> actionListener) {
        return new ThreadedActionListener<>(log, this.threadPool, MachineLearningPlugin.PREDICT_THREAD_POOL, actionListener, false);
    }

    private void handlePredictFailure(MLTask mLTask, ActionListener<MLTaskResponse> actionListener, Exception exc, boolean z, String str) {
        if (z) {
            this.mlStats.createCounterStatIfAbsent(mLTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.mlStats.createModelCounterStatIfAbsent(str, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT);
            this.mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
        }
        handleAsyncMLTaskFailure(mLTask, exc);
        actionListener.onFailure(exc);
    }
}
