package org.deeplearning4j.rl4j.util;

import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.class */
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
    private static final Logger log = LoggerFactory.getLogger(DataManagerSyncTrainingListener.class);
    private final IDataManager dataManager;
    private final int saveFrequency;
    private final int monitorFrequency;
    private int lastSave;
    private int lastMonitor;

    /* loaded from: input_file:org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener$Builder.class */
    public static class Builder {
        private final IDataManager dataManager;
        private int saveFrequency = Constants.MODEL_SAVE_FREQ;
        private int monitorFrequency = Constants.MONITOR_FREQ;

        public Builder(IDataManager iDataManager) {
            this.dataManager = iDataManager;
        }

        public Builder saveFrequency(int i) {
            this.saveFrequency = i;
            return this;
        }

        public Builder monitorFrequency(int i) {
            this.monitorFrequency = i;
            return this;
        }

        public DataManagerSyncTrainingListener build() {
            return new DataManagerSyncTrainingListener(this);
        }
    }

    private DataManagerSyncTrainingListener(Builder builder) {
        this.dataManager = builder.dataManager;
        this.saveFrequency = builder.saveFrequency;
        this.lastSave = -builder.saveFrequency;
        this.monitorFrequency = builder.monitorFrequency;
        this.lastMonitor = -builder.monitorFrequency;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener
    public SyncTrainingListener.ListenerResponse onTrainingStart(SyncTrainingEvent syncTrainingEvent) {
        try {
            this.dataManager.writeInfo(syncTrainingEvent.getLearning());
            return SyncTrainingListener.ListenerResponse.CONTINUE;
        } catch (Exception e) {
            log.error("Training failed.", e);
            return SyncTrainingListener.ListenerResponse.STOP;
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener
    public void onTrainingEnd() {
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener
    public SyncTrainingListener.ListenerResponse onEpochStart(SyncTrainingEvent syncTrainingEvent) {
        int stepCounter = syncTrainingEvent.getLearning().getStepCounter();
        if (stepCounter - this.lastMonitor >= this.monitorFrequency && syncTrainingEvent.getLearning().getHistoryProcessor() != null && this.dataManager.isSaveData()) {
            this.lastMonitor = stepCounter;
            syncTrainingEvent.getLearning().getHistoryProcessor().startMonitor(this.dataManager.getVideoDir() + "/video-" + syncTrainingEvent.getLearning().getEpochCounter() + "-" + stepCounter + ".mp4", syncTrainingEvent.getLearning().getMdp().getObservationSpace().getShape());
        }
        return SyncTrainingListener.ListenerResponse.CONTINUE;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener
    public SyncTrainingListener.ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent syncTrainingEpochEndEvent) {
        try {
            int stepCounter = syncTrainingEpochEndEvent.getLearning().getStepCounter();
            if (stepCounter - this.lastSave >= this.saveFrequency) {
                this.dataManager.save(syncTrainingEpochEndEvent.getLearning());
                this.lastSave = stepCounter;
            }
            this.dataManager.appendStat(syncTrainingEpochEndEvent.getStatEntry());
            this.dataManager.writeInfo(syncTrainingEpochEndEvent.getLearning());
            return SyncTrainingListener.ListenerResponse.CONTINUE;
        } catch (Exception e) {
            log.error("Training failed.", e);
            return SyncTrainingListener.ListenerResponse.STOP;
        }
    }

    public static Builder builder(IDataManager iDataManager) {
        return new Builder(iDataManager);
    }
}
