package org.deeplearning4j.rl4j.util;

import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/util/DataManagerTrainingListener.class */
public class DataManagerTrainingListener implements TrainingListener {
    private static final Logger log = LoggerFactory.getLogger(DataManagerTrainingListener.class);
    private final IDataManager dataManager;
    private int lastSave = -100000;

    public DataManagerTrainingListener(IDataManager iDataManager) {
        this.dataManager = iDataManager;
    }

    @Override // org.deeplearning4j.rl4j.learning.listener.TrainingListener
    public TrainingListener.ListenerResponse onTrainingStart() {
        return TrainingListener.ListenerResponse.CONTINUE;
    }

    @Override // org.deeplearning4j.rl4j.learning.listener.TrainingListener
    public void onTrainingEnd() {
    }

    @Override // org.deeplearning4j.rl4j.learning.listener.TrainingListener
    public TrainingListener.ListenerResponse onNewEpoch(IEpochTrainer iEpochTrainer) {
        IHistoryProcessor historyProcessor = iEpochTrainer.getHistoryProcessor();
        if (historyProcessor != null) {
            int[] shape = iEpochTrainer.getMdp().getObservationSpace().getShape();
            String str = this.dataManager.getVideoDir() + "/video-";
            if (iEpochTrainer instanceof AsyncThread) {
                str = str + ((AsyncThread) iEpochTrainer).getThreadNumber() + "-";
            }
            historyProcessor.startMonitor(str + iEpochTrainer.getEpochCount() + "-" + iEpochTrainer.getStepCount() + ".mp4", shape);
        }
        return TrainingListener.ListenerResponse.CONTINUE;
    }

    @Override // org.deeplearning4j.rl4j.learning.listener.TrainingListener
    public TrainingListener.ListenerResponse onEpochTrainingResult(IEpochTrainer iEpochTrainer, IDataManager.StatEntry statEntry) {
        IHistoryProcessor historyProcessor = iEpochTrainer.getHistoryProcessor();
        if (historyProcessor != null) {
            historyProcessor.stopMonitor();
        }
        try {
            this.dataManager.appendStat(statEntry);
            return TrainingListener.ListenerResponse.CONTINUE;
        } catch (Exception e) {
            log.error("Training failed.", e);
            return TrainingListener.ListenerResponse.STOP;
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.listener.TrainingListener
    public TrainingListener.ListenerResponse onTrainingProgress(ILearning iLearning) {
        try {
            int stepCount = iLearning.getStepCount();
            if (stepCount - this.lastSave >= 100000) {
                this.dataManager.save(iLearning);
                this.lastSave = stepCount;
            }
            this.dataManager.writeInfo(iLearning);
            return TrainingListener.ListenerResponse.CONTINUE;
        } catch (Exception e) {
            log.error("Training failed.", e);
            return TrainingListener.ListenerResponse.STOP;
        }
    }
}
