package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;

import java.util.List;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.class */
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
    protected final QNetworkSource qNetworkSource;
    protected final double gamma;
    private final double errorClamp;
    private final boolean isClamped;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double d, double d2) {
        this.qNetworkSource = qNetworkSource;
        this.gamma = d;
        this.errorClamp = d2;
        this.isClamped = !Double.isNaN(d2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double d) {
        this(qNetworkSource, d, Double.NaN);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initComputation(INDArray iNDArray, INDArray iNDArray2) {
    }

    protected abstract double computeTarget(int i, double d, boolean z);

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm
    public DataSet computeTDTargets(List<Transition<Integer>> list) {
        int size = list.size();
        INDArray buildStackedObservations = Transition.buildStackedObservations(list);
        initComputation(buildStackedObservations, Transition.buildStackedNextObservations(list));
        INDArray output = this.qNetworkSource.getQNetwork().output(buildStackedObservations);
        for (int i = 0; i < size; i++) {
            Transition<Integer> transition = list.get(i);
            double computeTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
            if (this.isClamped) {
                double d = output.getDouble(i, transition.getAction().intValue());
                computeTarget = Math.min(d + this.errorClamp, Math.max(computeTarget, d - this.errorClamp));
            }
            output.putScalar(i, transition.getAction().intValue(), computeTarget);
        }
        return new org.nd4j.linalg.dataset.DataSet(buildStackedObservations, output);
    }
}
