package org.opensearch.ml.model;

import com.google.common.math.Quantiles;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.profile.MLPredictRequestStats;

/* loaded from: input_file:org/opensearch/ml/model/MLModelCache.class */
public class MLModelCache {

    @Generated
    private static final Logger log = LogManager.getLogger(MLModelCache.class);
    private MLModelState modelState;
    private FunctionName functionName;
    private Predictable predictor;
    private MLExecutable executor;
    private MLModel modelInfo;
    private Long memSizeEstimationCPU;
    private Long memSizeEstimationGPU;
    private Boolean deployToAllNodes;
    private final Set<String> targetWorkerNodes = ConcurrentHashMap.newKeySet();
    private final Set<String> workerNodes = ConcurrentHashMap.newKeySet();
    private final Queue<Double> modelInferenceDurationQueue = new ConcurrentLinkedQueue();
    private final Queue<Double> predictRequestDurationQueue = new ConcurrentLinkedQueue();

    public void setTargetWorkerNodes(List<String> list) {
        if (list == null || list.size() == 0) {
            throw new IllegalArgumentException("Null or empty target worker nodes");
        }
        this.targetWorkerNodes.clear();
        this.targetWorkerNodes.addAll(list);
    }

    public String[] getTargetWorkerNodes() {
        return (String[]) this.targetWorkerNodes.toArray(new String[0]);
    }

    public void removeWorkerNode(String str, boolean z) {
        if (isDeployToAllNodes() || z) {
            this.targetWorkerNodes.remove(str);
        }
        if (z) {
            this.deployToAllNodes = false;
        }
        this.workerNodes.remove(str);
        if (this.targetWorkerNodes.isEmpty() || this.workerNodes.isEmpty()) {
            this.modelInfo = null;
        }
    }

    public void removeWorkerNodes(Set<String> set, boolean z) {
        if (isDeployToAllNodes() || z) {
            this.targetWorkerNodes.removeAll(set);
        }
        if (z) {
            this.deployToAllNodes = false;
        }
        this.workerNodes.removeAll(set);
        if (this.targetWorkerNodes.isEmpty() || this.workerNodes.isEmpty()) {
            this.modelInfo = null;
        }
    }

    public void addWorkerNode(String str) {
        if (isDeployToAllNodes()) {
            this.targetWorkerNodes.add(str);
        }
        this.workerNodes.add(str);
    }

    public String[] getWorkerNodes() {
        return (String[]) this.workerNodes.toArray(new String[0]);
    }

    public void setModelInfo(MLModel mLModel) {
        this.modelInfo = mLModel;
    }

    public MLModel getCachedModelInfo() {
        return this.modelInfo;
    }

    public void syncWorkerNode(Set<String> set) {
        this.workerNodes.clear();
        this.workerNodes.addAll(set);
    }

    public boolean isDeployToAllNodes() {
        return this.deployToAllNodes != null && this.deployToAllNodes.booleanValue();
    }

    public void clearWorkerNodes() {
        this.workerNodes.clear();
    }

    public void clear() {
        this.modelState = null;
        this.functionName = null;
        this.workerNodes.clear();
        this.modelInfo = null;
        this.modelInferenceDurationQueue.clear();
        this.predictRequestDurationQueue.clear();
        if (this.predictor != null) {
            this.predictor.close();
        }
        this.memSizeEstimationCPU = 0L;
        this.memSizeEstimationGPU = 0L;
        if (this.executor != null) {
            this.executor.close();
        }
    }

    public void addModelInferenceDuration(double d, long j) {
        addInferenceDuration(d, j, this.modelInferenceDurationQueue);
    }

    public void addPredictRequestDuration(double d, long j) {
        addInferenceDuration(d, j, this.predictRequestDurationQueue);
    }

    private void addInferenceDuration(double d, long j, Queue<Double> queue) {
        resizeInferenceQueue(j, queue);
        if (j > 0) {
            queue.add(Double.valueOf(d));
        }
    }

    public void resizeMonitoringQueue(long j) {
        log.debug("resize inference duration monitoring queue with size {}", Long.valueOf(j));
        resizeInferenceQueue(j, this.predictRequestDurationQueue);
        resizeInferenceQueue(j, this.modelInferenceDurationQueue);
    }

    private void resizeInferenceQueue(long j, Queue<Double> queue) {
        if (j <= 0) {
            queue.clear();
        } else {
            while (queue.size() >= j) {
                queue.poll();
            }
        }
    }

    public MLPredictRequestStats getInferenceStats(boolean z) {
        Queue<Double> queue = z ? this.modelInferenceDurationQueue : this.predictRequestDurationQueue;
        if (queue.size() <= 0) {
            return null;
        }
        MLPredictRequestStats.MLPredictRequestStatsBuilder builder = MLPredictRequestStats.builder();
        DoubleSummaryStatistics summaryStatistics = queue.stream().mapToDouble(d -> {
            return d.doubleValue();
        }).summaryStatistics();
        builder.count(Long.valueOf(summaryStatistics.getCount()));
        builder.max(Double.valueOf(summaryStatistics.getMax()));
        builder.min(Double.valueOf(summaryStatistics.getMin()));
        builder.average(Double.valueOf(summaryStatistics.getAverage()));
        Quantiles.Scale percentiles = Quantiles.percentiles();
        builder.p50(Double.valueOf(percentiles.index(50).compute(queue)));
        builder.p90(Double.valueOf(percentiles.index(90).compute(queue)));
        builder.p99(Double.valueOf(percentiles.index(99).compute(queue)));
        return builder.build();
    }

    public boolean isValidCache() {
        return this.modelState != null || this.workerNodes.size() > 0;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public void setModelState(MLModelState mLModelState) {
        this.modelState = mLModelState;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public MLModelState getModelState() {
        return this.modelState;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public void setFunctionName(FunctionName functionName) {
        this.functionName = functionName;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public FunctionName getFunctionName() {
        return this.functionName;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public void setPredictor(Predictable predictable) {
        this.predictor = predictable;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public Predictable getPredictor() {
        return this.predictor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public void setExecutor(MLExecutable mLExecutable) {
        this.executor = mLExecutable;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public MLExecutable getExecutor() {
        return this.executor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public void setMemSizeEstimationCPU(Long l) {
        this.memSizeEstimationCPU = l;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public Long getMemSizeEstimationCPU() {
        return this.memSizeEstimationCPU;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public void setMemSizeEstimationGPU(Long l) {
        this.memSizeEstimationGPU = l;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Generated
    public Long getMemSizeEstimationGPU() {
        return this.memSizeEstimationGPU;
    }

    @Generated
    public void setDeployToAllNodes(Boolean bool) {
        this.deployToAllNodes = bool;
    }
}
