package org.deeplearning4j.rl4j.learning.async;

import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncLearning.class */
public abstract class AsyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet> extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN> implements IAsyncLearning {
    private static final Logger log = LoggerFactory.getLogger(AsyncLearning.class);
    private Thread monitorThread = null;
    private final TrainingListenerList listeners = new TrainingListenerList();
    private boolean canContinue = true;
    private int progressMonitorFrequency = 20000;

    public void addListener(TrainingListener trainingListener) {
        this.listeners.add(trainingListener);
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public abstract IAsyncLearningConfiguration getConfiguration();

    protected abstract AsyncThread newThread(int i, int i2);

    protected abstract IAsyncGlobal<NN> getAsyncGlobal();

    protected boolean isTrainingComplete() {
        return getAsyncGlobal().isTrainingComplete();
    }

    private void launchThreads() {
        for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
            newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices()).start();
        }
        log.info("Threads launched.");
    }

    @Override // org.deeplearning4j.rl4j.learning.Learning, org.deeplearning4j.rl4j.learning.ILearning
    public int getStepCount() {
        return getAsyncGlobal().getStepCount();
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public void train() {
        log.info("AsyncLearning training starting.");
        this.canContinue = this.listeners.notifyTrainingStarted();
        if (this.canContinue) {
            launchThreads();
            monitorTraining();
        }
        this.listeners.notifyTrainingFinished();
    }

    protected void monitorTraining() {
        try {
            this.monitorThread = Thread.currentThread();
            while (this.canContinue && !isTrainingComplete()) {
                this.canContinue = this.listeners.notifyTrainingProgress(this);
                if (!this.canContinue) {
                    return;
                }
                synchronized (this) {
                    wait(this.progressMonitorFrequency);
                }
            }
        } catch (InterruptedException e) {
            log.error("Training interrupted.", e);
        }
        this.monitorThread = null;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncLearning
    public void terminate() {
        if (this.canContinue) {
            this.canContinue = false;
            Thread thread = this.monitorThread;
            if (thread != null) {
                thread.interrupt();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TrainingListenerList getListeners() {
        return this.listeners;
    }

    public int getProgressMonitorFrequency() {
        return this.progressMonitorFrequency;
    }

    public void setProgressMonitorFrequency(int i) {
        this.progressMonitorFrequency = i;
    }
}
