package org.deeplearning4j.rl4j.learning.async.nstep.discrete;

import java.util.List;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.class */
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
    private final int[] shape;
    private final int actionSpaceSize;
    private final double gamma;

    public QLearningUpdateAlgorithm(int[] iArr, int i, double d) {
        this.shape = iArr;
        this.actionSpaceSize = i;
        this.gamma = d;
    }

    /* renamed from: computeGradients, reason: avoid collision after fix types in other method */
    public Gradient[] computeGradients2(IDQN idqn, List<StateActionPair<Integer>> list) {
        int size = list.size();
        INDArray create = Nd4j.create(Learning.makeShape(size, this.shape));
        INDArray create2 = Nd4j.create(new int[]{size, this.actionSpaceSize});
        StateActionPair<Integer> stateActionPair = list.get(size - 1);
        double d = stateActionPair.isTerminal() ? 0.0d : Nd4j.max(idqn.outputAll(stateActionPair.getObservation().getData())[0]).getDouble(0L);
        for (int i = size - 1; i >= 0; i--) {
            StateActionPair<Integer> stateActionPair2 = list.get(i);
            create.putRow(i, stateActionPair2.getObservation().getData());
            d = stateActionPair2.getReward() + (this.gamma * d);
            create2.putRow(i, idqn.outputAll(stateActionPair2.getObservation().getData())[0].putScalar(stateActionPair2.getAction().intValue(), d));
        }
        return idqn.gradient(create, create2);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm
    public /* bridge */ /* synthetic */ Gradient[] computeGradients(IDQN idqn, List list) {
        return computeGradients2(idqn, (List<StateActionPair<Integer>>) list);
    }
}
