package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import java.util.ArrayList;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.class */
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
    private final QLearning.QLConfiguration configuration;
    private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
    private DQNPolicy<O> policy;
    private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
    private final IDQN qNetwork;
    private IDQN targetQNetwork;
    private int lastAction;
    private double accuReward;
    ITDTargetAlgorithm tdTargetAlgorithm;

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
        return this.mdp;
    }

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearning.QLConfiguration qLConfiguration, int i) {
        this(mdp, idqn, qLConfiguration, i, Nd4j.getRandomFactory().getNewRandomInstance(qLConfiguration.getSeed().intValue()));
    }

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearning.QLConfiguration qLConfiguration, int i, Random random) {
        super(qLConfiguration);
        this.accuReward = 0.0d;
        this.configuration = qLConfiguration;
        this.mdp = new LegacyMDPWrapper<>(mdp, this);
        this.qNetwork = idqn;
        this.targetQNetwork = idqn.m22clone();
        this.policy = new DQNPolicy<>(getQNetwork());
        this.egPolicy = new EpsGreedy<>(this.policy, mdp, qLConfiguration.getUpdateStart(), i, random, qLConfiguration.getMinEpsilon(), this);
        this.tdTargetAlgorithm = qLConfiguration.isDoubleDQN() ? new DoubleDQN(this, qLConfiguration.getGamma(), qLConfiguration.getErrorClamp()) : new StandardDQN(this, qLConfiguration.getGamma(), qLConfiguration.getErrorClamp());
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp.getWrappedMDP();
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void postEpoch() {
        if (getHistoryProcessor() != null) {
            getHistoryProcessor().stopMonitor();
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void preEpoch() {
        this.lastAction = 0;
        this.accuReward = 0.0d;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected QLearning.QLStepReturn<Observation> trainStep(Observation observation) {
        Integer nextAction;
        boolean z = getHistoryProcessor() != null;
        int skipFrame = z ? getHistoryProcessor().getConf().getSkipFrame() : 1;
        int updateStart = getConfiguration().getUpdateStart() + ((getConfiguration().getBatchSize() + (z ? getHistoryProcessor().getConf().getHistoryLength() : 1)) * skipFrame);
        Double valueOf = Double.valueOf(Double.NaN);
        if (getStepCounter() % skipFrame != 0) {
            nextAction = Integer.valueOf(this.lastAction);
        } else {
            valueOf = Double.valueOf(getQNetwork().output(observation).getDouble(Learning.getMaxAction(r0).intValue()));
            nextAction = getEgPolicy().nextAction(observation);
        }
        this.lastAction = nextAction.intValue();
        StepReply<Observation> step = this.mdp.step(nextAction);
        Observation observation2 = (Observation) step.getObservation();
        this.accuReward += step.getReward() * this.configuration.getRewardFactor();
        if (getStepCounter() % skipFrame == 0 || step.isDone()) {
            getExpReplay().store(new Transition<>(observation, nextAction, this.accuReward, step.isDone(), observation2));
            if (getStepCounter() > updateStart) {
                DataSet target = setTarget(getExpReplay().getBatch());
                getQNetwork().fit(target.getFeatures(), target.getLabels());
            }
            this.accuReward = 0.0d;
        }
        return new QLearning.QLStepReturn<>(valueOf, getQNetwork().getLatestScore(), step);
    }

    protected DataSet setTarget(ArrayList<Transition<Integer>> arrayList) {
        if (arrayList.size() == 0) {
            throw new IllegalArgumentException("too few transitions");
        }
        return this.tdTargetAlgorithm.computeTDTargets(arrayList);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public QLearning.QLConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public DQNPolicy<O> getPolicy() {
        return this.policy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public EpsGreedy<O, Integer, DiscreteSpace> getEgPolicy() {
        return this.egPolicy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource
    public IDQN getQNetwork() {
        return this.qNetwork;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource
    public IDQN getTargetQNetwork() {
        return this.targetQNetwork;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected void setTargetQNetwork(IDQN idqn) {
        this.targetQNetwork = idqn;
    }
}
