package org.deeplearning4j.rl4j.policy;

import lombok.NonNull;
import org.deeplearning4j.rl4j.environment.IActionSchema;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy.class */
public class EpsGreedy<A> extends Policy<A> {
    private static final Logger log = LoggerFactory.getLogger(EpsGreedy.class);
    private final INeuralNetPolicy<A> policy;
    private final int annealingStart;
    private final int epsilonNbStep;
    private final Random rnd;
    private final double minEpsilon;
    private final IActionSchema<A> actionSchema;
    private final MDP<Encodable, A, ActionSpace<A>> mdp;
    private final IEpochTrainer learning;
    private int annealingStep;

    /* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy$Configuration.class */
    public static class Configuration {
        final int annealingStart;
        final int epsilonNbStep;
        final double minEpsilon;

        /* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> {
            private boolean annealingStart$set;
            private int annealingStart$value;
            private int epsilonNbStep;
            private double minEpsilon;

            protected abstract B self();

            public abstract C build();

            public B annealingStart(int i) {
                this.annealingStart$value = i;
                this.annealingStart$set = true;
                return self();
            }

            public B epsilonNbStep(int i) {
                this.epsilonNbStep = i;
                return self();
            }

            public B minEpsilon(double d) {
                this.minEpsilon = d;
                return self();
            }

            public String toString() {
                return "EpsGreedy.Configuration.ConfigurationBuilder(annealingStart$value=" + this.annealingStart$value + ", epsilonNbStep=" + this.epsilonNbStep + ", minEpsilon=" + this.minEpsilon + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.deeplearning4j.rl4j.policy.EpsGreedy.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.policy.EpsGreedy.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        private static int $default$annealingStart() {
            return 0;
        }

        protected Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            if (((ConfigurationBuilder) configurationBuilder).annealingStart$set) {
                this.annealingStart = ((ConfigurationBuilder) configurationBuilder).annealingStart$value;
            } else {
                this.annealingStart = $default$annealingStart();
            }
            this.epsilonNbStep = ((ConfigurationBuilder) configurationBuilder).epsilonNbStep;
            this.minEpsilon = ((ConfigurationBuilder) configurationBuilder).minEpsilon;
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        public int getAnnealingStart() {
            return this.annealingStart;
        }

        public int getEpsilonNbStep() {
            return this.epsilonNbStep;
        }

        public double getMinEpsilon() {
            return this.minEpsilon;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            return configuration.canEqual(this) && getAnnealingStart() == configuration.getAnnealingStart() && getEpsilonNbStep() == configuration.getEpsilonNbStep() && Double.compare(getMinEpsilon(), configuration.getMinEpsilon()) == 0;
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        public int hashCode() {
            int annealingStart = (((1 * 59) + getAnnealingStart()) * 59) + getEpsilonNbStep();
            long doubleToLongBits = Double.doubleToLongBits(getMinEpsilon());
            return (annealingStart * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        }

        public String toString() {
            return "EpsGreedy.Configuration(annealingStart=" + getAnnealingStart() + ", epsilonNbStep=" + getEpsilonNbStep() + ", minEpsilon=" + getMinEpsilon() + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy$EpsGreedyBuilder.class */
    public static class EpsGreedyBuilder<A> {
        private INeuralNetPolicy<A> policy;
        private IActionSchema<A> actionSchema;
        private double minEpsilon;
        private int annealingStart;
        private int epsilonNbStep;
        private Random rnd;

        EpsGreedyBuilder() {
        }

        public EpsGreedyBuilder<A> policy(@NonNull INeuralNetPolicy<A> iNeuralNetPolicy) {
            if (iNeuralNetPolicy == null) {
                throw new NullPointerException("policy is marked non-null but is null");
            }
            this.policy = iNeuralNetPolicy;
            return this;
        }

        public EpsGreedyBuilder<A> actionSchema(@NonNull IActionSchema<A> iActionSchema) {
            if (iActionSchema == null) {
                throw new NullPointerException("actionSchema is marked non-null but is null");
            }
            this.actionSchema = iActionSchema;
            return this;
        }

        public EpsGreedyBuilder<A> minEpsilon(double d) {
            this.minEpsilon = d;
            return this;
        }

        public EpsGreedyBuilder<A> annealingStart(int i) {
            this.annealingStart = i;
            return this;
        }

        public EpsGreedyBuilder<A> epsilonNbStep(int i) {
            this.epsilonNbStep = i;
            return this;
        }

        public EpsGreedyBuilder<A> rnd(Random random) {
            this.rnd = random;
            return this;
        }

