package org.deeplearning4j.rl4j.policy;

import java.beans.ConstructorProperties;
import java.io.IOException;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/DQNPolicy.class */
public class DQNPolicy<O extends Encodable> extends Policy<O, Integer> {
    private final IDQN dqn;

    public static <O extends Encodable> DQNPolicy<O> load(String str) throws IOException {
        return new DQNPolicy<>(DQN.load(str));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.policy.Policy
    public IDQN getNeuralNet() {
        return this.dqn;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.rl4j.policy.Policy
    public Integer nextAction(INDArray iNDArray) {
        return Learning.getMaxAction(this.dqn.output(iNDArray));
    }

    public void save(String str) throws IOException {
        this.dqn.save(str);
    }

    @ConstructorProperties({"dqn"})
    public DQNPolicy(IDQN idqn) {
        this.dqn = idqn;
    }
}
