package org.opensearch.ml.model;

import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.profile.MLModelProfile;
import org.opensearch.ml.settings.MLCommonsSettings;

/* loaded from: input_file:org/opensearch/ml/model/MLModelCacheHelper.class */
public class MLModelCacheHelper {

    @Generated
    private static final Logger log = LogManager.getLogger(MLModelCacheHelper.class);
    private final Map<String, MLModelCache> modelCaches = new ConcurrentHashMap();
    private volatile Long maxRequestCount;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.opensearch.ml.model.MLModelCacheHelper$1, reason: invalid class name */
    /* loaded from: input_file:org/opensearch/ml/model/MLModelCacheHelper$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$opensearch$ml$common$model$MLModelFormat = new int[MLModelFormat.values().length];

        static {
            try {
                $SwitchMap$org$opensearch$ml$common$model$MLModelFormat[MLModelFormat.ONNX.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$opensearch$ml$common$model$MLModelFormat[MLModelFormat.TORCH_SCRIPT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
        this.maxRequestCount = (Long) MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT, l -> {
            this.maxRequestCount = l;
        });
    }

    public synchronized void initModelState(String str, MLModelState mLModelState, FunctionName functionName, List<String> list, boolean z) {
        if (isModelRunningOnNode(str)) {
            throw new MLLimitExceededException("Duplicate deploy model task");
        }
        log.debug("init model state for model {}, state: {}", str, mLModelState);
        MLModelCache mLModelCache = new MLModelCache();
        mLModelCache.setModelState(mLModelState);
        mLModelCache.setFunctionName(functionName);
        mLModelCache.setTargetWorkerNodes(list);
        mLModelCache.setDeployToAllNodes(Boolean.valueOf(z));
        this.modelCaches.put(str, mLModelCache);
    }

    public synchronized void setModelState(String str, MLModelState mLModelState) {
        log.debug("Updating State of Model {}  to state {}", str, mLModelState);
        getExistingModelCache(str).setModelState(mLModelState);
    }

    public synchronized void setMemSizeEstimation(String str, MLModelFormat mLModelFormat, Long l) {
        Long memSizeEstimation = getMemSizeEstimation(mLModelFormat, l);
        log.debug("Updating memSizeEstimation of Model {}  to {}", str, memSizeEstimation);
        getExistingModelCache(str).setMemSizeEstimationCPU(memSizeEstimation);
        getExistingModelCache(str).setMemSizeEstimationGPU(memSizeEstimation);
    }

    private Long getMemSizeEstimation(MLModelFormat mLModelFormat, Long l) {
        Double valueOf = Double.valueOf(1.0d);
        switch (AnonymousClass1.$SwitchMap$org$opensearch$ml$common$model$MLModelFormat[mLModelFormat.ordinal()]) {
            case 1:
                valueOf = Double.valueOf(1.5d);
                break;
            case 2:
                valueOf = Double.valueOf(1.2d);
                break;
        }
        return Long.valueOf(Double.valueOf(valueOf.doubleValue() * l.longValue()).longValue());
    }

    public Long getMemEstCPU(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        return mLModelCache.getMemSizeEstimationCPU();
    }

    public Long getMemEstGPU(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        return mLModelCache.getMemSizeEstimationGPU();
    }

    public synchronized boolean isModelDeployed(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        return mLModelCache != null && mLModelCache.getModelState() == MLModelState.DEPLOYED;
    }

    public String[] getDeployedModels() {
        return (String[]) ((List) this.modelCaches.entrySet().stream().filter(entry -> {
            return ((MLModelCache) entry.getValue()).getModelState() == MLModelState.DEPLOYED;
        }).map(entry2 -> {
            return (String) entry2.getKey();
        }).collect(Collectors.toList())).toArray(new String[0]);
    }

    public String[] getLocalDeployedModels() {
        return (String[]) ((List) this.modelCaches.entrySet().stream().filter(entry -> {
            return ((MLModelCache) entry.getValue()).getModelState() == MLModelState.DEPLOYED && ((MLModelCache) entry.getValue()).getFunctionName() != FunctionName.REMOTE;
        }).map(entry2 -> {
            return (String) entry2.getKey();
        }).collect(Collectors.toList())).toArray(new String[0]);
    }

    public boolean isModelRunningOnNode(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        return (mLModelCache == null || mLModelCache.getModelState() == null) ? false : true;
    }

    public synchronized void setPredictor(String str, Predictable predictable) {
        getExistingModelCache(str).setPredictor(predictable);
    }

    public synchronized void setMLExecutor(String str, MLExecutable mLExecutable) {
        getExistingModelCache(str).setExecutor(mLExecutable);
    }

    public MLExecutable getMLExecutor(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        return mLModelCache.getExecutor();
    }

    public Predictable getPredictor(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        return mLModelCache.getPredictor();
    }

    public void setTargetWorkerNodes(String str, List<String> list) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache != null) {
            mLModelCache.setTargetWorkerNodes(list);
        }
    }

    public void removeModel(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache != null) {
            log.debug("removing model {} from cache", str);
            mLModelCache.clear();
            this.modelCaches.remove(str);
        }
    }

    public String[] getAllModels() {
        return (String[]) this.modelCaches.keySet().toArray(new String[0]);
    }

    public String[] getWorkerNodes(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        return mLModelCache.getWorkerNodes();
    }

    public synchronized void addWorkerNode(String str, String str2) {
        log.debug("add node {} to model routing table for model: {}", str2, str);
        getOrCreateModelCache(str).addWorkerNode(str2);
    }

    public void removeWorkerNodes(Set<String> set, boolean z) {
        for (String str : this.modelCaches.keySet()) {
            MLModelCache mLModelCache = this.modelCaches.get(str);
            log.debug("remove worker nodes of model {} : {}", str, set.toArray(new String[0]));
            mLModelCache.removeWorkerNodes(set, z);
            if (!mLModelCache.isValidCache()) {
                log.debug("remove model cache {}", str);
                this.modelCaches.remove(str);
            }
        }
    }

    public void removeWorkerNode(String str, String str2, boolean z) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache != null) {
            log.debug("remove worker node {} of model {} from cache", str2, str);
            mLModelCache.removeWorkerNode(str2, z);
            if (mLModelCache.isValidCache()) {
                return;
            }
            log.debug("remove model {} from cache as no node running it", str);
            this.modelCaches.remove(str);
        }
    }

    public void syncWorkerNodes(Map<String, Set<String>> map) {
        log.debug("sync model worker nodes");
        HashSet hashSet = new HashSet(this.modelCaches.keySet());
        hashSet.removeAll(map.keySet());
        if (hashSet.size() > 0) {
            hashSet.forEach(str -> {
                clearWorkerNodes(str);
            });
        }
        map.entrySet().forEach(entry -> {
            getOrCreateModelCache((String) entry.getKey()).syncWorkerNode((Set) entry.getValue());
        });
    }

    public void clearWorkerNodes() {
        log.debug("clear all model worker nodes");
        this.modelCaches.entrySet().forEach(entry -> {
            clearWorkerNodes((String) entry.getKey());
        });
    }

    public void clearWorkerNodes(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache != null) {
            log.debug("clear worker nodes of model {}", str);
            mLModelCache.clearWorkerNodes();
            if (mLModelCache.isValidCache()) {
                return;
            }
            this.modelCaches.remove(str);
        }
    }

    public MLModelProfile getModelProfile(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        MLModelProfile.MLModelProfileBuilder builder = MLModelProfile.builder();
        builder.modelState(mLModelCache.getModelState());
        if (mLModelCache.getPredictor() != null) {
            builder.predictor(mLModelCache.getPredictor().toString());
        }
        String[] targetWorkerNodes = mLModelCache.getTargetWorkerNodes();
        if (targetWorkerNodes.length > 0) {
            builder.targetWorkerNodes(targetWorkerNodes);
        }
        String[] workerNodes = mLModelCache.getWorkerNodes();
        if (workerNodes.length > 0) {
            builder.workerNodes(workerNodes);
        }
        builder.modelInferenceStats(mLModelCache.getInferenceStats(true));
        builder.predictRequestStats(mLModelCache.getInferenceStats(false));
        builder.memSizeEstimationCPU(mLModelCache.getMemSizeEstimationCPU());
        builder.memSizeEstimationGPU(mLModelCache.getMemSizeEstimationGPU());
        return builder.build();
    }

    public void addModelInferenceDuration(String str, double d) {
        getOrCreateModelCache(str).addModelInferenceDuration(d, this.maxRequestCount.longValue());
    }

    public void addPredictRequestDuration(String str, double d) {
        getOrCreateModelCache(str).addPredictRequestDuration(d, this.maxRequestCount.longValue());
    }

    public void resizeMonitoringQueue(long j) {
        Iterator<Map.Entry<String, MLModelCache>> it = this.modelCaches.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().resizeMonitoringQueue(j);
        }
    }

    public FunctionName getFunctionName(String str) {
        return getExistingModelCache(str).getFunctionName();
    }

    public Optional<FunctionName> getOptionalFunctionName(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        return Optional.ofNullable(mLModelCache == null ? null : mLModelCache.getFunctionName());
    }

    public void setDeployToAllNodes(String str, Boolean bool) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache != null) {
            log.info("Starting to set deployToAllNodes flag to modelId: {}, value to: {}", str, bool);
            mLModelCache.setDeployToAllNodes(bool);
        }
    }

    public boolean getDeployToAllNodes(String str) {
        return getExistingModelCache(str).isDeployToAllNodes();
    }

    public void setModelInfo(String str, MLModel mLModel) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache != null) {
            mLModelCache.setModelInfo(mLModel);
        }
    }

    public MLModel getModelInfo(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            return null;
        }
        return mLModelCache.getCachedModelInfo();
    }

    private MLModelCache getExistingModelCache(String str) {
        MLModelCache mLModelCache = this.modelCaches.get(str);
        if (mLModelCache == null) {
            throw new IllegalArgumentException("Model not found in cache");
        }
        return mLModelCache;
    }

    private MLModelCache getOrCreateModelCache(String str) {
        return this.modelCaches.computeIfAbsent(str, str2 -> {
            return new MLModelCache();
        });
    }
}