        public EpsGreedy<A> build() {
            return new EpsGreedy<>(this.policy, this.actionSchema, this.minEpsilon, this.annealingStart, this.epsilonNbStep, this.rnd);
        }

        public String toString() {
            return "EpsGreedy.EpsGreedyBuilder(policy=" + this.policy + ", actionSchema=" + this.actionSchema + ", minEpsilon=" + this.minEpsilon + ", annealingStart=" + this.annealingStart + ", epsilonNbStep=" + this.epsilonNbStep + ", rnd=" + this.rnd + ")";
        }
    }

    @Deprecated
    public <OBSERVATION extends Encodable, AS extends ActionSpace<A>> EpsGreedy(Policy<A> policy, MDP<Encodable, A, ActionSpace<A>> mdp, int i, int i2, Random random, double d, IEpochTrainer iEpochTrainer) {
        this.annealingStep = 0;
        this.policy = policy;
        this.mdp = mdp;
        this.annealingStart = i;
        this.epsilonNbStep = i2;
        this.rnd = random;
        this.minEpsilon = d;
        this.learning = iEpochTrainer;
        this.actionSchema = null;
    }

    public EpsGreedy(@NonNull Policy<A> policy, @NonNull IActionSchema<A> iActionSchema, double d, int i, int i2) {
        this(policy, iActionSchema, d, i, i2, null);
        if (policy == null) {
            throw new NullPointerException("policy is marked non-null but is null");
        }
        if (iActionSchema == null) {
            throw new NullPointerException("actionSchema is marked non-null but is null");
        }
    }

    public EpsGreedy(@NonNull INeuralNetPolicy<A> iNeuralNetPolicy, @NonNull IActionSchema<A> iActionSchema, double d, int i, int i2, Random random) {
        this.annealingStep = 0;
        if (iNeuralNetPolicy == null) {
            throw new NullPointerException("policy is marked non-null but is null");
        }
        if (iActionSchema == null) {
            throw new NullPointerException("actionSchema is marked non-null but is null");
        }
        this.policy = iNeuralNetPolicy;
        this.rnd = random == null ? Nd4j.getRandom() : random;
        this.minEpsilon = d;
        this.annealingStart = i;
        this.epsilonNbStep = i2;
        this.actionSchema = iActionSchema;
        this.mdp = null;
        this.learning = null;
    }

    public EpsGreedy(INeuralNetPolicy<A> iNeuralNetPolicy, IActionSchema<A> iActionSchema, @NonNull Configuration configuration, Random random) {
        this(iNeuralNetPolicy, iActionSchema, configuration.getMinEpsilon(), configuration.getAnnealingStart(), configuration.getEpsilonNbStep(), random);
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
    }

    @Override // org.deeplearning4j.rl4j.policy.Policy, org.deeplearning4j.rl4j.policy.INeuralNetPolicy
    public IOutputNeuralNet getNeuralNet() {
        return this.policy.getNeuralNet();
    }

    @Override // org.deeplearning4j.rl4j.policy.IPolicy
    @Deprecated
    public A nextAction(INDArray iNDArray) {
        double epsilon = getEpsilon();
        if (this.actionSchema != null) {
            throw new RuntimeException("nextAction(Observation observation) should be called when using a AgentLearner");
        }
        if (this.learning.getStepCount() % 500 == 1) {
            log.info("EP: " + epsilon + " " + this.learning.getStepCount());
        }
        return this.rnd.nextDouble() > epsilon ? this.policy.nextAction(iNDArray) : (A) this.mdp.getActionSpace().randomAction();
    }

    @Override // org.deeplearning4j.rl4j.policy.Policy, org.deeplearning4j.rl4j.policy.IPolicy
    public A nextAction(Observation observation) {
        if (this.actionSchema == null) {
            return nextAction(observation.getData());
        }
        double epsilon = getEpsilon();
        if (this.annealingStep % 500 == 1) {
            log.info("EP: " + epsilon + " " + this.annealingStep);
        }
        this.annealingStep++;
        if (this.rnd.nextDouble() > epsilon) {
            return this.policy.nextAction(observation);
        }
        if (getNeuralNet().isRecurrent()) {
            this.policy.nextAction(observation);
        }
        return this.actionSchema.getRandomAction();
    }

    public double getEpsilon() {
        return Math.min(1.0d, Math.max(this.minEpsilon, 1.0d - ((((this.actionSchema != null ? this.annealingStep : this.learning.getStepCount()) - this.annealingStart) * 1.0d) / this.epsilonNbStep)));
    }

    public static <A> EpsGreedyBuilder<A> builder() {
        return new EpsGreedyBuilder<>();
    }
}
