package org.deeplearning4j.rl4j.agent.learning.algorithm.dqn;

import org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQN.class */
public class StandardDQN extends BaseDQNAlgorithm {
    private static final int ACTION_DIMENSION_IDX = 1;
    private INDArray maxActionsFromQTargetNextObservation;

    public StandardDQN(IOutputNeuralNet iOutputNeuralNet, IOutputNeuralNet iOutputNeuralNet2, BaseTransitionTDAlgorithm.Configuration configuration) {
        super(iOutputNeuralNet, iOutputNeuralNet2, configuration);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseDQNAlgorithm, org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm
    public void initComputation(INDArray iNDArray, INDArray iNDArray2) {
        super.initComputation(iNDArray, iNDArray2);
        this.maxActionsFromQTargetNextObservation = Nd4j.max(this.targetQNetworkNextObservation, ACTION_DIMENSION_IDX);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm
    protected double computeTarget(int i, double d, boolean z) {
        double d2 = d;
        if (!z) {
            d2 += this.gamma * this.maxActionsFromQTargetNextObservation.getDouble(i);
        }
        return d2;
    }
}
