package org.opensearch.ml.action.register;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.regex.Pattern;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/register/TransportRegisterModelAction.class */
public class TransportRegisterModelAction extends HandledTransportAction<ActionRequest, MLRegisterModelResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportRegisterModelAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLIndicesHandler mlIndicesHandler;
    MLModelManager mlModelManager;
    MLTaskManager mlTaskManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    DiscoveryNodeHelper nodeFilter;
    MLTaskDispatcher mlTaskDispatcher;
    MLStats mlStats;
    volatile String trustedUrlRegex;
    private List<String> trustedConnectorEndpointsRegex;
    ModelAccessControlHelper modelAccessControlHelper;
    ConnectorAccessControlHelper connectorAccessControlHelper;
    MLModelGroupManager mlModelGroupManager;

    @Inject
    public TransportRegisterModelAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLIndicesHandler mLIndicesHandler, MLModelManager mLModelManager, MLTaskManager mLTaskManager, ClusterService clusterService, Settings settings, ThreadPool threadPool, Client client, DiscoveryNodeHelper discoveryNodeHelper, MLTaskDispatcher mLTaskDispatcher, MLStats mLStats, ModelAccessControlHelper modelAccessControlHelper, ConnectorAccessControlHelper connectorAccessControlHelper, MLModelGroupManager mLModelGroupManager) {
        super("cluster:admin/opensearch/ml/register_model", transportService, actionFilters, MLRegisterModelRequest::new);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlIndicesHandler = mLIndicesHandler;
        this.mlModelManager = mLModelManager;
        this.mlTaskManager = mLTaskManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.nodeFilter = discoveryNodeHelper;
        this.mlTaskDispatcher = mLTaskDispatcher;
        this.mlStats = mLStats;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.mlModelGroupManager = mLModelGroupManager;
        this.trustedUrlRegex = (String) MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, str -> {
            this.trustedUrlRegex = str;
        });
        this.trustedConnectorEndpointsRegex = (List) MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, list -> {
            this.trustedConnectorEndpointsRegex = list;
        });
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLRegisterModelResponse> actionListener) {
        MLRegisterModelInput registerModelInput = MLRegisterModelRequest.fromActionRequest(actionRequest).getRegisterModelInput();
        if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
            this.mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(searchResponse -> {
                if (searchResponse == null || searchResponse.getHits().getTotalHits() == null || searchResponse.getHits().getTotalHits().value == 0) {
                    doRegister(registerModelInput, actionListener);
                } else {
                    registerModelInput.setModelGroupId(searchResponse.getHits().getAt(0).getId());
                    checkUserAccess(registerModelInput, actionListener, true);
                }
            }, exc -> {
                log.error("Failed to search model group index", exc);
                actionListener.onFailure(exc);
            }));
        } else {
            checkUserAccess(registerModelInput, actionListener, false);
        }
    }

    private void checkUserAccess(MLRegisterModelInput mLRegisterModelInput, ActionListener<MLRegisterModelResponse> actionListener, Boolean bool) {
        User userContext = RestActionUtils.getUserContext(this.client);
        ModelAccessControlHelper modelAccessControlHelper = this.modelAccessControlHelper;
        String modelGroupId = mLRegisterModelInput.getModelGroupId();
        Client client = this.client;
        CheckedConsumer checkedConsumer = bool2 -> {
            if (bool2.booleanValue()) {
                doRegister(mLRegisterModelInput, actionListener);
                return;
            }
            if (!bool.booleanValue()) {
                actionListener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
            } else if (mLRegisterModelInput.getUrl() == null && mLRegisterModelInput.getFunctionName() != FunctionName.REMOTE && mLRegisterModelInput.getConnectorId() == null) {
                actionListener.onFailure(new IllegalArgumentException("Without a model group ID, the system will use the model name {" + mLRegisterModelInput.getModelName() + "} to create a new model group. However, this name is taken by another group with id {" + mLRegisterModelInput.getModelGroupId() + "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."));
            } else {
                actionListener.onFailure(new IllegalArgumentException("The name {" + mLRegisterModelInput.getModelName() + "} you provided is unavailable because it is used by another model group with id {" + mLRegisterModelInput.getModelGroupId() + "} to which you do not have access. Please provide a different name."));
            }
        };
        Objects.requireNonNull(actionListener);
        modelAccessControlHelper.validateModelGroupAccess(userContext, modelGroupId, client, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void doRegister(MLRegisterModelInput mLRegisterModelInput, ActionListener<MLRegisterModelResponse> actionListener) {
        if (FunctionName.REMOTE != mLRegisterModelInput.getFunctionName()) {
            createModelGroup(mLRegisterModelInput, actionListener);
            return;
        }
        if (Strings.isNotBlank(mLRegisterModelInput.getConnectorId())) {
            this.connectorAccessControlHelper.validateConnectorAccess(this.client, mLRegisterModelInput.getConnectorId(), ActionListener.wrap(bool -> {
                if (Boolean.TRUE.equals(bool)) {
                    createModelGroup(mLRegisterModelInput, actionListener);
                } else {
                    actionListener.onFailure(new IllegalArgumentException("You don't have permission to use the connector provided, connector id: " + mLRegisterModelInput.getConnectorId()));
                }
            }, exc -> {
                log.error("You don't have permission to use the connector provided, connector id: " + mLRegisterModelInput.getConnectorId(), exc);
                actionListener.onFailure(exc);
            }));
            return;
        }
        validateInternalConnector(mLRegisterModelInput);
        ActionListener wrap = ActionListener.wrap(mLCreateConnectorResponse -> {
            log.info("Dry run create connector successfully");
            createModelGroup(mLRegisterModelInput, actionListener);
        }, exc2 -> {
            log.error(exc2.getMessage(), exc2);
            actionListener.onFailure(exc2);
        });
        this.client.execute(MLCreateConnectorAction.INSTANCE, createDryRunConnectorRequest(), wrap);
    }

    private void createModelGroup(MLRegisterModelInput mLRegisterModelInput, ActionListener<MLRegisterModelResponse> actionListener) {
        if (Strings.isEmpty(mLRegisterModelInput.getModelGroupId())) {
            this.mlModelGroupManager.createModelGroup(createRegisterModelGroupRequest(mLRegisterModelInput), ActionListener.wrap(str -> {
                mLRegisterModelInput.setModelGroupId(str);
                mLRegisterModelInput.setDoesVersionCreateModelGroup(true);
                registerModel(mLRegisterModelInput, actionListener);
            }, exc -> {
                MLExceptionUtils.logException("Failed to create Model Group", exc, log);
                actionListener.onFailure(exc);
            }));
        } else {
            mLRegisterModelInput.setDoesVersionCreateModelGroup(false);
            registerModel(mLRegisterModelInput, actionListener);
        }
    }

    private MLCreateConnectorRequest createDryRunConnectorRequest() {
        return new MLCreateConnectorRequest(MLCreateConnectorInput.builder().dryRun(true).build());
    }

    private void validateInternalConnector(MLRegisterModelInput mLRegisterModelInput) {
        if (mLRegisterModelInput.getConnector() == null) {
            log.error("You must provide connector content when creating a remote model without providing connector id!");
            throw new IllegalArgumentException("You must provide connector content when creating a remote model without connector id!");
        }
        if (mLRegisterModelInput.getConnector().getPredictEndpoint(mLRegisterModelInput.getConnector().getParameters()) == null) {
            log.error("Connector endpoint is required when creating a remote model without connector id!");
            throw new IllegalArgumentException("Connector endpoint is required when creating a remote model without connector id!");
        }
        mLRegisterModelInput.getConnector().validateConnectorURL(this.trustedConnectorEndpointsRegex);
    }

    private void registerModel(MLRegisterModelInput mLRegisterModelInput, ActionListener<MLRegisterModelResponse> actionListener) {
        Pattern compile = Pattern.compile(this.trustedUrlRegex);
        String url = mLRegisterModelInput.getUrl();
        if (url != null && !compile.matcher(url).find()) {
            throw new IllegalArgumentException("URL can't match trusted url regex");
        }
        boolean z = mLRegisterModelInput.getFunctionName() != FunctionName.REMOTE;
        MLTask build = MLTask.builder().async(z).taskType(MLTaskType.REGISTER_MODEL).functionName(mLRegisterModelInput.getFunctionName()).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).workerNodes(ImmutableList.of(this.clusterService.localNode().getId())).build();
        if (z) {
            this.mlTaskDispatcher.dispatch(mLRegisterModelInput.getFunctionName(), ActionListener.wrap(discoveryNode -> {
                String id = discoveryNode.getId();
                build.setWorkerNodes(ImmutableList.of(id));
                this.mlTaskManager.createMLTask(build, ActionListener.wrap(indexResponse -> {
                    String id2 = indexResponse.getId();
                    build.setTaskId(id2);
                    actionListener.onResponse(new MLRegisterModelResponse(id2, MLTaskState.CREATED.name()));
                    ActionListener wrap = ActionListener.wrap(mLForwardResponse -> {
                        log.debug("Register model response: " + mLForwardResponse);
                        if (this.clusterService.localNode().getId().equals(id)) {
                            return;
                        }
                        this.mlTaskManager.remove(id2);
                    }, exc -> {
                        MLExceptionUtils.logException("Failed to register model", exc, log);
                        this.mlTaskManager.updateMLTask(id2, ImmutableMap.of("error", MLExceptionUtils.getRootCauseMessage(exc), "state", MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                    });
                    try {
                        ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
                        try {
                            this.mlTaskManager.add(build, Arrays.asList(id));
                            this.transportService.sendRequest(discoveryNode, "cluster:admin/opensearch/mlinternal/forward", new MLForwardRequest(MLForwardInput.builder().requestType(MLForwardRequestType.REGISTER_MODEL).registerModelInput(mLRegisterModelInput).mlTask(build).build()), new ActionListenerResponseHandler(wrap, MLForwardResponse::new));
                            if (stashContext != null) {
                                stashContext.close();
                            }
                        } finally {
                        }
                    } catch (Exception e) {
                        wrap.onFailure(e);
                    }
                }, exc -> {
                    MLExceptionUtils.logException("Failed to register model", exc, log);
                    actionListener.onFailure(exc);
                }));
            }, exc -> {
                MLExceptionUtils.logException("Failed to register model", exc, log);
                actionListener.onFailure(exc);
            }));
        } else {
            this.mlTaskManager.createMLTask(build, ActionListener.wrap(indexResponse -> {
                build.setTaskId(indexResponse.getId());
                this.mlModelManager.registerMLRemoteModel(mLRegisterModelInput, build, actionListener);
            }, exc2 -> {
                MLExceptionUtils.logException("Failed to register model", exc2, log);
                actionListener.onFailure(exc2);
            }));
        }
    }

    private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterModelInput mLRegisterModelInput) {
        return MLRegisterModelGroupInput.builder().name(mLRegisterModelInput.getModelName()).description(mLRegisterModelInput.getDescription()).backendRoles(mLRegisterModelInput.getBackendRoles()).modelAccessMode(mLRegisterModelInput.getAccessMode()).isAddAllBackendRoles(mLRegisterModelInput.getAddAllBackendRoles()).build();
    }
}
