package org.deeplearning4j.rl4j.network.dqn;

import java.util.Arrays;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.class */
public final class DQNFactoryStdConv implements DQNFactory {
    private final NetworkConfiguration conf;

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv$Configuration.class */
    public static final class Configuration {
        private final double learningRate;
        private final double l2;
        private final IUpdater updater;
        private final TrainingListener[] listeners;

        /* loaded from: input_file:org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv$Configuration$ConfigurationBuilder.class */
        public static class ConfigurationBuilder {
            private double learningRate;
            private double l2;
            private IUpdater updater;
            private TrainingListener[] listeners;

            ConfigurationBuilder() {
            }

            public ConfigurationBuilder learningRate(double d) {
                this.learningRate = d;
                return this;
            }

            public ConfigurationBuilder l2(double d) {
                this.l2 = d;
                return this;
            }

            public ConfigurationBuilder updater(IUpdater iUpdater) {
                this.updater = iUpdater;
                return this;
            }

            public ConfigurationBuilder listeners(TrainingListener[] trainingListenerArr) {
                this.listeners = trainingListenerArr;
                return this;
            }

            public Configuration build() {
                return new Configuration(this.learningRate, this.l2, this.updater, this.listeners);
            }

            public String toString() {
                return "DQNFactoryStdConv.Configuration.ConfigurationBuilder(learningRate=" + this.learningRate + ", l2=" + this.l2 + ", updater=" + this.updater + ", listeners=" + Arrays.deepToString(this.listeners) + ")";
            }
        }

        /* JADX WARN: Type inference failed for: r0v1, types: [org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration$NetworkConfigurationBuilder] */
        public NetworkConfiguration toNetworkConfiguration() {
            NetworkConfiguration.NetworkConfigurationBuilder updater = NetworkConfiguration.builder().learningRate(this.learningRate).l2(this.l2).updater(this.updater);
            if (this.listeners != null) {
                updater.listeners(Arrays.asList(this.listeners));
            }
            return updater.build();
        }

        public static ConfigurationBuilder builder() {
            return new ConfigurationBuilder();
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public double getL2() {
            return this.l2;
        }

        public IUpdater getUpdater() {
            return this.updater;
        }

        public TrainingListener[] getListeners() {
            return this.listeners;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            if (Double.compare(getLearningRate(), configuration.getLearningRate()) != 0 || Double.compare(getL2(), configuration.getL2()) != 0) {
                return false;
            }
            IUpdater updater = getUpdater();
            IUpdater updater2 = configuration.getUpdater();
            if (updater == null) {
                if (updater2 != null) {
                    return false;
                }
            } else if (!updater.equals(updater2)) {
                return false;
            }
            return Arrays.deepEquals(getListeners(), configuration.getListeners());
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
            int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getL2());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            IUpdater updater = getUpdater();
            return (((i2 * 59) + (updater == null ? 43 : updater.hashCode())) * 59) + Arrays.deepHashCode(getListeners());
        }

        public String toString() {
            return "DQNFactoryStdConv.Configuration(learningRate=" + getLearningRate() + ", l2=" + getL2() + ", updater=" + getUpdater() + ", listeners=" + Arrays.deepToString(getListeners()) + ")";
        }

        public Configuration(double d, double d2, IUpdater iUpdater, TrainingListener[] trainingListenerArr) {
            this.learningRate = d;
            this.l2 = d2;
            this.updater = iUpdater;
            this.listeners = trainingListenerArr;
        }
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.DQNFactory
    public DQN buildDQN(int[] iArr, int i) {
        if (iArr.length == 1) {
            throw new AssertionError("Impossible to apply convolutional layer on a shape == 1");
        }
        NeuralNetConfiguration.ListBuilder layer = new NeuralNetConfiguration.Builder().seed(12345L).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(this.conf.getL2()).updater(this.conf.getUpdater() != null ? this.conf.getUpdater() : new Adam()).weightInit(WeightInit.XAVIER).l2(this.conf.getL2()).list().layer(0, new ConvolutionLayer.Builder(new int[]{8, 8}).nIn(iArr[0]).nOut(16).stride(new int[]{4, 4}).activation(Activation.RELU).build());
        layer.layer(1, new ConvolutionLayer.Builder(new int[]{4, 4}).nOut(32).stride(new int[]{2, 2}).activation(Activation.RELU).build());
        layer.layer(2, new DenseLayer.Builder().nOut(256).activation(Activation.RELU).build());
        layer.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nOut(i).build());
        layer.setInputType(InputType.convolutional(iArr[1], iArr[2], iArr[0]));
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(layer.build());
        multiLayerNetwork.init();
        if (this.conf.getListeners() != null) {
            multiLayerNetwork.setListeners(this.conf.getListeners());
        } else {
            multiLayerNetwork.setListeners(new TrainingListener[]{new ScoreIterationListener(50)});
        }
        return new DQN(multiLayerNetwork);
    }

    public DQNFactoryStdConv(NetworkConfiguration networkConfiguration) {
        this.conf = networkConfiguration;
    }

    public NetworkConfiguration getConf() {
        return this.conf;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof DQNFactoryStdConv)) {
            return false;
        }
        NetworkConfiguration conf = getConf();
        NetworkConfiguration conf2 = ((DQNFactoryStdConv) obj).getConf();
        return conf == null ? conf2 == null : conf.equals(conf2);
    }

    public int hashCode() {
        NetworkConfiguration conf = getConf();
        return (1 * 59) + (conf == null ? 43 : conf.hashCode());
    }

    public String toString() {
        return "DQNFactoryStdConv(conf=" + getConf() + ")";
    }
}
