package org.opensearch.ml.task;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.client.Requests;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.threadpool.ThreadPool;

/* loaded from: input_file:org/opensearch/ml/task/MLTaskManager.class */
public class MLTaskManager {
    private final Client client;
    private final ThreadPool threadPool;
    private final MLIndicesHandler mlIndicesHandler;

    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskManager.class);
    public static int TASK_SEMAPHORE_TIMEOUT = MLModelManager.TIMEOUT_IN_MILLIS;
    public static final ImmutableSet TASK_DONE_STATES = ImmutableSet.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);
    private final Map<String, MLTaskCache> taskCaches = new ConcurrentHashMap();
    private final Map<MLTaskType, AtomicInteger> runningTasksCount = new ConcurrentHashMap();

    public MLTaskManager(Client client, ThreadPool threadPool, MLIndicesHandler mLIndicesHandler) {
        this.client = client;
        this.threadPool = threadPool;
        this.mlIndicesHandler = mLIndicesHandler;
    }

    public synchronized void checkLimitAndAddRunningTask(MLTask mLTask, Integer num) {
        AtomicInteger computeIfAbsent = this.runningTasksCount.computeIfAbsent(mLTask.getTaskType(), mLTaskType -> {
            return new AtomicInteger(0);
        });
        if (computeIfAbsent.get() < 0) {
            computeIfAbsent.set(0);
        }
        log.debug("Task id: {}, current running task {}: {}", mLTask.getTaskId(), mLTask.getTaskType(), Integer.valueOf(computeIfAbsent.get()));
        if (computeIfAbsent.get() >= num.intValue()) {
            log.warn("exceed max running task limit" + " for task " + mLTask.getTaskId());
            throw new MLLimitExceededException("exceed max running task limit");
        }
        if (contains(mLTask.getTaskId())) {
            getMLTask(mLTask.getTaskId()).setState(MLTaskState.RUNNING);
        } else {
            mLTask.setState(MLTaskState.RUNNING);
            add(mLTask);
        }
        computeIfAbsent.incrementAndGet();
    }

    public synchronized void checkMaxBatchJobTask(MLTaskType mLTaskType, Integer num, ActionListener<Boolean> actionListener) {
        try {
            SearchSourceBuilder query = new SearchSourceBuilder().query(QueryBuilders.boolQuery().must(QueryBuilders.termQuery("task_type", mLTaskType.name())).must(QueryBuilders.boolQuery().should(QueryBuilders.termQuery("state", MLTaskState.CREATED)).should(QueryBuilders.termQuery("state", MLTaskState.RUNNING))));
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-task"});
            searchRequest.source(query);
            try {
                ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
                try {
                    ActionListener runBefore = ActionListener.runBefore(ActionListener.wrap(searchResponse -> {
                        Boolean bool = false;
                        if (searchResponse.getHits().getHits().length >= num.intValue()) {
                            bool = true;
                        }
                        actionListener.onResponse(bool);
                    }, exc -> {
                        actionListener.onFailure(exc);
                    }), () -> {
                        stashContext.restore();
                    });
                    this.client.admin().indices().refresh(Requests.refreshRequest(new String[]{".plugins-ml-task"}), ActionListener.wrap(refreshResponse -> {
                        this.client.search(searchRequest, runBefore);
                    }, exc2 -> {
                        log.error("Failed to refresh Task index during search MLTaskType for " + String.valueOf(mLTaskType), exc2);
                        runBefore.onFailure(exc2);
                    }));
                    if (stashContext != null) {
                        stashContext.close();
                    }
                } catch (Throwable th) {
                    if (stashContext != null) {
                        try {
                            stashContext.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Exception e) {
                actionListener.onFailure(e);
            }
        } catch (Exception e2) {
            log.error("Failed to search ML task for " + String.valueOf(mLTaskType), e2);
            actionListener.onFailure(e2);
        }
    }

    public synchronized void add(MLTask mLTask) {
        add(mLTask, null);
    }

    public synchronized void add(MLTask mLTask, List<String> list) {
        String taskId = mLTask.getTaskId();
        if (contains(taskId)) {
            throw new IllegalArgumentException("Duplicate taskId");
        }
        this.taskCaches.put(taskId, new MLTaskCache(mLTask, list));
        log.debug("add ML task to cache, taskId: {}, taskType: {} ", taskId, mLTask.getTaskType());
    }

    public boolean contains(String str) {
        return this.taskCaches.containsKey(str);
    }

    public void remove(String str) {
        AtomicInteger atomicInteger;
        if (contains(str)) {
            MLTask mlTask = this.taskCaches.remove(str).getMlTask();
            if (mlTask.getState() != MLTaskState.CREATED && (atomicInteger = this.runningTasksCount.get(mlTask.getTaskType())) != null) {
                atomicInteger.decrementAndGet();
            }
            log.debug("remove ML task from cache " + str);
        }
    }

    public MLTask getMLTask(String str) {
        if (contains(str)) {
            return this.taskCaches.get(str).getMlTask();
        }
        return null;
    }

    public MLTaskCache getMLTaskCache(String str) {
        if (contains(str)) {
            return this.taskCaches.get(str);
        }
        return null;
    }

    public Set<String> getWorkNodes(String str) {
        if (this.taskCaches.containsKey(str)) {
            return this.taskCaches.get(str).getWorkerNodes();
        }
        return null;
    }

    public void addNodeError(String str, String str2, String str3) {
        log.debug("add task error: taskId: {}, workerNodeId: {}, error: {}", str, str2, str3);
        if (this.taskCaches.containsKey(str)) {
            this.taskCaches.get(str).addError(str2, str3);
        }
    }

    public String[] getAllTaskIds() {
        return Strings.toStringArray(this.taskCaches.keySet());
    }

    public int getRunningTaskCount() {
        int i = 0;
        Iterator<Map.Entry<String, MLTaskCache>> it = this.taskCaches.entrySet().iterator();
        while (it.hasNext()) {
            MLTask mlTask = it.next().getValue().getMlTask();
            if (mlTask.getState() != null && mlTask.getState() == MLTaskState.RUNNING) {
                i++;
            }
        }
        return i;
    }

    public void clear() {
        this.taskCaches.clear();
    }

    public void createMLTask(MLTask mLTask, ActionListener<IndexResponse> actionListener) {
        this.mlIndicesHandler.initMLTaskIndex(ActionListener.wrap(bool -> {
            if (!bool.booleanValue()) {
                actionListener.onFailure(new RuntimeException("No response to create ML task index"));
                return;
            }
            IndexRequest indexRequest = new IndexRequest(".plugins-ml-task");
            try {
                XContentBuilder jsonBuilder = XContentFactory.jsonBuilder();
                try {
                    ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
                    try {
                        indexRequest.source(mLTask.toXContent(jsonBuilder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                        this.client.index(indexRequest, ActionListener.runBefore(actionListener, () -> {
                            stashContext.restore();
                        }));
                        if (stashContext != null) {
                            stashContext.close();
                        }
                        if (jsonBuilder != null) {
                            jsonBuilder.close();
                        }
                    } catch (Throwable th) {
                        if (stashContext != null) {
                            try {
                                stashContext.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } finally {
                }
            } catch (Exception e) {
                log.error("Failed to create AD task for " + String.valueOf(mLTask.getFunctionName()) + ", " + String.valueOf(mLTask.getTaskType()), e);
                actionListener.onFailure(e);
            }
        }, exc -> {
            log.error("Failed to create ML index", exc);
            actionListener.onFailure(exc);
        }));
    }

    public void updateTaskStateAsRunning(String str, boolean z) {
        if (!contains(str)) {
            throw new IllegalArgumentException("Task not found");
        }
        getMLTask(str).setState(MLTaskState.RUNNING);
        if (z) {
            updateMLTask(str, ImmutableMap.of("state", MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
        }
    }

    public void updateMLTask(String str, Map<String, Object> map, long j, boolean z) {
        updateMLTask(str, map, ActionListener.wrap(updateResponse -> {
            if (updateResponse.status() == RestStatus.OK) {
                log.debug("Updated ML task successfully: {}, taskId: {}, updatedFields: {}", updateResponse.status(), str, map);
            } else {
                log.error("Failed to update ML task {}, status: {}, updatedFields: {}", str, updateResponse.status(), map);
            }
        }, exc -> {
            MLExceptionUtils.logException("Failed to update ML task: " + str, exc, log);
        }), j, z);
    }

    public void updateMLTask(String str, Map<String, Object> map, ActionListener<UpdateResponse> actionListener, long j, boolean z) {
        MLTaskCache mLTaskCache = this.taskCaches.get(str);
        if (z) {
            remove(str);
        }
        if (mLTaskCache == null) {
            actionListener.onFailure(new MLResourceNotFoundException("Can't find task in cache: " + str));
        } else {
            this.threadPool.executor(MachineLearningPlugin.GENERAL_THREAD_POOL).execute(() -> {
                Semaphore updateTaskIndexSemaphore = mLTaskCache.getUpdateTaskIndexSemaphore();
                if (updateTaskIndexSemaphore != null) {
                    try {
                        if (!updateTaskIndexSemaphore.tryAcquire(j, TimeUnit.MILLISECONDS)) {
                            actionListener.onFailure(new MLException("Other updating request not finished yet"));
                            return;
                        }
                    } catch (InterruptedException e) {
                        log.error("Failed to acquire semaphore for ML task " + str, e);
                        actionListener.onFailure(e);
                        return;
                    }
                }
                if (map != null) {
                    try {
                        if (map.size() != 0) {
                            UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", str);
                            HashMap hashMap = new HashMap();
                            hashMap.putAll(map);
                            hashMap.put("last_update_time", Long.valueOf(Instant.now().toEpochMilli()));
                            updateRequest.doc(hashMap);
                            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                            if (map.containsKey("state") && TASK_DONE_STATES.contains(Boolean.valueOf(map.containsKey("state")))) {
                                updateRequest.retryOnConflict(3);
                            }
                            ActionListener runAfter = updateTaskIndexSemaphore == null ? actionListener : ActionListener.runAfter(actionListener, () -> {
                                updateTaskIndexSemaphore.release();
                            });
                            try {
                                ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
                                try {
                                    this.client.update(updateRequest, ActionListener.runBefore(runAfter, () -> {
                                        stashContext.restore();
                                    }));
                                    if (stashContext != null) {
                                        stashContext.close();
                                    }
                                } catch (Throwable th) {
                                    if (stashContext != null) {
                                        try {
                                            stashContext.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    }
                                    throw th;
                                }
                            } catch (Exception e2) {
                                runAfter.onFailure(e2);
                            }
                            return;
                        }
                    } catch (Exception e3) {
                        updateTaskIndexSemaphore.release();
                        log.error("Failed to update ML task " + str, e3);
                        actionListener.onFailure(e3);
                        return;
                    }
                }
                actionListener.onFailure(new IllegalArgumentException("Updated fields is null or empty"));
            });
        }
    }

    public void updateMLTaskDirectly(String str, Map<String, Object> map) {
        updateMLTaskDirectly(str, map, ActionListener.wrap(updateResponse -> {
            log.debug("updated ML task directly: {}", str);
        }, exc -> {
            log.error("Failed to update ML task " + str, exc);
        }));
    }

    public void updateMLTaskDirectly(String str, Map<String, Object> map, ActionListener<UpdateResponse> actionListener) {
        if (map != null) {
            try {
                if (map.size() != 0) {
                    UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", str);
                    HashMap hashMap = new HashMap();
                    hashMap.putAll(map);
                    hashMap.put("last_update_time", Long.valueOf(Instant.now().toEpochMilli()));
                    updateRequest.doc(hashMap);
                    updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                    if (map.containsKey("state") && TASK_DONE_STATES.contains(Boolean.valueOf(map.containsKey("state")))) {
                        updateRequest.retryOnConflict(3);
                    }
                    try {
                        ThreadContext.StoredContext stashContext = this.client.threadPool().getThreadContext().stashContext();
                        try {
                            this.client.update(updateRequest, ActionListener.runBefore(actionListener, () -> {
                                stashContext.restore();
                            }));
                            if (stashContext != null) {
                                stashContext.close();
                            }
                        } catch (Throwable th) {
                            if (stashContext != null) {
                                try {
                                    stashContext.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } catch (Exception e) {
                        actionListener.onFailure(e);
                    }
                    return;
                }
            } catch (Exception e2) {
                log.error("Failed to update ML task " + str, e2);
                actionListener.onFailure(e2);
                return;
            }
        }
        actionListener.onFailure(new IllegalArgumentException("Updated fields is null or empty"));
    }

    public boolean containsModel(String str) {
        Iterator<Map.Entry<String, MLTaskCache>> it = this.taskCaches.entrySet().iterator();
        while (it.hasNext()) {
            if (str.equals(it.next().getValue().mlTask.getModelId())) {
                return true;
            }
        }
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public List<String[]> getLocalRunningDeployModelTasks() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getTaskType() == MLTaskType.DEPLOY_MODEL && mlTask.getState() != MLTaskState.CREATED) {
                arrayList.add(entry.getKey());
                arrayList2.add(mlTask.getModelId());
            }
        }
        return Arrays.asList((String[]) arrayList.toArray(new String[0]), (String[]) arrayList2.toArray(new String[0]));
    }
}
