package org.opensearch.ml.task;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.MLTaskRequest;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

/* loaded from: input_file:org/opensearch/ml/task/MLTaskRunner.class */
public abstract class MLTaskRunner<Request extends MLTaskRequest, Response extends TransportResponse> {

    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskRunner.class);
    public static final int TIMEOUT_IN_MILLIS = 2000;
    protected final MLTaskManager mlTaskManager;
    protected final MLStats mlStats;
    protected final DiscoveryNodeHelper nodeHelper;
    protected final MLTaskDispatcher mlTaskDispatcher;
    protected final MLCircuitBreakerService mlCircuitBreakerService;
    private final ClusterService clusterService;

    public MLTaskRunner(MLTaskManager mLTaskManager, MLStats mLStats, DiscoveryNodeHelper discoveryNodeHelper, MLTaskDispatcher mLTaskDispatcher, MLCircuitBreakerService mLCircuitBreakerService, ClusterService clusterService) {
        this.mlTaskManager = mLTaskManager;
        this.mlStats = mLStats;
        this.nodeHelper = discoveryNodeHelper;
        this.mlTaskDispatcher = mLTaskDispatcher;
        this.mlCircuitBreakerService = mLCircuitBreakerService;
        this.clusterService = clusterService;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleAsyncMLTaskFailure(MLTask mLTask, Exception exc) {
        if (mLTask.isAsync()) {
            this.mlTaskManager.updateMLTask(mLTask.getTaskId(), ImmutableMap.of("state", MLTaskState.FAILED.name(), "error", exc.getMessage()), 2000L, true);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleAsyncMLTaskComplete(MLTask mLTask) {
        if (mLTask.isAsync()) {
            HashMap hashMap = new HashMap();
            hashMap.put("state", MLTaskState.COMPLETED);
            if (mLTask.getModelId() != null) {
                hashMap.put("model_id", mLTask.getModelId());
            }
            this.mlTaskManager.updateMLTask(mLTask.getTaskId(), hashMap, 2000L, true);
        }
    }

    public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> actionListener) {
        if (request.isDispatchTask()) {
            dispatchTask(functionName, request, transportService, actionListener);
        } else {
            log.debug("Run ML request {} locally", request.getRequestID());
            checkCBAndExecute(functionName, request, actionListener);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ActionListener<MLTaskResponse> wrappedCleanupListener(ActionListener<MLTaskResponse> actionListener, String str) {
        return ActionListener.runAfter(actionListener, () -> {
            this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
            this.mlTaskManager.remove(str);
        });
    }

    public void dispatchTask(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> actionListener) {
        this.mlTaskDispatcher.dispatch(functionName, ActionListener.wrap(discoveryNode -> {
            String id = discoveryNode.getId();
            if (this.clusterService.localNode().getId().equals(id)) {
                log.debug("Execute ML request {} locally on node {}", request.getRequestID(), id);
                MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
                executeTask(request, actionListener);
            } else {
                log.debug("Execute ML request {} remotely on node {}", request.getRequestID(), id);
                request.setDispatchTask(false);
                transportService.sendRequest(discoveryNode, getTransportActionName(), request, getResponseHandler(actionListener));
            }
        }, exc -> {
            actionListener.onFailure(exc);
        }));
    }

    protected abstract String getTransportActionName();

    protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> actionListener);

    protected abstract void executeTask(Request request, ActionListener<Response> actionListener);

    /* JADX INFO: Access modifiers changed from: protected */
    public void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> actionListener) {
        if (functionName != FunctionName.REMOTE) {
            MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
        }
        executeTask(request, actionListener);
    }
}
