package org.deeplearning4j.rl4j.policy;

import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy.class */
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
    private static final Logger log = LoggerFactory.getLogger(EpsGreedy.class);
    private final Policy<A> policy;
    private final MDP<OBSERVATION, A, AS> mdp;
    private final int updateStart;
    private final int epsilonNbStep;
    private final Random rnd;
    private final double minEpsilon;
    private final IEpochTrainer learning;

    @Override // org.deeplearning4j.rl4j.policy.Policy
    public NeuralNet getNeuralNet() {
        return this.policy.getNeuralNet();
    }

    @Override // org.deeplearning4j.rl4j.policy.IPolicy
    public A nextAction(INDArray iNDArray) {
        double epsilon = getEpsilon();
        if (this.learning.getStepCount() % 500 == 1) {
            log.info("EP: " + epsilon + " " + this.learning.getStepCount());
        }
        return this.rnd.nextDouble() > epsilon ? this.policy.nextAction(iNDArray) : (A) this.mdp.getActionSpace().randomAction();
    }

    @Override // org.deeplearning4j.rl4j.policy.Policy, org.deeplearning4j.rl4j.policy.IPolicy
    public A nextAction(Observation observation) {
        return nextAction(observation.getData());
    }

    public double getEpsilon() {
        return Math.min(1.0d, Math.max(this.minEpsilon, 1.0d - (((this.learning.getStepCount() - this.updateStart) * 1.0d) / this.epsilonNbStep)));
    }

    public EpsGreedy(Policy<A> policy, MDP<OBSERVATION, A, AS> mdp, int i, int i2, Random random, double d, IEpochTrainer iEpochTrainer) {
        this.policy = policy;
        this.mdp = mdp;
        this.updateStart = i;
        this.epsilonNbStep = i2;
        this.rnd = random;
        this.minEpsilon = d;
        this.learning = iEpochTrainer;
    }
}
