package org.opensearch.ml.action.batch;

import com.jayway.jsonpath.PathNotFoundException;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.ingest.Ingestable;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/action/batch/TransportBatchIngestionAction.class */
public class TransportBatchIngestionAction extends HandledTransportAction<ActionRequest, MLBatchIngestionResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(TransportBatchIngestionAction.class);
    private static final String S3_URI_REGEX = "^s3://([a-zA-Z0-9.-]+)(/.*)?$";
    private static final Pattern S3_URI_PATTERN = Pattern.compile(S3_URI_REGEX);
    public static final String TYPE = "type";
    public static final String SOURCE = "source";
    TransportService transportService;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    private final Client client;
    private ThreadPool threadPool;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private volatile Integer batchIngestionBulkSize;

    @Inject
    public TransportBatchIngestionAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, Client client, MLTaskManager mLTaskManager, ThreadPool threadPool, MLModelManager mLModelManager, MLFeatureEnabledSetting mLFeatureEnabledSetting, Settings settings) {
        super("cluster:admin/opensearch/ml/batch_ingestion", transportService, actionFilters, MLBatchIngestionRequest::new);
        this.transportService = transportService;
        this.client = client;
        this.mlTaskManager = mLTaskManager;
        this.threadPool = threadPool;
        this.mlModelManager = mLModelManager;
        this.mlFeatureEnabledSetting = mLFeatureEnabledSetting;
        this.batchIngestionBulkSize = (Integer) MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE, num -> {
            this.batchIngestionBulkSize = num;
        });
    }

    protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<MLBatchIngestionResponse> actionListener) {
        MLBatchIngestionInput mlBatchIngestionInput = MLBatchIngestionRequest.fromActionRequest(actionRequest).getMlBatchIngestionInput();
        try {
            if (!this.mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled().booleanValue()) {
                throw new IllegalStateException(MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
            }
            validateBatchIngestInput(mlBatchIngestionInput);
            if (mlBatchIngestionInput.getConnectorId() == null || !(mlBatchIngestionInput.getCredential() == null || mlBatchIngestionInput.getCredential().isEmpty())) {
                createMLTaskandExecute(mlBatchIngestionInput, actionListener);
            } else {
                this.mlModelManager.getConnectorCredential(mlBatchIngestionInput.getConnectorId(), ActionListener.wrap(map -> {
                    mlBatchIngestionInput.setCredential(map);
                    createMLTaskandExecute(mlBatchIngestionInput, actionListener);
                }, exc -> {
                    log.error(exc.getMessage());
                    actionListener.onFailure(new OpenSearchStatusException("Fail to fetch credentials from the connector in the batch ingestion input: " + exc.getMessage(), RestStatus.BAD_REQUEST, new Object[0]));
                }));
            }
        } catch (IllegalArgumentException e) {
            log.error(e.getMessage());
            actionListener.onFailure(new OpenSearchStatusException("IllegalArgumentException in the batch ingestion input: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]));
        } catch (Exception e2) {
            actionListener.onFailure(e2);
        }
    }

    protected void createMLTaskandExecute(MLBatchIngestionInput mLBatchIngestionInput, ActionListener<MLBatchIngestionResponse> actionListener) {
        MLTask build = MLTask.builder().async(true).taskType(MLTaskType.BATCH_INGEST).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).build();
        this.mlModelManager.checkMaxBatchJobTask(build, ActionListener.wrap(bool -> {
            if (!bool.booleanValue()) {
                this.mlTaskManager.createMLTask(build, ActionListener.wrap(indexResponse -> {
                    String id = indexResponse.getId();
                    try {
                        build.setTaskId(id);
                        this.mlTaskManager.add(build);
                        actionListener.onResponse(new MLBatchIngestionResponse(id, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
                        Ingestable ingestable = (Ingestable) MLEngineClassLoader.initInstance(((String) mLBatchIngestionInput.getDataSources().get(TYPE)).toLowerCase(), this.client, Client.class);
                        this.threadPool.executor(MachineLearningPlugin.INGEST_THREAD_POOL).execute(() -> {
                            executeWithErrorHandling(() -> {
                                handleSuccessRate(ingestable.ingest(mLBatchIngestionInput, this.batchIngestionBulkSize.intValue()), id);
                            }, id);
                        });
                    } catch (Exception e) {
                        log.error("Failed in batch ingestion", e);
                        this.mlTaskManager.updateMLTask(id, Map.of("state", MLTaskState.FAILED, "error", MLExceptionUtils.getRootCauseMessage(e)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                        actionListener.onFailure(e);
                    }
                }, exc -> {
                    log.error("Failed to create batch ingestion task", exc);
                    actionListener.onFailure(exc);
                }));
            } else {
                log.warn("Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting." + " in task " + build.getTaskId());
                actionListener.onFailure(new OpenSearchStatusException("Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
            }
        }, exc -> {
            log.error("Failed to check the maximum BATCH_INGEST Task limits", exc);
            actionListener.onFailure(exc);
        }));
    }

    protected void executeWithErrorHandling(Runnable runnable, String str) {
        try {
            runnable.run();
        } catch (Exception e) {
            log.error("Error in ingest, failed to produce a successRate", e);
            this.mlTaskManager.updateMLTask(str, Map.of("state", MLTaskState.FAILED, "error", MLExceptionUtils.getRootCauseMessage(e)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        } catch (PathNotFoundException e2) {
            log.error("Error in jsonParse fields", e2);
            this.mlTaskManager.updateMLTask(str, Map.of("state", MLTaskState.FAILED, "error", e2.getMessage()), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        }
    }

    protected void handleSuccessRate(double d, String str) {
        if (d == 100.0d) {
            this.mlTaskManager.updateMLTask(str, Map.of("state", MLTaskState.COMPLETED), 5000L, true);
        } else if (d > 0.0d) {
            this.mlTaskManager.updateMLTask(str, Map.of("state", MLTaskState.FAILED, "error", "batch ingestion successful rate is " + d), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        } else {
            this.mlTaskManager.updateMLTask(str, Map.of("state", MLTaskState.FAILED, "error", "batch ingestion successful rate is 0"), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        }
    }

    private void validateBatchIngestInput(MLBatchIngestionInput mLBatchIngestionInput) {
        if (mLBatchIngestionInput == null || mLBatchIngestionInput.getDataSources() == null || mLBatchIngestionInput.getDataSources().isEmpty()) {
            throw new IllegalArgumentException("The batch ingest input data source cannot be null");
        }
        if (mLBatchIngestionInput.getCredential() == null && mLBatchIngestionInput.getConnectorId() == null) {
            throw new IllegalArgumentException("The batch ingest credential or connector_id cannot be null");
        }
        Map dataSources = mLBatchIngestionInput.getDataSources();
        if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) {
            throw new IllegalArgumentException("The batch ingest input data source is missing data type or source");
        }
        if (((String) dataSources.get(TYPE)).equalsIgnoreCase("s3")) {
            List list = (List) dataSources.get(SOURCE);
            if (list == null || list.isEmpty()) {
                throw new IllegalArgumentException("The batch ingest input s3Uris is empty");
            }
            List list2 = (List) ((Map) list.stream().collect(Collectors.partitioningBy(str -> {
                return S3_URI_PATTERN.matcher(str).matches();
            }))).get(false);
            if (!list2.isEmpty()) {
                throw new IllegalArgumentException("The following batch ingest input S3 URIs are invalid: " + String.valueOf(list2));
            }
        }
    }
}
