package org.deeplearning4j.rl4j.learning.sync.qlearning;

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.class */
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>> extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource, IEpochTrainer {
    private static final Logger log = LoggerFactory.getLogger(QLearning.class);
    private int episodeCount;
    private int currentEpisodeStepCount = 0;

    @JsonDeserialize(builder = QLConfigurationBuilder.class)
    @Deprecated
    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning$QLConfiguration.class */
    public static class QLConfiguration {
        Integer seed;
        int maxEpochStep;
        int maxStep;
        int expRepMaxSize;
        int batchSize;
        int targetDqnUpdateFreq;
        int updateStart;
        double rewardFactor;
        double gamma;
        double errorClamp;
        float minEpsilon;
        int epsilonNbStep;
        boolean doubleDQN;

        @JsonPOJOBuilder(withPrefix = "")
        /* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning$QLConfiguration$QLConfigurationBuilder.class */
        public static final class QLConfigurationBuilder {
            private Integer seed;
            private int maxEpochStep;
            private int maxStep;
            private int expRepMaxSize;
            private int batchSize;
            private int targetDqnUpdateFreq;
            private int updateStart;
            private double rewardFactor;
            private double gamma;
            private double errorClamp;
            private float minEpsilon;
            private int epsilonNbStep;
            private boolean doubleDQN;

            QLConfigurationBuilder() {
            }

            public QLConfigurationBuilder seed(Integer num) {
                this.seed = num;
                return this;
            }

            public QLConfigurationBuilder maxEpochStep(int i) {
                this.maxEpochStep = i;
                return this;
            }

            public QLConfigurationBuilder maxStep(int i) {
                this.maxStep = i;
                return this;
            }

            public QLConfigurationBuilder expRepMaxSize(int i) {
                this.expRepMaxSize = i;
                return this;
            }

            public QLConfigurationBuilder batchSize(int i) {
                this.batchSize = i;
                return this;
            }

            public QLConfigurationBuilder targetDqnUpdateFreq(int i) {
                this.targetDqnUpdateFreq = i;
                return this;
            }

            public QLConfigurationBuilder updateStart(int i) {
                this.updateStart = i;
                return this;
            }

            public QLConfigurationBuilder rewardFactor(double d) {
                this.rewardFactor = d;
                return this;
            }

            public QLConfigurationBuilder gamma(double d) {
                this.gamma = d;
                return this;
            }

            public QLConfigurationBuilder errorClamp(double d) {
                this.errorClamp = d;
                return this;
            }

            public QLConfigurationBuilder minEpsilon(float f) {
                this.minEpsilon = f;
                return this;
            }

            public QLConfigurationBuilder epsilonNbStep(int i) {
                this.epsilonNbStep = i;
                return this;
            }

            public QLConfigurationBuilder doubleDQN(boolean z) {
                this.doubleDQN = z;
                return this;
            }

            public QLConfiguration build() {
                return new QLConfiguration(this.seed, this.maxEpochStep, this.maxStep, this.expRepMaxSize, this.batchSize, this.targetDqnUpdateFreq, this.updateStart, this.rewardFactor, this.gamma, this.errorClamp, this.minEpsilon, this.epsilonNbStep, this.doubleDQN);
            }

            public String toString() {
                return "QLearning.QLConfiguration.QLConfigurationBuilder(seed=" + this.seed + ", maxEpochStep=" + this.maxEpochStep + ", maxStep=" + this.maxStep + ", expRepMaxSize=" + this.expRepMaxSize + ", batchSize=" + this.batchSize + ", targetDqnUpdateFreq=" + this.targetDqnUpdateFreq + ", updateStart=" + this.updateStart + ", rewardFactor=" + this.rewardFactor + ", gamma=" + this.gamma + ", errorClamp=" + this.errorClamp + ", minEpsilon=" + this.minEpsilon + ", epsilonNbStep=" + this.epsilonNbStep + ", doubleDQN=" + this.doubleDQN + ")";
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public QLearningConfiguration toLearningConfiguration() {
            return ((QLearningConfiguration.QLearningConfigurationBuilder) ((QLearningConfiguration.QLearningConfigurationBuilder) ((QLearningConfiguration.QLearningConfigurationBuilder) ((QLearningConfiguration.QLearningConfigurationBuilder) ((QLearningConfiguration.QLearningConfigurationBuilder) QLearningConfiguration.builder().seed(Long.valueOf(this.seed.longValue()))).maxEpochStep(this.maxEpochStep)).maxStep(this.maxStep)).expRepMaxSize(this.expRepMaxSize).batchSize(this.batchSize).targetDqnUpdateFreq(this.targetDqnUpdateFreq).updateStart(this.updateStart).rewardFactor(this.rewardFactor)).gamma(this.gamma)).errorClamp(this.errorClamp).minEpsilon(this.minEpsilon).epsilonNbStep(this.epsilonNbStep).doubleDQN(this.doubleDQN).build();
        }

        public static QLConfigurationBuilder builder() {
            return new QLConfigurationBuilder();
        }

        public Integer getSeed() {
            return this.seed;
        }

        public int getMaxEpochStep() {
            return this.maxEpochStep;
        }

        public int getMaxStep() {
            return this.maxStep;
        }

        public int getExpRepMaxSize() {
            return this.expRepMaxSize;
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public int getTargetDqnUpdateFreq() {
            return this.targetDqnUpdateFreq;
        }

        public int getUpdateStart() {
            return this.updateStart;
        }

        public double getRewardFactor() {
            return this.rewardFactor;
        }

        public double getGamma() {
            return this.gamma;
        }

        public double getErrorClamp() {
            return this.errorClamp;
        }

        public float getMinEpsilon() {
            return this.minEpsilon;
        }

        public int getEpsilonNbStep() {
            return this.epsilonNbStep;
        }

        public boolean isDoubleDQN() {
            return this.doubleDQN;
        }

        public void setSeed(Integer num) {
            this.seed = num;
        }

        public void setMaxEpochStep(int i) {
            this.maxEpochStep = i;
        }

        public void setMaxStep(int i) {
            this.maxStep = i;
        }

        public void setExpRepMaxSize(int i) {
            this.expRepMaxSize = i;
        }

        public void setBatchSize(int i) {
            this.batchSize = i;
        }

        public void setTargetDqnUpdateFreq(int i) {
            this.targetDqnUpdateFreq = i;
        }

        public void setUpdateStart(int i) {
            this.updateStart = i;
        }

        public void setRewardFactor(double d) {
            this.rewardFactor = d;
        }

        public void setGamma(double d) {
            this.gamma = d;
        }

        public void setErrorClamp(double d) {
            this.errorClamp = d;
        }

        public void setMinEpsilon(float f) {
            this.minEpsilon = f;
        }

        public void setEpsilonNbStep(int i) {
            this.epsilonNbStep = i;
        }

        public void setDoubleDQN(boolean z) {
            this.doubleDQN = z;
        }

        public String toString() {
            return "QLearning.QLConfiguration(seed=" + getSeed() + ", maxEpochStep=" + getMaxEpochStep() + ", maxStep=" + getMaxStep() + ", expRepMaxSize=" + getExpRepMaxSize() + ", batchSize=" + getBatchSize() + ", targetDqnUpdateFreq=" + getTargetDqnUpdateFreq() + ", updateStart=" + getUpdateStart() + ", rewardFactor=" + getRewardFactor() + ", gamma=" + getGamma() + ", errorClamp=" + getErrorClamp() + ", minEpsilon=" + getMinEpsilon() + ", epsilonNbStep=" + getEpsilonNbStep() + ", doubleDQN=" + isDoubleDQN() + ")";
        }

        public QLConfiguration(Integer num, int i, int i2, int i3, int i4, int i5, int i6, double d, double d2, double d3, float f, int i7, boolean z) {
            this.seed = num;
            this.maxEpochStep = i;
            this.maxStep = i2;
            this.expRepMaxSize = i3;
            this.batchSize = i4;
            this.targetDqnUpdateFreq = i5;
            this.updateStart = i6;
            this.rewardFactor = d;
            this.gamma = d2;
            this.errorClamp = d3;
            this.minEpsilon = f;
            this.epsilonNbStep = i7;
            this.doubleDQN = z;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof QLConfiguration)) {
                return false;
            }
            QLConfiguration qLConfiguration = (QLConfiguration) obj;
            if (!qLConfiguration.canEqual(this)) {
                return false;
            }
            Integer seed = getSeed();
            Integer seed2 = qLConfiguration.getSeed();
            if (seed == null) {
                if (seed2 != null) {
                    return false;
                }
            } else if (!seed.equals(seed2)) {
                return false;
            }
            return getMaxEpochStep() == qLConfiguration.getMaxEpochStep() && getMaxStep() == qLConfiguration.getMaxStep() && getExpRepMaxSize() == qLConfiguration.getExpRepMaxSize() && getBatchSize() == qLConfiguration.getBatchSize() && getTargetDqnUpdateFreq() == qLConfiguration.getTargetDqnUpdateFreq() && getUpdateStart() == qLConfiguration.getUpdateStart() && Double.compare(getRewardFactor(), qLConfiguration.getRewardFactor()) == 0 && Double.compare(getGamma(), qLConfiguration.getGamma()) == 0 && Double.compare(getErrorClamp(), qLConfiguration.getErrorClamp()) == 0 && Float.compare(getMinEpsilon(), qLConfiguration.getMinEpsilon()) == 0 && getEpsilonNbStep() == qLConfiguration.getEpsilonNbStep() && isDoubleDQN() == qLConfiguration.isDoubleDQN();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof QLConfiguration;
        }

        public int hashCode() {
            Integer seed = getSeed();
            int hashCode = (((((((((((((1 * 59) + (seed == null ? 43 : seed.hashCode())) * 59) + getMaxEpochStep()) * 59) + getMaxStep()) * 59) + getExpRepMaxSize()) * 59) + getBatchSize()) * 59) + getTargetDqnUpdateFreq()) * 59) + getUpdateStart();
            long doubleToLongBits = Double.doubleToLongBits(getRewardFactor());
            int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getGamma());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(getErrorClamp());
            return (((((((i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3))) * 59) + Float.floatToIntBits(getMinEpsilon())) * 59) + getEpsilonNbStep()) * 59) + (isDoubleDQN() ? 79 : 97);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning$QLStatEntry.class */
    public static final class QLStatEntry implements IDataManager.StatEntry {
        private final int stepCounter;
        private final int epochCounter;
        private final double reward;
        private final int episodeLength;
        private final List<Double> scores;
        private final double epsilon;
        private final double startQ;
        private final double meanQ;

        /* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning$QLStatEntry$QLStatEntryBuilder.class */
        public static class QLStatEntryBuilder {
            private int stepCounter;
            private int epochCounter;
            private double reward;
            private int episodeLength;
            private List<Double> scores;
            private double epsilon;
            private double startQ;
            private double meanQ;

            QLStatEntryBuilder() {
            }

            public QLStatEntryBuilder stepCounter(int i) {
                this.stepCounter = i;
                return this;
            }

            public QLStatEntryBuilder epochCounter(int i) {
                this.epochCounter = i;
                return this;
            }

            public QLStatEntryBuilder reward(double d) {
                this.reward = d;
                return this;
            }

            public QLStatEntryBuilder episodeLength(int i) {
                this.episodeLength = i;
                return this;
            }

            public QLStatEntryBuilder scores(List<Double> list) {
                this.scores = list;
                return this;
            }

            public QLStatEntryBuilder epsilon(double d) {
                this.epsilon = d;
                return this;
            }

            public QLStatEntryBuilder startQ(double d) {
                this.startQ = d;
                return this;
            }

            public QLStatEntryBuilder meanQ(double d) {
                this.meanQ = d;
                return this;
            }

            public QLStatEntry build() {
                return new QLStatEntry(this.stepCounter, this.epochCounter, this.reward, this.episodeLength, this.scores, this.epsilon, this.startQ, this.meanQ);
            }

            public String toString() {
                return "QLearning.QLStatEntry.QLStatEntryBuilder(stepCounter=" + this.stepCounter + ", epochCounter=" + this.epochCounter + ", reward=" + this.reward + ", episodeLength=" + this.episodeLength + ", scores=" + this.scores + ", epsilon=" + this.epsilon + ", startQ=" + this.startQ + ", meanQ=" + this.meanQ + ")";
            }
        }

        public static QLStatEntryBuilder builder() {
            return new QLStatEntryBuilder();
        }

        @Override // org.deeplearning4j.rl4j.util.IDataManager.StatEntry
        public int getStepCounter() {
            return this.stepCounter;
        }

        @Override // org.deeplearning4j.rl4j.util.IDataManager.StatEntry
        public int getEpochCounter() {
            return this.epochCounter;
        }

        @Override // org.deeplearning4j.rl4j.util.IDataManager.StatEntry
        public double getReward() {
            return this.reward;
        }

        public int getEpisodeLength() {
            return this.episodeLength;
        }

        public List<Double> getScores() {
            return this.scores;
        }

        public double getEpsilon() {
            return this.epsilon;
        }

        public double getStartQ() {
            return this.startQ;
        }

        public double getMeanQ() {
            return this.meanQ;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof QLStatEntry)) {
                return false;
            }
            QLStatEntry qLStatEntry = (QLStatEntry) obj;
            if (getStepCounter() != qLStatEntry.getStepCounter() || getEpochCounter() != qLStatEntry.getEpochCounter() || Double.compare(getReward(), qLStatEntry.getReward()) != 0 || getEpisodeLength() != qLStatEntry.getEpisodeLength()) {
                return false;
            }
            List<Double> scores = getScores();
            List<Double> scores2 = qLStatEntry.getScores();
            if (scores == null) {
                if (scores2 != null) {
                    return false;
                }
            } else if (!scores.equals(scores2)) {
                return false;
            }
            return Double.compare(getEpsilon(), qLStatEntry.getEpsilon()) == 0 && Double.compare(getStartQ(), qLStatEntry.getStartQ()) == 0 && Double.compare(getMeanQ(), qLStatEntry.getMeanQ()) == 0;
        }

        public int hashCode() {
            int stepCounter = (((1 * 59) + getStepCounter()) * 59) + getEpochCounter();
            long doubleToLongBits = Double.doubleToLongBits(getReward());
            int episodeLength = (((stepCounter * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + getEpisodeLength();
            List<Double> scores = getScores();
            int hashCode = (episodeLength * 59) + (scores == null ? 43 : scores.hashCode());
            long doubleToLongBits2 = Double.doubleToLongBits(getEpsilon());
            int i = (hashCode * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(getStartQ());
            int i2 = (i * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
            long doubleToLongBits4 = Double.doubleToLongBits(getMeanQ());
            return (i2 * 59) + ((int) ((doubleToLongBits4 >>> 32) ^ doubleToLongBits4));
        }

        public String toString() {
            return "QLearning.QLStatEntry(stepCounter=" + getStepCounter() + ", epochCounter=" + getEpochCounter() + ", reward=" + getReward() + ", episodeLength=" + getEpisodeLength() + ", scores=" + getScores() + ", epsilon=" + getEpsilon() + ", startQ=" + getStartQ() + ", meanQ=" + getMeanQ() + ")";
        }

        public QLStatEntry(int i, int i2, double d, int i3, List<Double> list, double d2, double d3, double d4) {
            this.stepCounter = i;
            this.epochCounter = i2;
            this.reward = d;
            this.episodeLength = i3;
            this.scores = list;
            this.epsilon = d2;
            this.startQ = d3;
            this.meanQ = d4;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning$QLStepReturn.class */
    public static final class QLStepReturn<O> {
        private final Double maxQ;
        private final double score;
        private final StepReply<O> stepReply;

        /* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning$QLStepReturn$QLStepReturnBuilder.class */
        public static class QLStepReturnBuilder<O> {
            private Double maxQ;
            private double score;
            private StepReply<O> stepReply;

            QLStepReturnBuilder() {
            }

            public QLStepReturnBuilder<O> maxQ(Double d) {
                this.maxQ = d;
                return this;
            }

            public QLStepReturnBuilder<O> score(double d) {
                this.score = d;
                return this;
            }

            public QLStepReturnBuilder<O> stepReply(StepReply<O> stepReply) {
                this.stepReply = stepReply;
                return this;
            }

            public QLStepReturn<O> build() {
                return new QLStepReturn<>(this.maxQ, this.score, this.stepReply);
            }

            public String toString() {
                return "QLearning.QLStepReturn.QLStepReturnBuilder(maxQ=" + this.maxQ + ", score=" + this.score + ", stepReply=" + this.stepReply + ")";
            }
        }

        public static <O> QLStepReturnBuilder<O> builder() {
            return new QLStepReturnBuilder<>();
        }

        public Double getMaxQ() {
            return this.maxQ;
        }

        public double getScore() {
            return this.score;
        }

        public StepReply<O> getStepReply() {
            return this.stepReply;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof QLStepReturn)) {
                return false;
            }
            QLStepReturn qLStepReturn = (QLStepReturn) obj;
            Double maxQ = getMaxQ();
            Double maxQ2 = qLStepReturn.getMaxQ();
            if (maxQ == null) {
                if (maxQ2 != null) {
                    return false;
                }
            } else if (!maxQ.equals(maxQ2)) {
                return false;
            }
            if (Double.compare(getScore(), qLStepReturn.getScore()) != 0) {
                return false;
            }
            StepReply<O> stepReply = getStepReply();
            StepReply<O> stepReply2 = qLStepReturn.getStepReply();
            return stepReply == null ? stepReply2 == null : stepReply.equals(stepReply2);
        }

        public int hashCode() {
            Double maxQ = getMaxQ();
            int hashCode = (1 * 59) + (maxQ == null ? 43 : maxQ.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getScore());
            int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            StepReply<O> stepReply = getStepReply();
            return (i * 59) + (stepReply == null ? 43 : stepReply.hashCode());
        }

        public String toString() {
            return "QLearning.QLStepReturn(maxQ=" + getMaxQ() + ", score=" + getScore() + ", stepReply=" + getStepReply() + ")";
        }

        public QLStepReturn(Double d, double d2, StepReply<O> stepReply) {
            this.maxQ = d;
            this.score = d2;
            this.stepReply = stepReply;
        }
    }

    protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();

    protected abstract EpsGreedy<O, A, AS> getEgPolicy();

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public abstract MDP<O, A, AS> getMdp();

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource
    public abstract IDQN getQNetwork();

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource
    public abstract IDQN getTargetQNetwork();

    protected abstract void setTargetQNetwork(IDQN idqn);

    protected void updateTargetNetwork() {
        log.info("Update target network");
        setTargetQNetwork(getQNetwork().m26clone());
    }

    @Override // org.deeplearning4j.rl4j.learning.Learning, org.deeplearning4j.rl4j.learning.NeuralNetFetchable
    public IDQN getNeuralNet() {
        return getQNetwork();
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public abstract QLearningConfiguration getConfiguration();

    @Override // org.deeplearning4j.rl4j.learning.sync.SyncLearning
    protected abstract void preEpoch();

    @Override // org.deeplearning4j.rl4j.learning.sync.SyncLearning
    protected abstract void postEpoch();

    protected abstract QLStepReturn<Observation> trainStep(Observation observation);

    @Override // org.deeplearning4j.rl4j.learning.sync.SyncLearning
    protected IDataManager.StatEntry trainEpoch() {
        resetNetworks();
        Learning.InitMdp<Observation> refacInitMdp = refacInitMdp();
        Observation lastObs = refacInitMdp.getLastObs();
        double reward = refacInitMdp.getReward();
        Double valueOf = Double.valueOf(Double.NaN);
        double d = 0.0d;
        int i = 0;
        ArrayList arrayList = new ArrayList();
        while (this.currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
            if (getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
                updateTargetNetwork();
            }
            QLStepReturn<Observation> trainStep = trainStep(lastObs);
            if (!trainStep.getMaxQ().isNaN()) {
                if (valueOf.isNaN()) {
                    valueOf = trainStep.getMaxQ();
                }
                i++;
                d += trainStep.getMaxQ().doubleValue();
            }
            if (trainStep.getScore() != 0.0d) {
                arrayList.add(Double.valueOf(trainStep.getScore()));
            }
            reward += trainStep.getStepReply().getReward();
            lastObs = (Observation) trainStep.getStepReply().getObservation();
            incrementStep();
        }
        finishEpoch(lastObs);
        return new QLStatEntry(getStepCount(), getEpochCount(), reward, this.currentEpisodeStepCount, arrayList, getEgPolicy().getEpsilon(), valueOf.doubleValue(), d / (i + 0.001d));
    }

    protected void finishEpoch(Observation observation) {
        this.episodeCount++;
    }

    @Override // org.deeplearning4j.rl4j.learning.Learning
    public void incrementStep() {
        super.incrementStep();
        this.currentEpisodeStepCount++;
    }

    protected void resetNetworks() {
        getQNetwork().reset();
        getTargetQNetwork().reset();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Learning.InitMdp<Observation> refacInitMdp() {
        this.currentEpisodeStepCount = 0;
        double d = 0.0d;
        LegacyMDPWrapper<O, A, AS> legacyMDPWrapper = getLegacyMDPWrapper();
        Observation m34reset = legacyMDPWrapper.m34reset();
        Object noOp = legacyMDPWrapper.getActionSpace().noOp();
        while (m34reset.isSkipped() && !legacyMDPWrapper.isDone()) {
            StepReply<Observation> step = legacyMDPWrapper.step(noOp);
            d += step.getReward();
            m34reset = (Observation) step.getObservation();
            incrementStep();
        }
        return new Learning.InitMdp<>(0, m34reset, d);
    }

    @Override // org.deeplearning4j.rl4j.learning.IEpochTrainer
    public int getEpisodeCount() {
        return this.episodeCount;
    }

    @Override // org.deeplearning4j.rl4j.learning.IEpochTrainer
    public int getCurrentEpisodeStepCount() {
        return this.currentEpisodeStepCount;
    }
}
