package org.deeplearning4j.rl4j.network.ac;

import java.beans.ConstructorProperties;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.class */
public final class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate {
    private final Configuration conf;

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense$Configuration.class */
    public static final class Configuration {
        private final int numLayer;
        private final int numHiddenNodes;
        private final double learningRate;
        private final double l2;

        /* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense$Configuration$ConfigurationBuilder.class */
        public static class ConfigurationBuilder {
            private int numLayer;
            private int numHiddenNodes;
            private double learningRate;
            private double l2;

            ConfigurationBuilder() {
            }

            public ConfigurationBuilder numLayer(int i) {
                this.numLayer = i;
                return this;
            }

            public ConfigurationBuilder numHiddenNodes(int i) {
                this.numHiddenNodes = i;
                return this;
            }

            public ConfigurationBuilder learningRate(double d) {
                this.learningRate = d;
                return this;
            }

            public ConfigurationBuilder l2(double d) {
                this.l2 = d;
                return this;
            }

            public Configuration build() {
                return new Configuration(this.numLayer, this.numHiddenNodes, this.learningRate, this.l2);
            }

            public String toString() {
                return "ActorCriticFactorySeparateStdDense.Configuration.ConfigurationBuilder(numLayer=" + this.numLayer + ", numHiddenNodes=" + this.numHiddenNodes + ", learningRate=" + this.learningRate + ", l2=" + this.l2 + ")";
            }
        }

        Configuration(int i, int i2, double d, double d2) {
            this.numLayer = i;
            this.numHiddenNodes = i2;
            this.learningRate = d;
            this.l2 = d2;
        }

        public static ConfigurationBuilder builder() {
            return new ConfigurationBuilder();
        }

        public int getNumLayer() {
            return this.numLayer;
        }

        public int getNumHiddenNodes() {
            return this.numHiddenNodes;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public double getL2() {
            return this.l2;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            return getNumLayer() == configuration.getNumLayer() && getNumHiddenNodes() == configuration.getNumHiddenNodes() && Double.compare(getLearningRate(), configuration.getLearningRate()) == 0 && Double.compare(getL2(), configuration.getL2()) == 0;
        }

        public int hashCode() {
            int numLayer = (((1 * 59) + getNumLayer()) * 59) + getNumHiddenNodes();
            long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
            int i = (numLayer * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getL2());
            return (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        }

        public String toString() {
            return "ActorCriticFactorySeparateStdDense.Configuration(numLayer=" + getNumLayer() + ", numHiddenNodes=" + getNumHiddenNodes() + ", learningRate=" + getLearningRate() + ", l2=" + getL2() + ")";
        }
    }

    @Override // org.deeplearning4j.rl4j.network.ac.ActorCriticFactorySeparate
    public ActorCriticSeparate buildActorCritic(int[] iArr, int i) {
        NeuralNetConfiguration.ListBuilder layer = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED).iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(this.conf.getLearningRate()).updater(Updater.ADAM).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(iArr[0]).nOut(this.conf.getNumHiddenNodes()).activation("relu").build());
        for (int i2 = 1; i2 < this.conf.getNumLayer(); i2++) {
            layer.layer(i2, new DenseLayer.Builder().nIn(this.conf.getNumHiddenNodes()).nOut(this.conf.getNumHiddenNodes()).activation("relu").build());
        }
        layer.layer(this.conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("identity").nIn(this.conf.getNumHiddenNodes()).nOut(1).build());
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(layer.pretrain(false).backprop(true).build());
        multiLayerNetwork.init();
        multiLayerNetwork.setListeners(new IterationListener[]{new ScoreIterationListener(50)});
        NeuralNetConfiguration.ListBuilder layer2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED).iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(this.conf.getLearningRate()).updater(Updater.ADAM).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(iArr[0]).nOut(this.conf.getNumHiddenNodes()).activation("relu").build());
        for (int i3 = 1; i3 < this.conf.getNumLayer(); i3++) {
            layer2.layer(i3, new DenseLayer.Builder().nIn(this.conf.getNumHiddenNodes()).nOut(this.conf.getNumHiddenNodes()).activation("relu").build());
        }
        layer2.layer(this.conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax").nIn(this.conf.getNumHiddenNodes()).nOut(i).build());
        MultiLayerNetwork multiLayerNetwork2 = new MultiLayerNetwork(layer2.pretrain(false).backprop(true).build());
        multiLayerNetwork2.init();
        multiLayerNetwork2.setListeners(new IterationListener[]{new ScoreIterationListener(50)});
        return new ActorCriticSeparate(multiLayerNetwork, multiLayerNetwork2);
    }

    @ConstructorProperties({"conf"})
    public ActorCriticFactorySeparateStdDense(Configuration configuration) {
        this.conf = configuration;
    }

    public Configuration getConf() {
        return this.conf;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ActorCriticFactorySeparateStdDense)) {
            return false;
        }
        Configuration conf = getConf();
        Configuration conf2 = ((ActorCriticFactorySeparateStdDense) obj).getConf();
        return conf == null ? conf2 == null : conf.equals(conf2);
    }

    public int hashCode() {
        Configuration conf = getConf();
        return (1 * 59) + (conf == null ? 43 : conf.hashCode());
    }

    public String toString() {
        return "ActorCriticFactorySeparateStdDense(conf=" + getConf() + ")";
    }
}
