package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import java.util.ArrayList;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.BaseTDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.class */
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
    private final QLearning.QLConfiguration configuration;
    private final MDP<O, Integer, DiscreteSpace> mdp;
    private DQNPolicy<O> policy;
    private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
    private final IDQN qNetwork;
    private IDQN targetQNetwork;
    private int lastAction;
    private INDArray[] history;
    private double accuReward;
    ITDTargetAlgorithm tdTargetAlgorithm;

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearning.QLConfiguration qLConfiguration, int i) {
        super(qLConfiguration);
        this.history = null;
        this.accuReward = 0.0d;
        this.configuration = qLConfiguration;
        this.mdp = mdp;
        this.qNetwork = idqn;
        this.targetQNetwork = idqn.m22clone();
        this.policy = new DQNPolicy<>(getQNetwork());
        this.egPolicy = new EpsGreedy<>(this.policy, mdp, qLConfiguration.getUpdateStart(), i, getRandom(), qLConfiguration.getMinEpsilon(), this);
        mdp.getActionSpace().setSeed(qLConfiguration.getSeed());
        this.tdTargetAlgorithm = qLConfiguration.isDoubleDQN() ? new DoubleDQN(this, qLConfiguration.getGamma(), qLConfiguration.getErrorClamp()) : new StandardDQN(this, qLConfiguration.getGamma(), qLConfiguration.getErrorClamp());
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void postEpoch() {
        if (getHistoryProcessor() != null) {
            getHistoryProcessor().stopMonitor();
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void preEpoch() {
        this.history = null;
        this.lastAction = 0;
        this.accuReward = 0.0d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected QLearning.QLStepReturn<O> trainStep(O o) {
        Integer num;
        INDArray input = getInput(o);
        boolean z = getHistoryProcessor() != null;
        if (z) {
            getHistoryProcessor().record(input);
        }
        int skipFrame = z ? getHistoryProcessor().getConf().getSkipFrame() : 1;
        int updateStart = getConfiguration().getUpdateStart() + ((getConfiguration().getBatchSize() + (z ? getHistoryProcessor().getConf().getHistoryLength() : 1)) * skipFrame);
        Double valueOf = Double.valueOf(Double.NaN);
        if (getStepCounter() % skipFrame != 0) {
            num = Integer.valueOf(this.lastAction);
        } else {
            if (this.history == null) {
                if (z) {
                    getHistoryProcessor().add(input);
                    this.history = getHistoryProcessor().getHistory();
                } else {
                    this.history = new INDArray[]{input};
                }
            }
            INDArray concat = Transition.concat(Transition.dup(this.history));
            if (z) {
                concat.muli(Double.valueOf(1.0d / getHistoryProcessor().getScale()));
            }
            if (concat.shape().length > 2) {
                concat = concat.reshape(Learning.makeShape(1, ArrayUtil.toInts(concat.shape())));
            }
            valueOf = Double.valueOf(getQNetwork().output(concat).getDouble(Learning.getMaxAction(r0).intValue()));
            num = (Integer) getEgPolicy().nextAction(concat);
        }
        this.lastAction = num.intValue();
        StepReply step = getMdp().step(num);
        this.accuReward += step.getReward() * this.configuration.getRewardFactor();
        if (getStepCounter() % skipFrame == 0 || step.isDone()) {
            INDArray input2 = getInput((Encodable) step.getObservation());
            if (z) {
                getHistoryProcessor().add(input2);
            }
            INDArray[] history = z ? getHistoryProcessor().getHistory() : new INDArray[]{input2};
            getExpReplay().store(new Transition(this.history, num, this.accuReward, step.isDone(), history[0]));
            if (getStepCounter() > updateStart) {
                DataSet target = setTarget(getExpReplay().getBatch());
                getQNetwork().fit(target.getFeatures(), target.getLabels());
            }
            this.history = history;
            this.accuReward = 0.0d;
        }
        return new QLearning.QLStepReturn<>(valueOf, getQNetwork().getLatestScore(), step);
    }

    protected DataSet setTarget(ArrayList<Transition<Integer>> arrayList) {
        if (arrayList.size() == 0) {
            throw new IllegalArgumentException("too few transitions");
        }
        ((BaseTDTargetAlgorithm) this.tdTargetAlgorithm).setNShape(makeShape(arrayList.size(), getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape()));
        if (getHistoryProcessor() != null) {
            ((BaseTDTargetAlgorithm) this.tdTargetAlgorithm).setScale(getHistoryProcessor().getScale());
        }
        return this.tdTargetAlgorithm.computeTDTargets(arrayList);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public QLearning.QLConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public DQNPolicy<O> getPolicy() {
        return this.policy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public EpsGreedy<O, Integer, DiscreteSpace> getEgPolicy() {
        return this.egPolicy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource
    public IDQN getQNetwork() {
        return this.qNetwork;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource
    public IDQN getTargetQNetwork() {
        return this.targetQNetwork;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected void setTargetQNetwork(IDQN idqn) {
        this.targetQNetwork = idqn;
    }
}
