package org.deeplearning4j.rl4j.builder;

import org.apache.commons.lang3.builder.Builder;
import org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.algorithm.nstepqlearning.NStepQLearning;
import org.deeplearning4j.rl4j.agent.learning.update.Gradients;
import org.deeplearning4j.rl4j.agent.learning.update.updater.async.AsyncSharedNetworksUpdateHandler;
import org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.rng.Random;

/* loaded from: input_file:org/deeplearning4j/rl4j/builder/NStepQLearningBuilder.class */
public class NStepQLearningBuilder extends BaseAsyncAgentLearnerBuilder<Configuration> {
    private final Random rnd;

    /* loaded from: input_file:org/deeplearning4j/rl4j/builder/NStepQLearningBuilder$Configuration.class */
    public static class Configuration extends BaseAsyncAgentLearnerBuilder.Configuration {
        NStepQLearning.Configuration nstepQLearningConfiguration;

        /* loaded from: input_file:org/deeplearning4j/rl4j/builder/NStepQLearningBuilder$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> extends BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder<C, B> {
            private NStepQLearning.Configuration nstepQLearningConfiguration;

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public abstract B self();

            @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public abstract C build();

            public B nstepQLearningConfiguration(NStepQLearning.Configuration configuration) {
                this.nstepQLearningConfiguration = configuration;
                return self();
            }

            @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public String toString() {
                return "NStepQLearningBuilder.Configuration.ConfigurationBuilder(super=" + super.toString() + ", nstepQLearningConfiguration=" + this.nstepQLearningConfiguration + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/builder/NStepQLearningBuilder$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.deeplearning4j.rl4j.builder.NStepQLearningBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.builder.NStepQLearningBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        protected Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            super(configurationBuilder);
            this.nstepQLearningConfiguration = ((ConfigurationBuilder) configurationBuilder).nstepQLearningConfiguration;
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            if (!configuration.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            NStepQLearning.Configuration nstepQLearningConfiguration = getNstepQLearningConfiguration();
            NStepQLearning.Configuration nstepQLearningConfiguration2 = configuration.getNstepQLearningConfiguration();
            return nstepQLearningConfiguration == null ? nstepQLearningConfiguration2 == null : nstepQLearningConfiguration.equals(nstepQLearningConfiguration2);
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public int hashCode() {
            int hashCode = super.hashCode();
            NStepQLearning.Configuration nstepQLearningConfiguration = getNstepQLearningConfiguration();
            return (hashCode * 59) + (nstepQLearningConfiguration == null ? 43 : nstepQLearningConfiguration.hashCode());
        }

        public NStepQLearning.Configuration getNstepQLearningConfiguration() {
            return this.nstepQLearningConfiguration;
        }

        public void setNstepQLearningConfiguration(NStepQLearning.Configuration configuration) {
            this.nstepQLearningConfiguration = configuration;
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public String toString() {
            return "NStepQLearningBuilder.Configuration(nstepQLearningConfiguration=" + getNstepQLearningConfiguration() + ")";
        }
    }

    public NStepQLearningBuilder(Configuration configuration, ITrainableNeuralNet iTrainableNeuralNet, Builder<Environment<Integer>> builder, Builder<TransformProcess> builder2, Random random) {
        super(configuration, iTrainableNeuralNet, builder, builder2);
        Preconditions.checkArgument(!iTrainableNeuralNet.isRecurrent() || configuration.getExperienceHandlerConfiguration().getBatchSize() == Integer.MAX_VALUE, "RL with a recurrent network currently only works with whole-trajectory updates. Until RNN are fully supported, please set the batch size of your experience handler to Integer.MAX_VALUE");
        this.rnd = random;
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected IPolicy<Integer> buildPolicy() {
        return new EpsGreedy(new DQNPolicy(this.networks.getThreadCurrentNetwork()), getEnvironment().getSchema().getActionSchema(), ((Configuration) this.configuration).getPolicyConfiguration(), this.rnd);
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected IUpdateAlgorithm<Gradients, StateActionPair<Integer>> buildUpdateAlgorithm() {
        return new NStepQLearning(this.networks.getThreadCurrentNetwork(), this.networks.getTargetNetwork(), getEnvironment().getSchema().getActionSchema().getActionSpaceSize(), ((Configuration) this.configuration).getNstepQLearningConfiguration());
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder
    protected AsyncSharedNetworksUpdateHandler buildAsyncSharedNetworksUpdateHandler() {
        return new AsyncSharedNetworksUpdateHandler(this.networks.getGlobalCurrentNetwork(), this.networks.getTargetNetwork(), ((Configuration) this.configuration).getNeuralNetUpdaterConfiguration());
    }
}
