package org.opensearch.ml.action.prediction;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/prediction/TransportPredictionTaskAction.class */
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private final MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private final TransportService transportService;
    private final MLModelCacheHelper modelCacheHelper;
    private final Client client;
    private final ClusterService clusterService;
    private final NamedXContentRegistry xContentRegistry;
    private final MLModelManager mlModelManager;
    private final ModelAccessControlHelper modelAccessControlHelper;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLPredictTaskRunner mLPredictTaskRunner, MLModelCacheHelper mLModelCacheHelper, ClusterService clusterService, Client client, NamedXContentRegistry namedXContentRegistry, MLModelManager mLModelManager, ModelAccessControlHelper modelAccessControlHelper) {
        super("cluster:admin/opensearch/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mLPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = mLModelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.mlModelManager = mLModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLTaskResponse> actionListener) {
        final MLPredictionTaskRequest fromActionRequest = MLPredictionTaskRequest.fromActionRequest(actionRequest);
        final String modelId = fromActionRequest.getModelId();
        User user = fromActionRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            fromActionRequest.setUser(user);
        }
        final User user2 = user;
        final ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
        try {
            final ActionListener runBefore = ActionListener.runBefore(actionListener, () -> {
                stashContext.restore();
            });
            MLModel modelInfo = this.modelCacheHelper.getModelInfo(modelId);
            ActionListener<MLModel> actionListener2 = new ActionListener<MLModel>() { // from class: org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1
                public void onResponse(MLModel mLModel) {
                    stashContext.restore();
                    TransportPredictionTaskAction.this.modelCacheHelper.setModelInfo(modelId, mLModel);
                    fromActionRequest.getMlInput().setAlgorithm(mLModel.getAlgorithm());
                    ModelAccessControlHelper modelAccessControlHelper = TransportPredictionTaskAction.this.modelAccessControlHelper;
                    User user3 = user2;
                    String modelGroupId = mLModel.getModelGroupId();
                    Client client = TransportPredictionTaskAction.this.client;
                    ActionListener actionListener3 = runBefore;
                    MLPredictionTaskRequest mLPredictionTaskRequest = fromActionRequest;
                    String str = modelId;
                    CheckedConsumer checkedConsumer = bool -> {
                        if (bool.booleanValue()) {
                            TransportPredictionTaskAction.this.executePredict(mLPredictionTaskRequest, actionListener3, str);
                        } else {
                            actionListener3.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation on this model"));
                        }
                    };
                    String str2 = modelId;
                    ActionListener actionListener4 = runBefore;
                    modelAccessControlHelper.validateModelGroupAccess(user3, modelGroupId, client, ActionListener.wrap(checkedConsumer, exc -> {
                        TransportPredictionTaskAction.log.error("Failed to Validate Access for ModelId " + str2, exc);
                        actionListener4.onFailure(exc);
                    }));
                }

                public void onFailure(Exception exc) {
                    TransportPredictionTaskAction.log.error("Failed to find model " + modelId, exc);
                    runBefore.onFailure(exc);
                }
            };
            if (modelInfo != null) {
                actionListener2.onResponse(modelInfo);
            } else {
                this.mlModelManager.getModel(modelId, actionListener2);
            }
            if (stashContext != null) {
                stashContext.close();
            }
        } catch (Throwable th) {
            if (stashContext != null) {
                try {
                    stashContext.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void executePredict(MLPredictionTaskRequest mLPredictionTaskRequest, ActionListener<MLTaskResponse> actionListener, String str) {
        String requestID = mLPredictionTaskRequest.getRequestID();
        log.debug("receive predict request " + requestID + " for model " + mLPredictionTaskRequest.getModelId());
        long nanoTime = System.nanoTime();
        this.mlPredictTaskRunner.run(this.modelCacheHelper.getOptionalFunctionName(str).orElse(mLPredictionTaskRequest.getMlInput().getAlgorithm()), mLPredictionTaskRequest, this.transportService, ActionListener.runAfter(actionListener, () -> {
            this.modelCacheHelper.addPredictRequestDuration(str, (System.nanoTime() - nanoTime) / 1000000.0d);
            log.debug("completed predict request " + requestID + " for model " + str);
        }));
    }
}
