package org.deeplearning4j.rl4j.learning.sync;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/SyncLearning.class */
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> extends Learning<O, A, AS, NN> {
    private static final Logger log = LoggerFactory.getLogger(SyncLearning.class);
    private List<SyncTrainingListener> listeners;

    public SyncLearning(ILearning.LConfiguration lConfiguration) {
        super(lConfiguration);
        this.listeners = new ArrayList();
    }

    public void addListener(SyncTrainingListener syncTrainingListener) {
        this.listeners.add(syncTrainingListener);
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public void train() {
        log.info("training starting.");
        if (notifyTrainingStarted()) {
            while (getStepCounter() < getConfiguration().getMaxStep()) {
                preEpoch();
                if (!notifyEpochStarted()) {
                    break;
                }
                IDataManager.StatEntry trainEpoch = trainEpoch();
                postEpoch();
                if (!notifyEpochFinished(trainEpoch)) {
                    break;
                }
                log.info("Epoch: " + getEpochCounter() + ", reward: " + trainEpoch.getReward());
                incrementEpoch();
            }
        }
        notifyTrainingFinished();
    }

    private boolean notifyTrainingStarted() {
        SyncTrainingEvent syncTrainingEvent = new SyncTrainingEvent(this);
        Iterator<SyncTrainingListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            if (it.next().onTrainingStart(syncTrainingEvent) == SyncTrainingListener.ListenerResponse.STOP) {
                return false;
            }
        }
        return true;
    }

    private void notifyTrainingFinished() {
        Iterator<SyncTrainingListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            it.next().onTrainingEnd();
        }
    }

    private boolean notifyEpochStarted() {
        SyncTrainingEvent syncTrainingEvent = new SyncTrainingEvent(this);
        Iterator<SyncTrainingListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            if (it.next().onEpochStart(syncTrainingEvent) == SyncTrainingListener.ListenerResponse.STOP) {
                return false;
            }
        }
        return true;
    }

    private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
        SyncTrainingEpochEndEvent syncTrainingEpochEndEvent = new SyncTrainingEpochEndEvent(this, statEntry);
        Iterator<SyncTrainingListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            if (it.next().onEpochEnd(syncTrainingEpochEndEvent) == SyncTrainingListener.ListenerResponse.STOP) {
                return false;
            }
        }
        return true;
    }

    protected abstract void preEpoch();

    protected abstract void postEpoch();

    protected abstract IDataManager.StatEntry trainEpoch();
}
