package org.opensearch.ml.action.controller;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import lombok.Generated;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.transport.controller.MLDeployControllerAction;
import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest;
import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse;
import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/controller/UpdateControllerTransportAction.class */
public class UpdateControllerTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(UpdateControllerTransportAction.class);
    private final Client client;
    private final MLModelManager mlModelManager;
    private final MLModelCacheHelper mlModelCacheHelper;
    private final ClusterService clusterService;
    private final ModelAccessControlHelper modelAccessControlHelper;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public UpdateControllerTransportAction(TransportService transportService, ActionFilters actionFilters, Client client, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mLModelCacheHelper, MLModelManager mLModelManager, MLFeatureEnabledSetting mLFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/controllers/update", transportService, actionFilters, MLUpdateControllerRequest::new);
        this.client = client;
        this.mlModelManager = mLModelManager;
        this.clusterService = clusterService;
        this.mlModelCacheHelper = mLModelCacheHelper;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mLFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<UpdateResponse> actionListener) {
        MLController updateControllerInput = MLUpdateControllerRequest.fromActionRequest(actionRequest).getUpdateControllerInput();
        String modelId = updateControllerInput.getModelId();
        User userContext = RestActionUtils.getUserContext(this.client);
        String[] strArr = {"model_content", "content"};
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                if (!this.mlFeatureEnabledSetting.isControllerEnabled().booleanValue()) {
                    throw new IllegalStateException(MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG);
                }
                Objects.requireNonNull(stashContext);
                ActionListener runBefore = ActionListener.runBefore(actionListener, stashContext::restore);
                this.mlModelManager.getModel(modelId, null, strArr, ActionListener.wrap(mLModel -> {
                    FunctionName algorithm = mLModel.getAlgorithm();
                    Boolean isHidden = mLModel.getIsHidden();
                    if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
                        this.modelAccessControlHelper.validateModelGroupAccess(userContext, mLModel.getModelGroupId(), this.client, ActionListener.wrap(bool -> {
                            if (bool.booleanValue()) {
                                this.mlModelManager.getController(modelId, ActionListener.wrap(mLController -> {
                                    boolean isDeployRequiredAfterUpdate = mLController.isDeployRequiredAfterUpdate(updateControllerInput);
                                    mLController.update(updateControllerInput);
                                    updateController(mLModel, mLController, isDeployRequiredAfterUpdate, runBefore);
                                }, exc -> {
                                    if (mLModel.getIsControllerEnabled() != null && mLModel.getIsControllerEnabled().booleanValue()) {
                                        log.error(exc);
                                        runBefore.onFailure(exc);
                                    } else {
                                        String errorMessage = StringUtils.getErrorMessage("Model controller haven't been created for the model. Consider calling create model controller api instead.", modelId, isHidden);
                                        runBefore.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.CONFLICT, new Object[0]));
                                        log.error(errorMessage, exc);
                                    }
                                }));
                            } else {
                                runBefore.onFailure(new OpenSearchStatusException(StringUtils.getErrorMessage("User doesn't have privilege to perform this operation on this model controller.", modelId, isHidden), RestStatus.FORBIDDEN, new Object[0]));
                            }
                        }, exc -> {
                            log.error(StringUtils.getErrorMessage("Permission denied: Unable to create the model controller for the model. Details: ", modelId, isHidden), exc);
                            runBefore.onFailure(exc);
                        }));
                    } else {
                        runBefore.onFailure(new OpenSearchStatusException("Creating model controller on this operation on the function category " + algorithm.toString() + " is not supported.", RestStatus.FORBIDDEN, new Object[0]));
                    }
                }, exc -> {
                    runBefore.onFailure(new OpenSearchStatusException("Failed to find model to create the corresponding model controller with the provided model ID", RestStatus.NOT_FOUND, new Object[0]));
                }));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Failed to create model controller for the provided model", e);
            actionListener.onFailure(e);
        }
    }

    private void updateController(MLModel mLModel, MLController mLController, boolean z, ActionListener<UpdateResponse> actionListener) {
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                String modelId = mLModel.getModelId();
                Boolean isHidden = mLModel.getIsHidden();
                CheckedConsumer checkedConsumer = updateResponse -> {
                    if (updateResponse == null || updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
                        if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
                            log.warn(StringUtils.getErrorMessage("Update model controller got a result status other than update, result status: {}", modelId, isHidden), updateResponse.getResult());
                            actionListener.onResponse(updateResponse);
                            return;
                        } else {
                            String errorMessage = StringUtils.getErrorMessage("Failed to update model controller.", modelId, isHidden);
                            log.error(errorMessage);
                            actionListener.onFailure(new RuntimeException(errorMessage));
                            return;
                        }
                    }
                    log.info(StringUtils.getErrorMessage("Model controller successfully updated to index, result: {}", modelId, isHidden), updateResponse.getResult());
                    if (ArrayUtils.isEmpty(this.mlModelCacheHelper.getWorkerNodes(modelId)) || !z) {
                        actionListener.onResponse(updateResponse);
                        return;
                    }
                    log.info(StringUtils.getErrorMessage("The model is deployed and the user rate limiter config is constructable. Start to deploy the model controller into cache.", modelId, isHidden));
                    this.client.execute(MLDeployControllerAction.INSTANCE, new MLDeployControllerNodesRequest(this.mlModelManager.getWorkerNodes(modelId, mLModel.getAlgorithm()), modelId), ActionListener.wrap(mLDeployControllerNodesResponse -> {
                        if (mLDeployControllerNodesResponse != null && isDeployControllerSuccessOnAllNodes(mLDeployControllerNodesResponse)) {
                            log.info(StringUtils.getErrorMessage("Successfully update model controller and deploy it into cache", modelId, isHidden));
                            actionListener.onResponse(updateResponse);
                        } else {
                            String errorMessage2 = StringUtils.getErrorMessage("Successfully update model controller index but deploy model controller to cache was failed on following nodes " + Arrays.toString(getDeployControllerFailedNodesList(mLDeployControllerNodesResponse)) + ", please retry.", modelId, isHidden);
                            log.error(errorMessage2);
                            actionListener.onFailure(new RuntimeException(errorMessage2));
                        }
                    }, exc -> {
                        log.error(StringUtils.getErrorMessage("Failed to deploy model controller for model", modelId, isHidden));
                        actionListener.onFailure(exc);
                    }));
                };
                Objects.requireNonNull(actionListener);
                ActionListener wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
                UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-controller", modelId);
                updateRequest.doc(mLController.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
                updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                Client client = this.client;
                Objects.requireNonNull(stashContext);
                client.update(updateRequest, ActionListener.runBefore(wrap, stashContext::restore));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Failed to update model controller.", e);
            actionListener.onFailure(e);
        }
    }

    private boolean isDeployControllerSuccessOnAllNodes(MLDeployControllerNodesResponse mLDeployControllerNodesResponse) {
        return mLDeployControllerNodesResponse.failures() == null || mLDeployControllerNodesResponse.failures().isEmpty();
    }

    private String[] getDeployControllerFailedNodesList(MLDeployControllerNodesResponse mLDeployControllerNodesResponse) {
        if (mLDeployControllerNodesResponse == null) {
            return getAllNodes();
        }
        ArrayList arrayList = new ArrayList();
        Iterator it = mLDeployControllerNodesResponse.failures().iterator();
        while (it.hasNext()) {
            arrayList.add(((FailedNodeException) it.next()).nodeId());
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    private String[] getAllNodes() {
        Iterator it = this.clusterService.state().nodes().iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            arrayList.add(((DiscoveryNode) it.next()).getId());
        }
        return (String[]) arrayList.toArray(new String[0]);
    }
}
