package org.opensearch.ml.task;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.parameter.MLTask;
import org.opensearch.ml.common.parameter.MLTaskState;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.StatNames;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/task/MLTaskRunner.class */
public abstract class MLTaskRunner<Request, Response> {
    public static final int TIMEOUT_IN_MILLIS = 2000;
    protected final MLTaskManager mlTaskManager;
    protected final MLStats mlStats;
    protected final MLTaskDispatcher mlTaskDispatcher;
    protected final MLCircuitBreakerService mlCircuitBreakerService;
    protected static final String TASK_ID = "task_id";
    protected static final String ALGORITHM = "algorithm";
    protected static final String MODEL_NAME = "model_name";
    protected static final String MODEL_VERSION = "model_version";
    protected static final String MODEL_CONTENT = "model_content";
    protected static final String USER = "user";

    public MLTaskRunner(MLTaskManager mLTaskManager, MLStats mLStats, MLTaskDispatcher mLTaskDispatcher, MLCircuitBreakerService mLCircuitBreakerService) {
        this.mlTaskManager = mLTaskManager;
        this.mlStats = mLStats;
        this.mlTaskDispatcher = mLTaskDispatcher;
        this.mlCircuitBreakerService = mLCircuitBreakerService;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleAsyncMLTaskFailure(MLTask mLTask, Exception exc) {
        if (mLTask.isAsync()) {
            this.mlTaskManager.updateMLTask(mLTask.getTaskId(), ImmutableMap.of("state", MLTaskState.FAILED.name(), "error", exc.getMessage()), 2000L);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleAsyncMLTaskComplete(MLTask mLTask) {
        if (mLTask.isAsync()) {
            HashMap hashMap = new HashMap();
            hashMap.put("state", MLTaskState.COMPLETED);
            if (mLTask.getModelId() != null) {
                hashMap.put(RestActionUtils.PARAMETER_MODEL_ID, mLTask.getModelId());
            }
            this.mlTaskManager.updateMLTask(mLTask.getTaskId(), hashMap, 2000L);
        }
    }

    public void run(Request request, TransportService transportService, ActionListener<Response> actionListener) {
        if (this.mlCircuitBreakerService.isOpen().booleanValue()) {
            throw new MLLimitExceededException("Circuit breaker is open");
        }
        executeTask(request, transportService, actionListener);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ActionListener<MLTaskResponse> wrappedCleanupListener(ActionListener<MLTaskResponse> actionListener, String str) {
        return ActionListener.runAfter(actionListener, () -> {
            this.mlStats.getStat(StatNames.ML_EXECUTING_TASK_COUNT).decrement();
            this.mlTaskManager.remove(str);
        });
    }

    public abstract void executeTask(Request request, TransportService transportService, ActionListener<Response> actionListener);
}
