package org.deeplearning4j.rl4j.policy;

import java.beans.ConstructorProperties;
import java.util.Random;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/BoltzmannQ.class */
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
    private final IDQN dqn;
    private final Random rd = new Random(123);

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.policy.Policy
    public IDQN getNeuralNet() {
        return this.dqn;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.rl4j.policy.Policy
    public Integer nextAction(INDArray iNDArray) {
        INDArray exp = Transforms.exp(this.dqn.output(iNDArray));
        double nextDouble = this.rd.nextDouble() * exp.sum(new int[]{1}).getDouble(0);
        for (int i = 0; i < exp.columns(); i++) {
            if (nextDouble < exp.getDouble(i)) {
                return Integer.valueOf(i);
            }
        }
        return -1;
    }

    @ConstructorProperties({"dqn"})
    public BoltzmannQ(IDQN idqn) {
        this.dqn = idqn;
    }
}
