package org.opensearch.ml.action.models;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.model.MLUpdateModelInput;
import org.opensearch.ml.common.transport.model.MLUpdateModelRequest;
import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheAction;
import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest;
import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/models/UpdateModelTransportAction.class */
public class UpdateModelTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(UpdateModelTransportAction.class);
    private Client client;
    private Settings settings;
    private ClusterService clusterService;
    private ModelAccessControlHelper modelAccessControlHelper;
    private ConnectorAccessControlHelper connectorAccessControlHelper;
    private MLModelManager mlModelManager;
    private MLModelGroupManager mlModelGroupManager;
    private MLEngine mlEngine;
    private volatile List<String> trustedConnectorEndpointsRegex;

    @Inject
    public UpdateModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client, ConnectorAccessControlHelper connectorAccessControlHelper, ModelAccessControlHelper modelAccessControlHelper, MLModelManager mLModelManager, MLModelGroupManager mLModelGroupManager, Settings settings, ClusterService clusterService, MLEngine mLEngine) {
        super("cluster:admin/opensearch/ml/models/update", transportService, actionFilters, MLUpdateModelRequest::new);
        this.client = client;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.mlModelManager = mLModelManager;
        this.mlModelGroupManager = mLModelGroupManager;
        this.clusterService = clusterService;
        this.mlEngine = mLEngine;
        this.settings = settings;
        this.trustedConnectorEndpointsRegex = (List) MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, list -> {
            this.trustedConnectorEndpointsRegex = list;
        });
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<UpdateResponse> actionListener) {
        MLUpdateModelInput updateModelInput = MLUpdateModelRequest.fromActionRequest(actionRequest).getUpdateModelInput();
        String modelId = updateModelInput.getModelId();
        User userContext = RestActionUtils.getUserContext(this.client);
        boolean isSuperAdminUserWrapper = isSuperAdminUserWrapper(this.clusterService, this.client);
        String[] strArr = {"model_content", "content"};
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                Objects.requireNonNull(stashContext);
                ActionListener runBefore = ActionListener.runBefore(actionListener, stashContext::restore);
                this.mlModelManager.getModel(modelId, null, strArr, ActionListener.wrap(mLModel -> {
                    if (isModelDeploying(mLModel.getModelState()).booleanValue()) {
                        runBefore.onFailure(new OpenSearchStatusException("Model is deploying. Please wait for the model to complete deployment. model ID " + modelId, RestStatus.CONFLICT, new Object[0]));
                        return;
                    }
                    FunctionName algorithm = mLModel.getAlgorithm();
                    if (algorithm != FunctionName.TEXT_EMBEDDING && algorithm != FunctionName.REMOTE) {
                        runBefore.onFailure(new OpenSearchStatusException("The function category " + algorithm.toString() + " is not supported at this time.", RestStatus.FORBIDDEN, new Object[0]));
                        return;
                    }
                    if (mLModel.getIsHidden() == null || !mLModel.getIsHidden().booleanValue()) {
                        this.modelAccessControlHelper.validateModelGroupAccess(userContext, mLModel.getModelGroupId(), this.client, ActionListener.wrap(bool -> {
                            if (bool.booleanValue()) {
                                updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mLModel, userContext, runBefore);
                            } else {
                                runBefore.onFailure(new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model, model ID " + modelId, RestStatus.FORBIDDEN, new Object[0]));
                            }
                        }, exc -> {
                            log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exc);
                            runBefore.onFailure(exc);
                        }));
                    } else if (isSuperAdminUserWrapper) {
                        updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mLModel, userContext, runBefore);
                    } else {
                        runBefore.onFailure(new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                    }
                }, exc -> {
                    runBefore.onFailure(new OpenSearchStatusException("Failed to find model to update with the provided model id: " + modelId, RestStatus.NOT_FOUND, new Object[0]));
                }));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Failed to update ML model for " + modelId, e);
            actionListener.onFailure(e);
        }
    }

    private void updateRemoteOrTextEmbeddingModel(String str, MLUpdateModelInput mLUpdateModelInput, MLModel mLModel, User user, ActionListener<UpdateResponse> actionListener) throws IOException {
        String modelGroupId = (!Strings.hasLength(mLUpdateModelInput.getModelGroupId()) || Objects.equals(mLUpdateModelInput.getModelGroupId(), mLModel.getModelGroupId())) ? null : mLUpdateModelInput.getModelGroupId();
        String connectorId = Strings.hasLength(mLUpdateModelInput.getConnectorId()) ? mLUpdateModelInput.getConnectorId() : null;
        boolean booleanValue = isModelDeployed(mLModel.getModelState()).booleanValue();
        boolean z = (mLUpdateModelInput.getConnector() == null && connectorId == null && Objects.equals(mLUpdateModelInput.getIsEnabled(), mLModel.getIsEnabled()) && mLUpdateModelInput.getGuardrails() == null && mLUpdateModelInput.getModelInterface() == null) ? false : true;
        if (MLRateLimiter.updateValidityPreCheck(mLModel.getRateLimiter(), mLUpdateModelInput.getRateLimiter())) {
            MLRateLimiter update = MLRateLimiter.update(mLModel.getRateLimiter(), mLUpdateModelInput.getRateLimiter());
            mLUpdateModelInput.setRateLimiter(update);
            z = z || update.isValid();
        }
        boolean z2 = z && booleanValue;
        if (mLModel.getAlgorithm() == FunctionName.TEXT_EMBEDDING) {
            if (connectorId == null && mLUpdateModelInput.getConnector() == null) {
                updateModelWithRegisteringToAnotherModelGroup(str, modelGroupId, user, mLUpdateModelInput, actionListener, z2);
                return;
            } else {
                actionListener.onFailure(new OpenSearchStatusException("Trying to update the connector or connector_id field on a local model.", RestStatus.BAD_REQUEST, new Object[0]));
                return;
            }
        }
        if (connectorId != null) {
            updateModelWithNewStandAloneConnector(str, modelGroupId, connectorId, mLModel, user, mLUpdateModelInput, actionListener, z2);
            return;
        }
        if (mLUpdateModelInput.getConnector() != null) {
            Connector connector = mLModel.getConnector();
            MLCreateConnectorInput connector2 = mLUpdateModelInput.getConnector();
            MLEngine mLEngine = this.mlEngine;
            Objects.requireNonNull(mLEngine);
            connector.update(connector2, mLEngine::encrypt);
            connector.validateConnectorURL(this.trustedConnectorEndpointsRegex);
            mLUpdateModelInput.setUpdatedConnector(connector);
            mLUpdateModelInput.setConnector((MLCreateConnectorInput) null);
        }
        updateModelWithRegisteringToAnotherModelGroup(str, modelGroupId, user, mLUpdateModelInput, actionListener, z2);
    }

    private void updateModelWithNewStandAloneConnector(String str, String str2, String str3, MLModel mLModel, User user, MLUpdateModelInput mLUpdateModelInput, ActionListener<UpdateResponse> actionListener, boolean z) {
        if (Strings.hasLength(mLModel.getConnectorId())) {
            this.connectorAccessControlHelper.validateConnectorAccess(this.client, str3, ActionListener.wrap(bool -> {
                if (bool.booleanValue()) {
                    updateModelWithRegisteringToAnotherModelGroup(str, str2, user, mLUpdateModelInput, actionListener, z);
                } else {
                    actionListener.onFailure(new OpenSearchStatusException("You don't have permission to update the connector, connector id: " + str3, RestStatus.FORBIDDEN, new Object[0]));
                }
            }, exc -> {
                log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", str3, exc);
                actionListener.onFailure(exc);
            }));
        } else {
            actionListener.onFailure(new OpenSearchStatusException("This remote does not have a connector_id field, maybe it uses an internal connector.", RestStatus.BAD_REQUEST, new Object[0]));
        }
    }

    private void updateModelWithRegisteringToAnotherModelGroup(String str, String str2, User user, MLUpdateModelInput mLUpdateModelInput, ActionListener<UpdateResponse> actionListener, boolean z) {
        UpdateRequest refreshPolicy = new UpdateRequest(".plugins-ml-model", str).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        if (str2 != null) {
            this.modelAccessControlHelper.validateModelGroupAccess(user, str2, this.client, ActionListener.wrap(bool -> {
                if (bool.booleanValue()) {
                    this.mlModelGroupManager.getModelGroupResponse(str2, ActionListener.wrap(getResponse -> {
                        buildUpdateRequest(str, str2, refreshPolicy, mLUpdateModelInput, getResponse, actionListener, z);
                    }, exc -> {
                        actionListener.onFailure(new OpenSearchStatusException("Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + str2, RestStatus.NOT_FOUND, new Object[0]));
                    }));
                } else {
                    actionListener.onFailure(new OpenSearchStatusException("User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " + str2, RestStatus.FORBIDDEN, new Object[0]));
                }
            }, exc -> {
                log.error("Permission denied: Unable to update the model with ID {}. Details: {}", str, exc);
                actionListener.onFailure(exc);
            }));
        } else {
            buildUpdateRequest(str, refreshPolicy, mLUpdateModelInput, actionListener, z);
        }
    }

    private void buildUpdateRequest(String str, UpdateRequest updateRequest, MLUpdateModelInput mLUpdateModelInput, ActionListener<UpdateResponse> actionListener, boolean z) {
        try {
            mLUpdateModelInput.setLastUpdateTime(Instant.now());
            updateRequest.doc(mLUpdateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
            updateRequest.docAsUpsert(true);
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            if (z) {
                this.client.update(updateRequest, getUpdateResponseListenerWithUpdateModelCache(str, actionListener, new MLUpdateModelCacheNodesRequest(getAllNodes(), str)));
            } else {
                this.client.update(updateRequest, getUpdateResponseListener(str, actionListener));
            }
        } catch (IOException e) {
            log.error("Failed to build update request.", e);
            actionListener.onFailure(e);
        }
    }

    private void buildUpdateRequest(String str, String str2, UpdateRequest updateRequest, MLUpdateModelInput mLUpdateModelInput, GetResponse getResponse, ActionListener<UpdateResponse> actionListener, boolean z) {
        Map<String, Object> sourceAsMap = getResponse.getSourceAsMap();
        String incrementLatestVersion = incrementLatestVersion(sourceAsMap);
        mLUpdateModelInput.setVersion(incrementLatestVersion);
        mLUpdateModelInput.setLastUpdateTime(Instant.now());
        UpdateRequest createUpdateModelGroupRequest = createUpdateModelGroupRequest(sourceAsMap, str2, getResponse.getSeqNo(), getResponse.getPrimaryTerm(), Integer.parseInt(incrementLatestVersion));
        try {
            updateRequest.doc(mLUpdateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
            updateRequest.docAsUpsert(true);
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            if (z) {
                MLUpdateModelCacheNodesRequest mLUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(getAllNodes(), str);
                this.client.update(createUpdateModelGroupRequest, ActionListener.wrap(updateResponse -> {
                    this.client.update(updateRequest, getUpdateResponseListenerWithUpdateModelCache(str, actionListener, mLUpdateModelCacheNodesRequest));
                }, exc -> {
                    log.error("Failed to register ML model with model ID {} to the new model group with model group ID {}", str, str2, exc);
                    actionListener.onFailure(exc);
                }));
            } else {
                this.client.update(createUpdateModelGroupRequest, ActionListener.wrap(updateResponse2 -> {
                    this.client.update(updateRequest, getUpdateResponseListener(str, actionListener));
                }, exc2 -> {
                    log.error("Failed to register ML model with model ID {} to the new model group with model group ID {}", str, str2, exc2);
                    actionListener.onFailure(exc2);
                }));
            }
        } catch (IOException e) {
            log.error("Failed to build update request.");
            actionListener.onFailure(e);
        }
    }

    private ActionListener<UpdateResponse> getUpdateResponseListenerWithUpdateModelCache(String str, ActionListener<UpdateResponse> actionListener, MLUpdateModelCacheNodesRequest mLUpdateModelCacheNodesRequest) {
        return ActionListener.wrap(updateResponse -> {
            if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
                this.client.execute(MLUpdateModelCacheAction.INSTANCE, mLUpdateModelCacheNodesRequest, ActionListener.wrap(mLUpdateModelCacheNodesResponse -> {
                    if (mLUpdateModelCacheNodesResponse != null && isUpdateModelCacheSuccessOnAllNodes(mLUpdateModelCacheNodesResponse)) {
                        log.info("Successfully updated ML model cache with model ID {}", str);
                        actionListener.onResponse(updateResponse);
                    } else {
                        String[] updateModelCacheFailedNodesList = getUpdateModelCacheFailedNodesList(mLUpdateModelCacheNodesResponse);
                        log.error("Successfully update ML model index with model ID {} but update model cache was failed on following nodes {}, please retry or redeploy model manually.", str, Arrays.toString(updateModelCacheFailedNodesList));
                        actionListener.onFailure(new RuntimeException("Successfully update ML model index with model ID " + str + " but update model cache was failed on following nodes " + Arrays.toString(updateModelCacheFailedNodesList) + ", please retry or redeploy model manually."));
                    }
                }, exc -> {
                    log.error("Failed to update ML model cache for model: " + str, exc);
                    actionListener.onFailure(exc);
                }));
                return;
            }
            if (updateResponse == null || updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
                log.error("Failed to update ML model: " + str);
                actionListener.onFailure(new RuntimeException("Failed to update ML model: " + str));
            } else {
                log.warn("Update model for model {} got a result status other than update, result status: {}", str, updateResponse.getResult());
                actionListener.onResponse(updateResponse);
            }
        }, exc -> {
            log.error("Failed to update ML model: " + str, exc);
            actionListener.onFailure(exc);
        });
    }

    private ActionListener<UpdateResponse> getUpdateResponseListener(String str, ActionListener<UpdateResponse> actionListener) {
        return ActionListener.wrap(updateResponse -> {
            if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
                log.info("Successfully update ML model with model ID {}", str);
                actionListener.onResponse(updateResponse);
            } else if (updateResponse == null || updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
                log.error("Failed to update ML model: " + str);
                actionListener.onFailure(new RuntimeException("Failed to update ML model: " + str));
            } else {
                log.warn("Update model for model {} got a result status other than update, result status: {}", str, updateResponse.getResult());
                actionListener.onResponse(updateResponse);
            }
        }, exc -> {
            log.error("Failed to update ML model: " + str, exc);
            actionListener.onFailure(exc);
        });
    }

    private String incrementLatestVersion(Map<String, Object> map) {
        return Integer.toString(((Integer) map.get("latest_version")).intValue() + 1);
    }

    private UpdateRequest createUpdateModelGroupRequest(Map<String, Object> map, String str, long j, long j2, int i) {
        map.put("latest_version", Integer.valueOf(i));
        map.put("last_updated_time", Long.valueOf(Instant.now().toEpochMilli()));
        UpdateRequest updateRequest = new UpdateRequest();
        updateRequest.index(".plugins-ml-model-group").id(str).setIfSeqNo(j).setIfPrimaryTerm(j2).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).doc(map);
        return updateRequest;
    }

    private Boolean isModelDeployed(MLModelState mLModelState) {
        return Boolean.valueOf(mLModelState.equals(MLModelState.LOADED) || mLModelState.equals(MLModelState.PARTIALLY_LOADED) || mLModelState.equals(MLModelState.DEPLOYED) || mLModelState.equals(MLModelState.PARTIALLY_DEPLOYED));
    }

    private Boolean isModelDeploying(MLModelState mLModelState) {
        return Boolean.valueOf(mLModelState.equals(MLModelState.LOADING) || mLModelState.equals(MLModelState.DEPLOYING));
    }

    private String[] getAllNodes() {
        Iterator it = this.clusterService.state().nodes().iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            arrayList.add(((DiscoveryNode) it.next()).getId());
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    private boolean isUpdateModelCacheSuccessOnAllNodes(MLUpdateModelCacheNodesResponse mLUpdateModelCacheNodesResponse) {
        return mLUpdateModelCacheNodesResponse.failures() == null || mLUpdateModelCacheNodesResponse.failures().isEmpty();
    }

    private String[] getUpdateModelCacheFailedNodesList(MLUpdateModelCacheNodesResponse mLUpdateModelCacheNodesResponse) {
        if (mLUpdateModelCacheNodesResponse == null) {
            return getAllNodes();
        }
        ArrayList arrayList = new ArrayList();
        Iterator it = mLUpdateModelCacheNodesResponse.failures().iterator();
        while (it.hasNext()) {
            arrayList.add(((FailedNodeException) it.next()).nodeId());
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    @VisibleForTesting
    boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
        return RestActionUtils.isSuperAdminUser(clusterService, client);
    }
}
