package org.opensearch.knn.training;

import java.io.IOException;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.threadpool.ThreadPool;

/* loaded from: input_file:org/opensearch/knn/training/TrainingJobRunner.class */
public class TrainingJobRunner {
    public static Logger logger = LogManager.getLogger(TrainingJobRunner.class);
    private static TrainingJobRunner INSTANCE;
    private static ModelDao modelDao;
    private static ThreadPool threadPool;
    private final AtomicInteger jobCount = new AtomicInteger(0);
    private final Semaphore semaphore = new Semaphore(1);

    public static synchronized TrainingJobRunner getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new TrainingJobRunner();
        }
        return INSTANCE;
    }

    private TrainingJobRunner() {
    }

    public static void initialize(ThreadPool threadPool2, ModelDao modelDao2) {
        threadPool = threadPool2;
        modelDao = modelDao2;
    }

    public void execute(TrainingJob trainingJob, ActionListener<IndexResponse> actionListener) throws IOException {
        if (!this.semaphore.tryAcquire()) {
            ValidationException validationException = new ValidationException();
            validationException.addValidationError("Unable to run training job: No training capacity on node.");
            KNNCounter.TRAINING_ERRORS.increment();
            throw validationException;
        }
        this.jobCount.incrementAndGet();
        try {
            serializeModel(trainingJob, ActionListener.wrap(indexResponse -> {
                actionListener.onResponse(indexResponse);
                train(trainingJob);
            }, exc -> {
                this.jobCount.decrementAndGet();
                this.semaphore.release();
                logger.error("Unable to initialize model serialization: " + exc.getMessage());
                actionListener.onFailure(exc);
            }), false);
        } catch (IOException e) {
            this.jobCount.decrementAndGet();
            this.semaphore.release();
            throw e;
        }
    }

    private void train(TrainingJob trainingJob) {
        ActionListener<IndexResponse> wrap = ActionListener.wrap(indexResponse -> {
            logger.debug("[KNN] Model serialization update for \"" + trainingJob.getModelId() + "\" was successful");
        }, exc -> {
            logger.error("[KNN] Model serialization update for \"" + trainingJob.getModelId() + "\" failed: " + exc.getMessage());
            KNNCounter.TRAINING_ERRORS.increment();
        });
        try {
            threadPool.executor(KNNConstants.TRAIN_THREAD_POOL).execute(() -> {
                try {
                    try {
                        trainingJob.run();
                        serializeModel(trainingJob, wrap, true);
                        this.jobCount.decrementAndGet();
                        this.semaphore.release();
                    } catch (IOException e) {
                        logger.error("Unable to serialize model \"" + trainingJob.getModelId() + "\": " + e.getMessage());
                        KNNCounter.TRAINING_ERRORS.increment();
                        this.jobCount.decrementAndGet();
                        this.semaphore.release();
                    } catch (Exception e2) {
                        logger.error("Unable to complete training for \"" + trainingJob.getModelId() + "\": " + e2.getMessage());
                        KNNCounter.TRAINING_ERRORS.increment();
                        this.jobCount.decrementAndGet();
                        this.semaphore.release();
                    }
                } catch (Throwable th) {
                    this.jobCount.decrementAndGet();
                    this.semaphore.release();
                    throw th;
                }
            });
        } catch (RejectedExecutionException e) {
            logger.error("Unable to train model \"" + trainingJob.getModelId() + "\": " + e.getMessage());
            ModelMetadata modelMetadata = trainingJob.getModel().getModelMetadata();
            modelMetadata.setState(ModelState.FAILED);
            modelMetadata.setError("Training job execution was rejected. Node's training queue is at capacity.");
            try {
                try {
                    serializeModel(trainingJob, wrap, true);
                    this.jobCount.decrementAndGet();
                    this.semaphore.release();
                    KNNCounter.TRAINING_ERRORS.increment();
                } catch (IOException e2) {
                    logger.error("Unable to serialize the failure for model \"" + trainingJob.getModelId() + "\": " + e2);
                    this.jobCount.decrementAndGet();
                    this.semaphore.release();
                    KNNCounter.TRAINING_ERRORS.increment();
                }
            } catch (Throwable th) {
                this.jobCount.decrementAndGet();
                this.semaphore.release();
                KNNCounter.TRAINING_ERRORS.increment();
                throw th;
            }
        }
    }

    private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> actionListener, boolean z) throws IOException {
        if (z) {
            modelDao.update(trainingJob.getModel(), actionListener);
        } else {
            modelDao.put(trainingJob.getModel(), actionListener);
        }
    }

    public int getJobCount() {
        return this.jobCount.get();
    }
}
