package org.opensearch.ml.common;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.HashMap;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;

/* loaded from: input_file:org/opensearch/ml/common/MLCommonsClassLoader.class */
public class MLCommonsClassLoader {

    @Generated
    private static final Logger log = LogManager.getLogger(MLCommonsClassLoader.class);
    private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> mlInputClassMap = new HashMap();
    private static Map<String, Class<?>> connectorClassMap = new HashMap();

    public static void loadClassMapping() {
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        try {
            Thread.currentThread().setContextClassLoader(MLCommonsClassLoader.class.getClassLoader());
            loadMLAlgoParameterClassMapping();
            loadMLOutputClassMapping();
            loadMLInputDataSetClassMapping();
            loadExecuteInputClassMapping();
            loadExecuteOutputClassMapping();
            loadMLInputClassMapping();
            loadConnectorClassMapping();
            Thread.currentThread().setContextClassLoader(contextClassLoader);
        } catch (Throwable th) {
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            throw th;
        }
    }

    private static void loadConnectorClassMapping() {
        String value;
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.connector", new Scanner[0]).getTypesAnnotatedWith(Connector.class)) {
            Connector connector = (Connector) cls.getAnnotation(Connector.class);
            if (connector != null && (value = connector.value()) != null && value.length() > 0) {
                connectorClassMap.put(value, cls);
            }
        }
    }

    private static void loadMLAlgoParameterClassMapping() {
        FunctionName[] algorithms;
        Reflections reflections = new Reflections("org.opensearch.ml.common.input.parameter", new Scanner[0]);
        for (Class<?> cls : reflections.getTypesAnnotatedWith(MLAlgoParameter.class)) {
            MLAlgoParameter mLAlgoParameter = (MLAlgoParameter) cls.getAnnotation(MLAlgoParameter.class);
            if (mLAlgoParameter != null && (algorithms = mLAlgoParameter.algorithms()) != null && algorithms.length > 0) {
                for (FunctionName functionName : algorithms) {
                    parameterClassMap.put(functionName, cls);
                }
            }
        }
        for (Class<?> cls2 : reflections.getTypesAnnotatedWith(MLAlgoOutput.class)) {
            MLOutputType value = ((MLAlgoOutput) cls2.getAnnotation(MLAlgoOutput.class)).value();
            if (value != null) {
                parameterClassMap.put(value, cls2);
            }
        }
    }

    private static void loadMLOutputClassMapping() {
        MLOutputType value;
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.output", new Scanner[0]).getTypesAnnotatedWith(MLAlgoOutput.class)) {
            MLAlgoOutput mLAlgoOutput = (MLAlgoOutput) cls.getAnnotation(MLAlgoOutput.class);
            if (mLAlgoOutput != null && (value = mLAlgoOutput.value()) != null) {
                parameterClassMap.put(value, cls);
            }
        }
    }

    private static void loadMLInputDataSetClassMapping() {
        MLInputDataType value;
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.dataset", new Scanner[0]).getTypesAnnotatedWith(InputDataSet.class)) {
            InputDataSet inputDataSet = (InputDataSet) cls.getAnnotation(InputDataSet.class);
            if (inputDataSet != null && (value = inputDataSet.value()) != null) {
                parameterClassMap.put(value, cls);
            }
        }
    }

    private static void loadExecuteInputClassMapping() {
        FunctionName[] algorithms;
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.input.execute", new Scanner[0]).getTypesAnnotatedWith(ExecuteInput.class)) {
            ExecuteInput executeInput = (ExecuteInput) cls.getAnnotation(ExecuteInput.class);
            if (executeInput != null && (algorithms = executeInput.algorithms()) != null && algorithms.length > 0) {
                for (FunctionName functionName : algorithms) {
                    executeInputClassMap.put(functionName, cls);
                }
            }
        }
    }

    private static void loadExecuteOutputClassMapping() {
        FunctionName[] algorithms;
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.output.execute", new Scanner[0]).getTypesAnnotatedWith(ExecuteOutput.class)) {
            ExecuteOutput executeOutput = (ExecuteOutput) cls.getAnnotation(ExecuteOutput.class);
            if (executeOutput != null && (algorithms = executeOutput.algorithms()) != null && algorithms.length > 0) {
                for (FunctionName functionName : algorithms) {
                    executeOutputClassMap.put(functionName, cls);
                }
            }
        }
    }

    private static void loadMLInputClassMapping() {
        FunctionName[] functionNames;
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.input", new Scanner[0]).getTypesAnnotatedWith(MLInput.class)) {
            MLInput mLInput = (MLInput) cls.getAnnotation(MLInput.class);
            if (mLInput != null && (functionNames = mLInput.functionNames()) != null && functionNames.length > 0) {
                for (FunctionName functionName : functionNames) {
                    mlInputClassMap.put(functionName, cls);
                }
            }
        }
    }

    public static <T extends Enum<T>, S, I> S initMLInstance(T t, I i, Class<?> cls) {
        return (S) init((Map<T, Class<?>>) parameterClassMap, t, (Object) i, cls);
    }

    public static <T extends Enum<T>, S, I> S initExecuteInputInstance(T t, I i, Class<?> cls) {
        try {
            return (S) init((Map<T, Class<?>>) executeInputClassMap, t, (Object) i, cls);
        } catch (Exception e) {
            return (S) init((Map<T, Class<?>>) mlInputClassMap, t, (Object) i, cls);
        }
    }

    public static <T extends Enum<T>, S, I> S initExecuteOutputInstance(T t, I i, Class<?> cls) {
        try {
            return (S) init((Map<T, Class<?>>) executeOutputClassMap, t, (Object) i, cls);
        } catch (Exception e) {
            if (!(i instanceof StreamInput)) {
                throw e;
            }
            try {
                return (S) MLOutput.fromStream((StreamInput) i);
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
        }
    }

    private static <T, S, I> S init(Map<T, Class<?>> map, T t, I i, Class<?> cls) {
        Class<?> cls2 = map.get(t);
        if (cls2 == null) {
            throw new IllegalArgumentException("Can't find class for type " + String.valueOf(t));
        }
        try {
            return (S) cls2.getConstructor(cls).newInstance(i);
        } catch (Exception e) {
            Throwable cause = e.getCause();
            if ((cause instanceof MLException) || (cause instanceof IllegalArgumentException)) {
                throw ((RuntimeException) cause);
            }
            log.error("Failed to init instance for type " + String.valueOf(t), e);
            return null;
        }
    }

    public static boolean canInitMLInput(FunctionName functionName) {
        return mlInputClassMap.containsKey(functionName);
    }

    public static <S> S initConnector(String str, Object[] objArr, Class<?>... clsArr) {
        return (S) init(connectorClassMap, str, objArr, clsArr);
    }

    public static <T extends Enum<T>, S> S initMLInput(T t, Object[] objArr, Class<?>... clsArr) {
        return (S) init((Map<T, Class<?>>) mlInputClassMap, t, objArr, clsArr);
    }

    private static <T, S> S init(Map<T, Class<?>> map, T t, Object[] objArr, Class<?>... clsArr) {
        Class<?> cls = map.get(t);
        if (cls == null) {
            throw new IllegalArgumentException("Can't find class for type " + String.valueOf(t));
        }
        try {
            return (S) cls.getConstructor(clsArr).newInstance(objArr);
        } catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException) {
                throw ((MLException) cause);
            }
            if (cause instanceof IllegalArgumentException) {
                throw ((IllegalArgumentException) cause);
            }
            log.error("Failed to init instance for type " + String.valueOf(t), e);
            return null;
        }
    }

    static {
        try {
            AccessController.doPrivileged(() -> {
                loadClassMapping();
                return null;
            });
        } catch (PrivilegedActionException e) {
            throw new RuntimeException("Can't load class mapping in ML commons", e);
        }
    }
}
