package org.deeplearning4j.rl4j.learning.async.nstep.discrete;

import java.util.Random;
import java.util.Stack;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.class */
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
    protected final AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
    protected final MDP<O, Integer, DiscreteSpace> mdp;
    protected final IAsyncGlobal<IDQN> asyncGlobal;
    protected final int threadNumber;
    protected final IDataManager dataManager;
    private final Random random;

    public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> iAsyncGlobal, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration asyncNStepQLConfiguration, int i, IDataManager iDataManager, int i2) {
        super(iAsyncGlobal, i, i2);
        this.conf = asyncNStepQLConfiguration;
        this.asyncGlobal = iAsyncGlobal;
        this.threadNumber = i;
        this.mdp = mdp;
        this.dataManager = iDataManager;
        mdp.getActionSpace().setSeed(asyncNStepQLConfiguration.getSeed() + i);
        this.random = new Random(asyncNStepQLConfiguration.getSeed() + i);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public Policy<O, Integer> getPolicy(IDQN idqn) {
        return new EpsGreedy(new DQNPolicy(idqn), this.mdp, this.conf.getUpdateStart(), this.conf.getEpsilonNbStep(), this.random, this.conf.getMinEpsilon(), this);
    }

    /* renamed from: calcGradient, reason: avoid collision after fix types in other method */
    public Gradient[] calcGradient2(IDQN idqn, Stack<MiniTrans<Integer>> stack) {
        MiniTrans<Integer> pop = stack.pop();
        int size = stack.size();
        INDArray create = Nd4j.create(Learning.makeShape(size, getHistoryProcessor() == null ? this.mdp.getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape()));
        INDArray create2 = Nd4j.create(size, this.mdp.getActionSpace().getSize());
        double reward = pop.getReward();
        for (int i = size - 1; i >= 0; i--) {
            MiniTrans<Integer> pop2 = stack.pop();
            reward = pop2.getReward() + (this.conf.getGamma() * reward);
            create.putRow(i, pop2.getObs());
            create2.putRow(i, pop2.getOutput()[0].putScalar(pop2.getAction().intValue(), reward));
        }
        return idqn.gradient(create, create2);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration getConf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public IAsyncGlobal<IDQN> getAsyncGlobal() {
        return this.asyncGlobal;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public int getThreadNumber() {
        return this.threadNumber;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public IDataManager getDataManager() {
        return this.dataManager;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete
    public /* bridge */ /* synthetic */ Gradient[] calcGradient(IDQN idqn, Stack stack) {
        return calcGradient2(idqn, (Stack<MiniTrans<Integer>>) stack);
    }
}
