package org.deeplearning4j.rl4j.learning.async.a3c.discrete;

import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.class */
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
    public final A3CConfiguration configuration;
    protected final MDP<O, Integer, DiscreteSpace> mdp;
    private final IActorCritic iActorCritic;
    private final AsyncGlobal asyncGlobal;
    private final ACPolicy<O> policy;

    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete$A3CConfiguration.class */
    public static class A3CConfiguration implements AsyncConfiguration {
        Integer seed;
        int maxEpochStep;
        int maxStep;
        int numThread;
        int nstep;
        int updateStart;
        double rewardFactor;
        double gamma;
        double errorClamp;

        /* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete$A3CConfiguration$A3CConfigurationBuilder.class */
        public static class A3CConfigurationBuilder {
            private Integer seed;
            private int maxEpochStep;
            private int maxStep;
            private int numThread;
            private int nstep;
            private int updateStart;
            private double rewardFactor;
            private double gamma;
            private double errorClamp;

            A3CConfigurationBuilder() {
            }

            public A3CConfigurationBuilder seed(Integer num) {
                this.seed = num;
                return this;
            }

            public A3CConfigurationBuilder maxEpochStep(int i) {
                this.maxEpochStep = i;
                return this;
            }

            public A3CConfigurationBuilder maxStep(int i) {
                this.maxStep = i;
                return this;
            }

            public A3CConfigurationBuilder numThread(int i) {
                this.numThread = i;
                return this;
            }

            public A3CConfigurationBuilder nstep(int i) {
                this.nstep = i;
                return this;
            }

            public A3CConfigurationBuilder updateStart(int i) {
                this.updateStart = i;
                return this;
            }

            public A3CConfigurationBuilder rewardFactor(double d) {
                this.rewardFactor = d;
                return this;
            }

            public A3CConfigurationBuilder gamma(double d) {
                this.gamma = d;
                return this;
            }

            public A3CConfigurationBuilder errorClamp(double d) {
                this.errorClamp = d;
                return this;
            }

            public A3CConfiguration build() {
                return new A3CConfiguration(this.seed, this.maxEpochStep, this.maxStep, this.numThread, this.nstep, this.updateStart, this.rewardFactor, this.gamma, this.errorClamp);
            }

            public String toString() {
                return "A3CDiscrete.A3CConfiguration.A3CConfigurationBuilder(seed=" + this.seed + ", maxEpochStep=" + this.maxEpochStep + ", maxStep=" + this.maxStep + ", numThread=" + this.numThread + ", nstep=" + this.nstep + ", updateStart=" + this.updateStart + ", rewardFactor=" + this.rewardFactor + ", gamma=" + this.gamma + ", errorClamp=" + this.errorClamp + ")";
            }
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration
        public int getTargetDqnUpdateFreq() {
            return -1;
        }

        public static A3CConfigurationBuilder builder() {
            return new A3CConfigurationBuilder();
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration, org.deeplearning4j.rl4j.learning.ILearning.LConfiguration
        public Integer getSeed() {
            return this.seed;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration, org.deeplearning4j.rl4j.learning.ILearning.LConfiguration
        public int getMaxEpochStep() {
            return this.maxEpochStep;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration, org.deeplearning4j.rl4j.learning.ILearning.LConfiguration
        public int getMaxStep() {
            return this.maxStep;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration
        public int getNumThread() {
            return this.numThread;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration
        public int getNstep() {
            return this.nstep;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration
        public int getUpdateStart() {
            return this.updateStart;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration
        public double getRewardFactor() {
            return this.rewardFactor;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration, org.deeplearning4j.rl4j.learning.ILearning.LConfiguration
        public double getGamma() {
            return this.gamma;
        }

        @Override // org.deeplearning4j.rl4j.learning.async.AsyncConfiguration
        public double getErrorClamp() {
            return this.errorClamp;
        }

        public void setSeed(Integer num) {
            this.seed = num;
        }

        public void setMaxEpochStep(int i) {
            this.maxEpochStep = i;
        }

        public void setMaxStep(int i) {
            this.maxStep = i;
        }

        public void setNumThread(int i) {
            this.numThread = i;
        }

        public void setNstep(int i) {
            this.nstep = i;
        }

        public void setUpdateStart(int i) {
            this.updateStart = i;
        }

        public void setRewardFactor(double d) {
            this.rewardFactor = d;
        }

        public void setGamma(double d) {
            this.gamma = d;
        }

        public void setErrorClamp(double d) {
            this.errorClamp = d;
        }

        public String toString() {
            return "A3CDiscrete.A3CConfiguration(seed=" + getSeed() + ", maxEpochStep=" + getMaxEpochStep() + ", maxStep=" + getMaxStep() + ", numThread=" + getNumThread() + ", nstep=" + getNstep() + ", updateStart=" + getUpdateStart() + ", rewardFactor=" + getRewardFactor() + ", gamma=" + getGamma() + ", errorClamp=" + getErrorClamp() + ")";
        }

        public A3CConfiguration(Integer num, int i, int i2, int i3, int i4, int i5, double d, double d2, double d3) {
            this.seed = num;
            this.maxEpochStep = i;
            this.maxStep = i2;
            this.numThread = i3;
            this.nstep = i4;
            this.updateStart = i5;
            this.rewardFactor = d;
            this.gamma = d2;
            this.errorClamp = d3;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof A3CConfiguration)) {
                return false;
            }
            A3CConfiguration a3CConfiguration = (A3CConfiguration) obj;
            if (!a3CConfiguration.canEqual(this)) {
                return false;
            }
            Integer seed = getSeed();
            Integer seed2 = a3CConfiguration.getSeed();
            if (seed == null) {
                if (seed2 != null) {
                    return false;
                }
            } else if (!seed.equals(seed2)) {
                return false;
            }
            return getMaxEpochStep() == a3CConfiguration.getMaxEpochStep() && getMaxStep() == a3CConfiguration.getMaxStep() && getNumThread() == a3CConfiguration.getNumThread() && getNstep() == a3CConfiguration.getNstep() && getUpdateStart() == a3CConfiguration.getUpdateStart() && Double.compare(getRewardFactor(), a3CConfiguration.getRewardFactor()) == 0 && Double.compare(getGamma(), a3CConfiguration.getGamma()) == 0 && Double.compare(getErrorClamp(), a3CConfiguration.getErrorClamp()) == 0;
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof A3CConfiguration;
        }

        public int hashCode() {
            Integer seed = getSeed();
            int hashCode = (((((((((((1 * 59) + (seed == null ? 43 : seed.hashCode())) * 59) + getMaxEpochStep()) * 59) + getMaxStep()) * 59) + getNumThread()) * 59) + getNstep()) * 59) + getUpdateStart();
            long doubleToLongBits = Double.doubleToLongBits(getRewardFactor());
            int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getGamma());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(getErrorClamp());
            return (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
        }
    }

    public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration a3CConfiguration) {
        this.iActorCritic = iActorCritic;
        this.mdp = mdp;
        this.configuration = a3CConfiguration;
        this.asyncGlobal = new AsyncGlobal(iActorCritic, a3CConfiguration);
        Integer seed = a3CConfiguration.getSeed();
        Random random = Nd4j.getRandom();
        if (seed != null) {
            random.setSeed(seed.intValue());
        }
        this.policy = new ACPolicy<>(iActorCritic, random);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.learning.async.AsyncLearning
    public AsyncThread newThread(int i, int i2) {
        return new A3CThreadDiscrete(this.mdp.newInstance(), this.asyncGlobal, getConfiguration(), i2, getListeners(), i);
    }

    @Override // org.deeplearning4j.rl4j.learning.Learning, org.deeplearning4j.rl4j.learning.NeuralNetFetchable
    public IActorCritic getNeuralNet() {
        return this.iActorCritic;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncLearning, org.deeplearning4j.rl4j.learning.ILearning
    public A3CConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncLearning
    /* renamed from: getAsyncGlobal, reason: merged with bridge method [inline-methods] */
    public IAsyncGlobal<IActorCritic> getAsyncGlobal2() {
        return this.asyncGlobal;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public ACPolicy<O> getPolicy() {
        return this.policy;
    }
}
