package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;

import java.util.List;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.class */
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
    protected final QNetworkSource qNetworkSource;
    protected final double gamma;
    private final double errorClamp;
    private final boolean isClamped;
    private int[] nShape;
    private double scale;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double d, double d2) {
        this.scale = 1.0d;
        this.qNetworkSource = qNetworkSource;
        this.gamma = d;
        this.errorClamp = d2;
        this.isClamped = !Double.isNaN(d2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double d) {
        this(qNetworkSource, d, Double.NaN);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initComputation(INDArray iNDArray, INDArray iNDArray2) {
    }

    protected abstract double computeTarget(int i, double d, boolean z);

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm
    public DataSet computeTDTargets(List<Transition<Integer>> list) {
        int size = list.size();
        INDArray create = Nd4j.create(this.nShape);
        INDArray create2 = Nd4j.create(this.nShape);
        for (int i = 0; i < size; i++) {
            Transition<Integer> transition = list.get(i);
            INDArray[] observation = transition.getObservation();
            if (create.rank() == 2) {
                create.putRow(i, observation[0]);
            } else {
                for (int i2 = 0; i2 < observation.length; i2++) {
                    create.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(i2)}, observation[i2]);
                }
            }
            INDArray[] append = Transition.append(transition.getObservation(), transition.getNextObservation());
            if (create2.rank() == 2) {
                create2.putRow(i, append[0]);
            } else {
                for (int i3 = 0; i3 < append.length; i3++) {
                    create2.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(i3)}, append[i3]);
                }
            }
        }
        if (this.scale != 1.0d) {
            create.muli(Double.valueOf(1.0d / this.scale));
            create2.muli(Double.valueOf(1.0d / this.scale));
        }
        initComputation(create, create2);
        INDArray output = this.qNetworkSource.getQNetwork().output(create);
        for (int i4 = 0; i4 < size; i4++) {
            Transition<Integer> transition2 = list.get(i4);
            double computeTarget = computeTarget(i4, transition2.getReward(), transition2.isTerminal());
            if (this.isClamped) {
                double d = output.getDouble(i4, transition2.getAction().intValue());
                computeTarget = Math.min(d + this.errorClamp, Math.max(computeTarget, d - this.errorClamp));
            }
            output.putScalar(i4, transition2.getAction().intValue(), computeTarget);
        }
        return new org.nd4j.linalg.dataset.DataSet(create, output);
    }

    public void setNShape(int[] iArr) {
        this.nShape = iArr;
    }

    public void setScale(double d) {
        this.scale = d;
    }
}
