package org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning;

import java.util.List;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelper.class */
public class NonRecurrentNStepQLearningHelper extends NStepQLearningHelper {
    private final int actionSpaceSize;

    public NonRecurrentNStepQLearningHelper(int i) {
        this.actionSpaceSize = i;
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearningHelper
    public INDArray createLabels(int i) {
        return Nd4j.create(new int[]{i, this.actionSpaceSize});
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearningHelper
    protected void setFeature(INDArray iNDArray, long j, INDArray iNDArray2) {
        iNDArray.putRow(j, iNDArray2);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearningHelper
    public INDArray getExpectedQValues(INDArray iNDArray, int i) {
        return iNDArray.getRow(i);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearningHelper
    protected INDArray createFeatureArray(int i, long[] jArr) {
        return INDArrayHelper.createBatchForShape(i, jArr);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearningHelper
    public void setLabels(INDArray iNDArray, long j, INDArray iNDArray2) {
        iNDArray.putRow(j, iNDArray2);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearningHelper
    public INDArray getTargetExpectedQValuesOfLast(IOutputNeuralNet iOutputNeuralNet, List<StateActionPair<Integer>> list, INDArray iNDArray) {
        return iOutputNeuralNet.output(list.get(list.size() - 1).getObservation()).get("Q");
    }
}
