package org.deeplearning4j.rl4j.learning.sync;

import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/SyncLearning.class */
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> extends Learning<O, A, AS, NN> implements IEpochTrainer {
    private static final Logger log = LoggerFactory.getLogger(SyncLearning.class);
    private final TrainingListenerList listeners;
    private int progressMonitorFrequency;

    public SyncLearning(ILearning.LConfiguration lConfiguration) {
        super(lConfiguration);
        this.listeners = new TrainingListenerList();
        this.progressMonitorFrequency = 5;
    }

    public void addListener(TrainingListener trainingListener) {
        this.listeners.add(trainingListener);
    }

    public void setProgressMonitorFrequency(int i) {
        if (i == 0) {
            throw new IllegalArgumentException("The progressMonitorFrequency cannot be 0");
        }
        this.progressMonitorFrequency = i;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public void train() {
        log.info("training starting.");
        if (this.listeners.notifyTrainingStarted()) {
            while (getStepCounter() < getConfiguration().getMaxStep()) {
                preEpoch();
                if (!this.listeners.notifyNewEpoch(this)) {
                    break;
                }
                IDataManager.StatEntry trainEpoch = trainEpoch();
                if (!this.listeners.notifyEpochTrainingResult(this, trainEpoch)) {
                    break;
                }
                postEpoch();
                if (getEpochCounter() % this.progressMonitorFrequency == 0 && !this.listeners.notifyTrainingProgress(this)) {
                    break;
                }
                log.info("Epoch: " + getEpochCounter() + ", reward: " + trainEpoch.getReward());
                incrementEpoch();
            }
        }
        this.listeners.notifyTrainingFinished();
    }

    protected abstract void preEpoch();

    protected abstract void postEpoch();

    protected abstract IDataManager.StatEntry trainEpoch();

    public int getProgressMonitorFrequency() {
        return this.progressMonitorFrequency;
    }
}
