package org.deeplearning4j.rl4j.builder;

import org.apache.commons.lang3.builder.Builder;
import org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels;
import org.deeplearning4j.rl4j.agent.learning.update.updater.INeuralNetUpdater;
import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration;
import org.deeplearning4j.rl4j.agent.learning.update.updater.sync.SyncLabelsNeuralNetUpdater;
import org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder;
import org.deeplearning4j.rl4j.builder.BaseDQNAgentLearnerBuilder.Configuration;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.rng.Random;

/* loaded from: input_file:org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder.class */
public abstract class BaseDQNAgentLearnerBuilder<CONFIGURATION_TYPE extends Configuration> extends BaseAgentLearnerBuilder<Integer, Transition<Integer>, FeaturesLabels, CONFIGURATION_TYPE> {
    private final Random rnd;

    /* loaded from: input_file:org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder$Configuration.class */
    public static class Configuration extends BaseAgentLearnerBuilder.Configuration<Integer> {
        EpsGreedy.Configuration policyConfiguration;
        ReplayMemoryExperienceHandler.Configuration experienceHandlerConfiguration;
        NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration;
        BaseTransitionTDAlgorithm.Configuration updateAlgorithmConfiguration;

        /* loaded from: input_file:org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> extends BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder<Integer, C, B> {
            private EpsGreedy.Configuration policyConfiguration;
            private ReplayMemoryExperienceHandler.Configuration experienceHandlerConfiguration;
            private NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration;
            private BaseTransitionTDAlgorithm.Configuration updateAlgorithmConfiguration;

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public abstract B self();

            @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public abstract C build();

            public B policyConfiguration(EpsGreedy.Configuration configuration) {
                this.policyConfiguration = configuration;
                return self();
            }

            public B experienceHandlerConfiguration(ReplayMemoryExperienceHandler.Configuration configuration) {
                this.experienceHandlerConfiguration = configuration;
                return self();
            }

            public B neuralNetUpdaterConfiguration(NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration) {
                this.neuralNetUpdaterConfiguration = neuralNetUpdaterConfiguration;
                return self();
            }

            public B updateAlgorithmConfiguration(BaseTransitionTDAlgorithm.Configuration configuration) {
                this.updateAlgorithmConfiguration = configuration;
                return self();
            }

            @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public String toString() {
                return "BaseDQNAgentLearnerBuilder.Configuration.ConfigurationBuilder(super=" + super.toString() + ", policyConfiguration=" + this.policyConfiguration + ", experienceHandlerConfiguration=" + this.experienceHandlerConfiguration + ", neuralNetUpdaterConfiguration=" + this.neuralNetUpdaterConfiguration + ", updateAlgorithmConfiguration=" + this.updateAlgorithmConfiguration + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.deeplearning4j.rl4j.builder.BaseDQNAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.builder.BaseDQNAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            super(configurationBuilder);
            this.policyConfiguration = ((ConfigurationBuilder) configurationBuilder).policyConfiguration;
            this.experienceHandlerConfiguration = ((ConfigurationBuilder) configurationBuilder).experienceHandlerConfiguration;
            this.neuralNetUpdaterConfiguration = ((ConfigurationBuilder) configurationBuilder).neuralNetUpdaterConfiguration;
            this.updateAlgorithmConfiguration = ((ConfigurationBuilder) configurationBuilder).updateAlgorithmConfiguration;
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            if (!configuration.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            EpsGreedy.Configuration policyConfiguration = getPolicyConfiguration();
            EpsGreedy.Configuration policyConfiguration2 = configuration.getPolicyConfiguration();
            if (policyConfiguration == null) {
                if (policyConfiguration2 != null) {
                    return false;
                }
            } else if (!policyConfiguration.equals(policyConfiguration2)) {
                return false;
            }
            ReplayMemoryExperienceHandler.Configuration experienceHandlerConfiguration = getExperienceHandlerConfiguration();
            ReplayMemoryExperienceHandler.Configuration experienceHandlerConfiguration2 = configuration.getExperienceHandlerConfiguration();
            if (experienceHandlerConfiguration == null) {
                if (experienceHandlerConfiguration2 != null) {
                    return false;
                }
            } else if (!experienceHandlerConfiguration.equals(experienceHandlerConfiguration2)) {
                return false;
            }
            NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration = getNeuralNetUpdaterConfiguration();
            NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration2 = configuration.getNeuralNetUpdaterConfiguration();
            if (neuralNetUpdaterConfiguration == null) {
                if (neuralNetUpdaterConfiguration2 != null) {
                    return false;
                }
            } else if (!neuralNetUpdaterConfiguration.equals(neuralNetUpdaterConfiguration2)) {
                return false;
            }
            BaseTransitionTDAlgorithm.Configuration updateAlgorithmConfiguration = getUpdateAlgorithmConfiguration();
            BaseTransitionTDAlgorithm.Configuration updateAlgorithmConfiguration2 = configuration.getUpdateAlgorithmConfiguration();
            return updateAlgorithmConfiguration == null ? updateAlgorithmConfiguration2 == null : updateAlgorithmConfiguration.equals(updateAlgorithmConfiguration2);
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public int hashCode() {
            int hashCode = super.hashCode();
            EpsGreedy.Configuration policyConfiguration = getPolicyConfiguration();
            int hashCode2 = (hashCode * 59) + (policyConfiguration == null ? 43 : policyConfiguration.hashCode());
            ReplayMemoryExperienceHandler.Configuration experienceHandlerConfiguration = getExperienceHandlerConfiguration();
            int hashCode3 = (hashCode2 * 59) + (experienceHandlerConfiguration == null ? 43 : experienceHandlerConfiguration.hashCode());
            NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration = getNeuralNetUpdaterConfiguration();
            int hashCode4 = (hashCode3 * 59) + (neuralNetUpdaterConfiguration == null ? 43 : neuralNetUpdaterConfiguration.hashCode());
            BaseTransitionTDAlgorithm.Configuration updateAlgorithmConfiguration = getUpdateAlgorithmConfiguration();
            return (hashCode4 * 59) + (updateAlgorithmConfiguration == null ? 43 : updateAlgorithmConfiguration.hashCode());
        }

        public EpsGreedy.Configuration getPolicyConfiguration() {
            return this.policyConfiguration;
        }

        public ReplayMemoryExperienceHandler.Configuration getExperienceHandlerConfiguration() {
            return this.experienceHandlerConfiguration;
        }

        public NeuralNetUpdaterConfiguration getNeuralNetUpdaterConfiguration() {
            return this.neuralNetUpdaterConfiguration;
        }

        public BaseTransitionTDAlgorithm.Configuration getUpdateAlgorithmConfiguration() {
            return this.updateAlgorithmConfiguration;
        }

        public void setPolicyConfiguration(EpsGreedy.Configuration configuration) {
            this.policyConfiguration = configuration;
        }

        public void setExperienceHandlerConfiguration(ReplayMemoryExperienceHandler.Configuration configuration) {
            this.experienceHandlerConfiguration = configuration;
        }

        public void setNeuralNetUpdaterConfiguration(NeuralNetUpdaterConfiguration neuralNetUpdaterConfiguration) {
            this.neuralNetUpdaterConfiguration = neuralNetUpdaterConfiguration;
        }

        public void setUpdateAlgorithmConfiguration(BaseTransitionTDAlgorithm.Configuration configuration) {
            this.updateAlgorithmConfiguration = configuration;
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public String toString() {
            return "BaseDQNAgentLearnerBuilder.Configuration(policyConfiguration=" + getPolicyConfiguration() + ", experienceHandlerConfiguration=" + getExperienceHandlerConfiguration() + ", neuralNetUpdaterConfiguration=" + getNeuralNetUpdaterConfiguration() + ", updateAlgorithmConfiguration=" + getUpdateAlgorithmConfiguration() + ")";
        }
    }

    public BaseDQNAgentLearnerBuilder(CONFIGURATION_TYPE configuration_type, ITrainableNeuralNet iTrainableNeuralNet, Builder<Environment<Integer>> builder, Builder<TransformProcess> builder2, Random random) {
        super(configuration_type, iTrainableNeuralNet, builder, builder2);
        Preconditions.checkArgument(!iTrainableNeuralNet.isRecurrent(), "Recurrent networks are not yet supported with DQN.");
        this.rnd = random;
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected IPolicy<Integer> buildPolicy() {
        return new EpsGreedy(new DQNPolicy(this.networks.getThreadCurrentNetwork()), getEnvironment().getSchema().getActionSchema(), ((Configuration) this.configuration).getPolicyConfiguration(), this.rnd);
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected ExperienceHandler<Integer, Transition<Integer>> buildExperienceHandler() {
        return new ReplayMemoryExperienceHandler(((Configuration) this.configuration).getExperienceHandlerConfiguration(), this.rnd);
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected INeuralNetUpdater<FeaturesLabels> buildNeuralNetUpdater() {
        if (((Configuration) this.configuration).isAsynchronous()) {
            throw new UnsupportedOperationException("Only synchronized use is currently supported");
        }
        return new SyncLabelsNeuralNetUpdater(this.networks.getThreadCurrentNetwork(), this.networks.getTargetNetwork(), ((Configuration) this.configuration).getNeuralNetUpdaterConfiguration());
    }
}
