package org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic;

import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelper.class */
public class NonRecurrentActorCriticHelper extends ActorCriticHelper {
    private final int actionSpaceSize;

    public NonRecurrentActorCriticHelper(int i) {
        this.actionSpaceSize = i;
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.ActorCriticHelper
    protected INDArray createFeatureArray(int i, long[] jArr) {
        return INDArrayHelper.createBatchForShape(i, jArr);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.ActorCriticHelper
    public INDArray createValueLabels(int i) {
        return Nd4j.create(new int[]{i, 1});
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.ActorCriticHelper
    public INDArray createPolicyLabels(int i) {
        return Nd4j.zeros(new int[]{i, this.actionSpaceSize});
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.ActorCriticHelper
    protected void setFeature(INDArray iNDArray, long j, INDArray iNDArray2) {
        iNDArray.putRow(j, iNDArray2);
    }

    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.ActorCriticHelper
    public void setPolicy(INDArray iNDArray, long j, int i, double d) {
        iNDArray.putScalar(j, i, d);
    }
}
