package org.opensearch.ml.action.prediction;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/prediction/TransportPredictionTaskAction.class */
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private TransportService transportService;
    private MLModelCacheHelper modelCacheHelper;
    private Client client;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private MLModelManager mlModelManager;
    private ModelAccessControlHelper modelAccessControlHelper;
    private volatile boolean enableAutomaticDeployment;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLModelCacheHelper mLModelCacheHelper, MLPredictTaskRunner mLPredictTaskRunner, ClusterService clusterService, Client client, NamedXContentRegistry namedXContentRegistry, MLModelManager mLModelManager, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mLFeatureEnabledSetting, Settings settings) {
        super("cluster:admin/opensearch/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mLPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = mLModelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.mlModelManager = mLModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mLFeatureEnabledSetting;
        this.enableAutomaticDeployment = ((Boolean) MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings)).booleanValue();
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, bool -> {
            this.enableAutomaticDeployment = bool.booleanValue();
        });
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLTaskResponse> actionListener) {
        final MLPredictionTaskRequest fromActionRequest = MLPredictionTaskRequest.fromActionRequest(actionRequest);
        final String modelId = fromActionRequest.getModelId();
        User user = fromActionRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            fromActionRequest.setUser(user);
        }
        final User user2 = user;
        final ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
        try {
            final ActionListener runBefore = ActionListener.runBefore(actionListener, () -> {
                stashContext.restore();
            });
            MLModel modelInfo = this.modelCacheHelper.getModelInfo(modelId);
            ActionListener<MLModel> actionListener2 = new ActionListener<MLModel>() { // from class: org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1
                public void onResponse(MLModel mLModel) {
                    stashContext.restore();
                    TransportPredictionTaskAction.this.modelCacheHelper.setModelInfo(modelId, mLModel);
                    FunctionName algorithm = mLModel.getAlgorithm();
                    if (FunctionName.isDLModel(algorithm) && !TransportPredictionTaskAction.this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
                        throw new IllegalStateException(MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG);
                    }
                    fromActionRequest.getMlInput().setAlgorithm(algorithm);
                    ModelAccessControlHelper modelAccessControlHelper = TransportPredictionTaskAction.this.modelAccessControlHelper;
                    User user3 = user2;
                    String modelGroupId = mLModel.getModelGroupId();
                    Client client = TransportPredictionTaskAction.this.client;
                    ActionListener actionListener3 = runBefore;
                    String str = modelId;
                    User user4 = user2;
                    MLPredictionTaskRequest mLPredictionTaskRequest = fromActionRequest;
                    CheckedConsumer checkedConsumer = bool -> {
                        if (!bool.booleanValue()) {
                            actionListener3.onFailure(new OpenSearchStatusException("User Doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                            return;
                        }
                        if (TransportPredictionTaskAction.this.modelCacheHelper.getIsModelEnabled(str) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getIsModelEnabled(str).booleanValue()) {
                            actionListener3.onFailure(new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN, new Object[0]));
                            return;
                        }
                        if (!FunctionName.isDLModel(algorithm)) {
                            TransportPredictionTaskAction.this.validateInputSchema(str, mLPredictionTaskRequest.getMlInput());
                            TransportPredictionTaskAction.this.executePredict(mLPredictionTaskRequest, actionListener3, str);
                            return;
                        }
                        if (TransportPredictionTaskAction.this.modelCacheHelper.getRateLimiter(str) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getRateLimiter(str).request()) {
                            actionListener3.onFailure(new OpenSearchStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            return;
                        }
                        if (user4 != null && TransportPredictionTaskAction.this.modelCacheHelper.getUserRateLimiter(str, user4.getName()) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getUserRateLimiter(str, user4.getName()).request()) {
                            actionListener3.onFailure(new OpenSearchStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                        } else {
                            TransportPredictionTaskAction.this.validateInputSchema(str, mLPredictionTaskRequest.getMlInput());
                            TransportPredictionTaskAction.this.executePredict(mLPredictionTaskRequest, actionListener3, str);
                        }
                    };
                    String str2 = modelId;
                    ActionListener actionListener4 = runBefore;
                    modelAccessControlHelper.validateModelGroupAccess(user3, modelGroupId, client, ActionListener.wrap(checkedConsumer, exc -> {
                        TransportPredictionTaskAction.log.error("Failed to Validate Access for ModelId " + str2, exc);
                        if (exc instanceof OpenSearchStatusException) {
                            actionListener4.onFailure(new OpenSearchStatusException(exc.getMessage(), RestStatus.fromCode(((OpenSearchStatusException) exc).status().getStatus()), new Object[0]));
                            return;
                        }
                        if (exc instanceof MLResourceNotFoundException) {
                            actionListener4.onFailure(new OpenSearchStatusException(exc.getMessage(), RestStatus.NOT_FOUND, new Object[0]));
                        } else if (exc instanceof CircuitBreakingException) {
                            actionListener4.onFailure(exc);
                        } else {
                            actionListener4.onFailure(new OpenSearchStatusException("Failed to Validate Access for ModelId " + str2, RestStatus.FORBIDDEN, new Object[0]));
                        }
                    }));
                }

                public void onFailure(Exception exc) {
                    TransportPredictionTaskAction.log.error("Failed to find model " + modelId, exc);
                    runBefore.onFailure(exc);
                }
            };
            if (modelInfo != null) {
                actionListener2.onResponse(modelInfo);
            } else {
                this.mlModelManager.getModel(modelId, actionListener2);
            }
            if (stashContext != null) {
                stashContext.close();
            }
        } catch (Throwable th) {
            if (stashContext != null) {
                try {
                    stashContext.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void executePredict(MLPredictionTaskRequest mLPredictionTaskRequest, ActionListener<MLTaskResponse> actionListener, String str) {
        String requestID = mLPredictionTaskRequest.getRequestID();
        log.debug("receive predict request " + requestID + " for model " + mLPredictionTaskRequest.getModelId());
        long nanoTime = System.nanoTime();
        this.mlPredictTaskRunner.run(this.modelCacheHelper.getOptionalFunctionName(str).orElse(mLPredictionTaskRequest.getMlInput().getAlgorithm()), mLPredictionTaskRequest, this.transportService, ActionListener.runAfter(actionListener, () -> {
            this.modelCacheHelper.addPredictRequestDuration(str, (System.nanoTime() - nanoTime) / 1000000.0d);
            this.modelCacheHelper.refreshLastAccessTime(str);
            log.debug("completed predict request " + requestID + " for model " + str);
        }));
    }

    public void validateInputSchema(String str, MLInput mLInput) {
        if (this.modelCacheHelper.getModelInterface(str) == null || this.modelCacheHelper.getModelInterface(str).get("input") == null) {
            return;
        }
        try {
            MLNodeUtils.validateSchema(this.modelCacheHelper.getModelInterface(str).get("input"), MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(mLInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()));
        } catch (Exception e) {
            throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]);
        }
    }
}
