package org.opensearch.knn.index.memory;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheStats;
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import java.io.Closeable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.knn.common.exception.OutOfNativeMemoryException;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.plugin.stats.StatNames;

/* loaded from: input_file:org/opensearch/knn/index/memory/NativeMemoryCacheManager.class */
public class NativeMemoryCacheManager implements Closeable {
    public static String GRAPH_COUNT = "graph_count";
    private static Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class);
    private static NativeMemoryCacheManager INSTANCE;
    private Cache<String, NativeMemoryAllocation> cache;
    private ExecutorService executor = Executors.newSingleThreadExecutor();
    private AtomicBoolean cacheCapacityReached = new AtomicBoolean(false);
    private long maxWeight = Long.MAX_VALUE;

    NativeMemoryCacheManager() {
        initialize();
    }

    public static synchronized NativeMemoryCacheManager getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new NativeMemoryCacheManager();
        }
        return INSTANCE;
    }

    private void initialize() {
        CacheBuilder removalListener = CacheBuilder.newBuilder().recordStats().concurrencyLevel(1).removalListener(this::onRemoval);
        if (((Boolean) KNNSettings.state().getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)).booleanValue()) {
            this.maxWeight = KNNSettings.getCircuitBreakerLimit().getKb();
            removalListener.maximumWeight(this.maxWeight).weigher((str, nativeMemoryAllocation) -> {
                return nativeMemoryAllocation.getSizeInKB();
            });
        }
        if (((Boolean) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)).booleanValue()) {
            removalListener.expireAfterAccess(((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)).getMinutes(), TimeUnit.MINUTES);
        }
        this.cacheCapacityReached = new AtomicBoolean(false);
        this.cache = removalListener.build();
    }

    public synchronized void rebuildCache() {
        logger.info("KNN Cache rebuilding.");
        this.executor.execute(() -> {
            this.cache.invalidateAll();
            initialize();
        });
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.executor.shutdown();
    }

    public long getCacheSizeInKilobytes() {
        return this.cache.asMap().values().stream().mapToLong((v0) -> {
            return v0.getSizeInKB();
        }).sum();
    }

    public Float getCacheSizeAsPercentage() {
        return getSizeAsPercentage(getCacheSizeInKilobytes());
    }

    public long getIndicesSizeInKilobytes() {
        return this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> {
            return nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation;
        }).mapToLong((v0) -> {
            return v0.getSizeInKB();
        }).sum();
    }

    public Float getIndicesSizeAsPercentage() {
        return getSizeAsPercentage(getIndicesSizeInKilobytes());
    }

    public Long getIndexSizeInKilobytes(String str) {
        Validate.notNull(str, "Index name cannot be null");
        return Long.valueOf(this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> {
            return nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation;
        }).filter(nativeMemoryAllocation2 -> {
            return str.equals(((NativeMemoryAllocation.IndexAllocation) nativeMemoryAllocation2).getOpenSearchIndexName());
        }).mapToLong((v0) -> {
            return v0.getSizeInKB();
        }).sum());
    }

    public Float getIndexSizeAsPercentage(String str) {
        Validate.notNull(str, "Index name cannot be null");
        return getSizeAsPercentage(getIndexSizeInKilobytes(str).longValue());
    }

    public long getTrainingSizeInKilobytes() {
        return this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> {
            return (nativeMemoryAllocation instanceof NativeMemoryAllocation.TrainingDataAllocation) || (nativeMemoryAllocation instanceof NativeMemoryAllocation.AnonymousAllocation);
        }).mapToLong((v0) -> {
            return v0.getSizeInKB();
        }).sum();
    }

    public Float getTrainingSizeAsPercentage() {
        return getSizeAsPercentage(getTrainingSizeInKilobytes());
    }

    public long getMaxCacheSizeInKilobytes() {
        return this.maxWeight;
    }

    public int getIndexGraphCount(String str) {
        Validate.notNull(str, "Index name cannot be null");
        return Long.valueOf(this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> {
            return nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation;
        }).filter(nativeMemoryAllocation2 -> {
            return str.equals(((NativeMemoryAllocation.IndexAllocation) nativeMemoryAllocation2).getOpenSearchIndexName());
        }).count()).intValue();
    }

    public CacheStats getCacheStats() {
        return this.cache.stats();
    }

    public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryContext, boolean z) throws ExecutionException {
        if (z || this.cache.asMap().containsKey(nativeMemoryEntryContext.getKey()) || (this.maxWeight - getCacheSizeInKilobytes()) - nativeMemoryEntryContext.calculateSizeInKB().intValue() > 0) {
            Cache<String, NativeMemoryAllocation> cache = this.cache;
            String key = nativeMemoryEntryContext.getKey();
            Objects.requireNonNull(nativeMemoryEntryContext);
            return (NativeMemoryAllocation) cache.get(key, nativeMemoryEntryContext::load);
        }
        Integer calculateSizeInKB = nativeMemoryEntryContext.calculateSizeInKB();
        long cacheSizeInKilobytes = getCacheSizeInKilobytes();
        long j = this.maxWeight;
        OutOfNativeMemoryException outOfNativeMemoryException = new OutOfNativeMemoryException("Entry cannot be loaded into cache because it would not fit. Entry size: " + calculateSizeInKB + " KB Current Cache Size: " + cacheSizeInKilobytes + " KB Max Cache Size: " + outOfNativeMemoryException);
        throw outOfNativeMemoryException;
    }

    public void invalidate(String str) {
        this.cache.invalidate(str);
    }

    public void invalidateAll() {
        this.cache.invalidateAll();
    }

    public Boolean isCacheCapacityReached() {
        return Boolean.valueOf(this.cacheCapacityReached.get());
    }

    public void setCacheCapacityReached(Boolean bool) {
        this.cacheCapacityReached.set(bool.booleanValue());
    }

    public Map<String, Map<String, Object>> getIndicesCacheStats() {
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : this.cache.asMap().entrySet()) {
            if (entry.getValue() instanceof NativeMemoryAllocation.IndexAllocation) {
                String openSearchIndexName = ((NativeMemoryAllocation.IndexAllocation) entry.getValue()).getOpenSearchIndexName();
                Map map = (Map) hashMap.computeIfAbsent(openSearchIndexName, str -> {
                    return new HashMap();
                });
                map.computeIfAbsent(GRAPH_COUNT, str2 -> {
                    return Integer.valueOf(getIndexGraphCount(openSearchIndexName));
                });
                map.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE.getName(), str3 -> {
                    return getIndexSizeInKilobytes(openSearchIndexName);
                });
                map.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName(), str4 -> {
                    return getIndexSizeAsPercentage(openSearchIndexName);
                });
            }
        }
        return hashMap;
    }

    private void onRemoval(RemovalNotification<String, NativeMemoryAllocation> removalNotification) {
        ((NativeMemoryAllocation) removalNotification.getValue()).close();
        if (RemovalCause.SIZE == removalNotification.getCause()) {
            KNNSettings.state().updateCircuitBreakerSettings(true);
            setCacheCapacityReached(true);
        }
        logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), removalNotification.getCause());
    }

    private Float getSizeAsPercentage(long j) {
        long kb = KNNSettings.getCircuitBreakerLimit().getKb();
        return kb == 0 ? Float.valueOf(0.0f) : Float.valueOf(((float) (100 * j)) / ((float) kb));
    }
}
