package org.opensearch.ml.action.model_group;

import com.google.common.collect.ImmutableList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupRequest;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.class */
public class TransportUpdateModelGroupAction extends HandledTransportAction<ActionRequest, MLUpdateModelGroupResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportUpdateModelGroupAction.class);
    private final TransportService transportService;
    private final ActionFilters actionFilters;
    private Client client;
    private NamedXContentRegistry xContentRegistry;
    ClusterService clusterService;
    ModelAccessControlHelper modelAccessControlHelper;
    MLModelGroupManager mlModelGroupManager;

    @Inject
    public TransportUpdateModelGroupAction(TransportService transportService, ActionFilters actionFilters, Client client, NamedXContentRegistry namedXContentRegistry, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelGroupManager mLModelGroupManager) {
        super("cluster:admin/opensearch/ml/update_model_group", transportService, actionFilters, MLUpdateModelGroupRequest::new);
        this.actionFilters = actionFilters;
        this.transportService = transportService;
        this.client = client;
        this.xContentRegistry = namedXContentRegistry;
        this.clusterService = clusterService;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlModelGroupManager = mLModelGroupManager;
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLUpdateModelGroupResponse> actionListener) {
        MLUpdateModelGroupInput updateModelGroupInput = MLUpdateModelGroupRequest.fromActionRequest(actionRequest).getUpdateModelGroupInput();
        String modelGroupID = updateModelGroupInput.getModelGroupID();
        User userContext = RestActionUtils.getUserContext(this.client);
        GetRequest id = new GetRequest(".plugins-ml-model-group").id(modelGroupID);
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                ActionListener runBefore = ActionListener.runBefore(actionListener, () -> {
                    stashContext.restore();
                });
                this.client.get(id, ActionListener.wrap(getResponse -> {
                    if (!getResponse.isExists()) {
                        runBefore.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND, new Object[0]));
                        return;
                    }
                    try {
                        XContentParser createXContentParserFromRegistry = MLNodeUtils.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, getResponse.getSourceAsBytesRef());
                        try {
                            XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, createXContentParserFromRegistry.nextToken(), createXContentParserFromRegistry);
                            MLModelGroup parse = MLModelGroup.parse(createXContentParserFromRegistry);
                            if (this.modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(userContext)) {
                                validateRequestForAccessControl(updateModelGroupInput, userContext, parse);
                            } else {
                                validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
                            }
                            updateModelGroup(modelGroupID, getResponse.getSource(), updateModelGroupInput, runBefore, userContext);
                            if (createXContentParserFromRegistry != null) {
                                createXContentParserFromRegistry.close();
                            }
                        } finally {
                        }
                    } catch (Exception e) {
                        log.error("Failed to parse ml model group" + getResponse.getId(), e);
                        runBefore.onFailure(e);
                    }
                }, exc -> {
                    if (exc instanceof IndexNotFoundException) {
                        runBefore.onFailure(new MLResourceNotFoundException("Fail to find model group"));
                    } else {
                        MLExceptionUtils.logException("Failed to get model group", exc, log);
                        runBefore.onFailure(exc);
                    }
                }));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            MLExceptionUtils.logException("Failed to Update model group", e, log);
            actionListener.onFailure(e);
        }
    }

    private void updateModelGroup(String str, Map<String, Object> map, MLUpdateModelGroupInput mLUpdateModelGroupInput, ActionListener<MLUpdateModelGroupResponse> actionListener, User user) {
        String str2 = (String) map.get("name");
        if (mLUpdateModelGroupInput.getModelAccessMode() != null) {
            map.put("access", mLUpdateModelGroupInput.getModelAccessMode().getValue());
            if (AccessMode.RESTRICTED != mLUpdateModelGroupInput.getModelAccessMode()) {
                map.put("backend_roles", ImmutableList.of());
            }
        } else if (mLUpdateModelGroupInput.getBackendRoles() != null || Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles())) {
            map.put("access", AccessMode.RESTRICTED.getValue());
        }
        if (mLUpdateModelGroupInput.getBackendRoles() != null) {
            map.put("backend_roles", mLUpdateModelGroupInput.getBackendRoles());
        }
        if (Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles())) {
            map.put("backend_roles", user.getBackendRoles());
        }
        if (StringUtils.isNotBlank(mLUpdateModelGroupInput.getDescription())) {
            map.put("description", mLUpdateModelGroupInput.getDescription());
        }
        if (!StringUtils.isNotBlank(mLUpdateModelGroupInput.getName()) || mLUpdateModelGroupInput.getName().equals(str2)) {
            updateModelGroup(str, map, actionListener);
        } else {
            this.mlModelGroupManager.validateUniqueModelGroupName(mLUpdateModelGroupInput.getName(), ActionListener.wrap(searchResponse -> {
                if (searchResponse == null || searchResponse.getHits().getTotalHits() == null || searchResponse.getHits().getTotalHits().value == 0) {
                    map.put("name", mLUpdateModelGroupInput.getName());
                    updateModelGroup(str, map, actionListener);
                } else {
                    Iterator it = searchResponse.getHits().iterator();
                    while (it.hasNext()) {
                        actionListener.onFailure(new IllegalArgumentException("The name you provided is already being used by another model with ID: " + ((SearchHit) it.next()).getId() + ". Please provide a different name"));
                    }
                }
            }, exc -> {
                log.error("Failed to search model group index", exc);
                actionListener.onFailure(exc);
            }));
        }
    }

    private void updateModelGroup(String str, Map<String, Object> map, ActionListener<MLUpdateModelGroupResponse> actionListener) {
        UpdateRequest updateRequest = new UpdateRequest();
        updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        updateRequest.index(".plugins-ml-model-group").id(str).doc(map);
        try {
            ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
            try {
                ActionListener runBefore = ActionListener.runBefore(actionListener, () -> {
                    stashContext.restore();
                });
                this.client.update(updateRequest, ActionListener.wrap(updateResponse -> {
                    runBefore.onResponse(new MLUpdateModelGroupResponse("Updated"));
                }, exc -> {
                    if (exc instanceof IndexNotFoundException) {
                        runBefore.onFailure(new MLResourceNotFoundException("Fail to find model group"));
                    } else {
                        log.error("Failed to update model group", exc, log);
                        runBefore.onFailure(new MLValidationException("Failed to update Model Group"));
                    }
                }));
                if (stashContext != null) {
                    stashContext.close();
                }
            } finally {
            }
        } catch (Exception e) {
            MLExceptionUtils.logException("Failed to Update model group ", e, log);
            actionListener.onFailure(e);
        }
    }

    private void validateRequestForAccessControl(MLUpdateModelGroupInput mLUpdateModelGroupInput, User user, MLModelGroup mLModelGroup) {
        if (hasAccessControlChange(mLUpdateModelGroupInput) && !this.modelAccessControlHelper.isOwner(mLModelGroup.getOwner(), user) && !this.modelAccessControlHelper.isAdmin(user)) {
            throw new IllegalArgumentException("Only owner or admin can update access control data.");
        }
        if (!this.modelAccessControlHelper.isAdmin(user) && !this.modelAccessControlHelper.isOwner(mLModelGroup.getOwner(), user) && !this.modelAccessControlHelper.isUserHasBackendRole(user, mLModelGroup)) {
            throw new IllegalArgumentException("You don't have permission to update this model group.");
        }
        if (this.modelAccessControlHelper.isOwner(mLModelGroup.getOwner(), user) && !this.modelAccessControlHelper.isAdmin(user) && !this.modelAccessControlHelper.isOwnerStillHasPermission(user, mLModelGroup)) {
            throw new IllegalArgumentException("You don't have the specified backend role to update this model group. For more information, contact your administrator.");
        }
        AccessMode modelAccessMode = mLUpdateModelGroupInput.getModelAccessMode();
        if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode) && (!CollectionUtils.isEmpty(mLUpdateModelGroupInput.getBackendRoles()) || Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles()))) {
            throw new IllegalArgumentException("You can specify backend roles only for a model group with the restricted access mode.");
        }
        if (modelAccessMode == null || AccessMode.RESTRICTED == modelAccessMode) {
            if (this.modelAccessControlHelper.isAdmin(user) && Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles())) {
                throw new IllegalArgumentException("Admin users cannot add all backend roles to a model group.");
            }
            if (Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles()) && CollectionUtils.isEmpty(user.getBackendRoles())) {
                throw new IllegalArgumentException("You don't have any backend roles.");
            }
            if (CollectionUtils.isEmpty(mLUpdateModelGroupInput.getBackendRoles()) && Boolean.FALSE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles())) {
                throw new IllegalArgumentException("You have to specify backend roles when add all backend roles is set to false.");
            }
            if (!CollectionUtils.isEmpty(mLUpdateModelGroupInput.getBackendRoles()) && Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles())) {
                throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
            }
            if (AccessMode.RESTRICTED == modelAccessMode && CollectionUtils.isEmpty(mLUpdateModelGroupInput.getBackendRoles()) && !Boolean.TRUE.equals(mLUpdateModelGroupInput.getIsAddAllBackendRoles())) {
                throw new IllegalArgumentException("You must specify one or more backend roles or add all backend roles to register a restricted model group.");
            }
            if (!this.modelAccessControlHelper.isAdmin(user) && !CollectionUtils.isEmpty(mLUpdateModelGroupInput.getBackendRoles()) && !new HashSet(user.getBackendRoles()).containsAll(mLUpdateModelGroupInput.getBackendRoles())) {
                throw new IllegalArgumentException("You don't have the backend roles specified.");
            }
        }
    }

    private boolean hasAccessControlChange(MLUpdateModelGroupInput mLUpdateModelGroupInput) {
        return (mLUpdateModelGroupInput.getModelAccessMode() == null && mLUpdateModelGroupInput.getIsAddAllBackendRoles() == null && mLUpdateModelGroupInput.getBackendRoles() == null) ? false : true;
    }

    private void validateSecurityDisabledOrModelAccessControlDisabled(MLUpdateModelGroupInput mLUpdateModelGroupInput) {
        if (mLUpdateModelGroupInput.getModelAccessMode() != null || mLUpdateModelGroupInput.getIsAddAllBackendRoles() != null || !CollectionUtils.isEmpty(mLUpdateModelGroupInput.getBackendRoles())) {
            throw new IllegalArgumentException("You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.");
        }
    }
}
