package org.deeplearning4j.nn.conf;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializer;
import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer;
import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializer;
import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializer;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.serde.json.LegacyIActivationDeserializer;
import org.nd4j.serde.json.LegacyILossFunctionDeserializer;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/conf/NeuralNetConfiguration.class */
public class NeuralNetConfiguration implements Serializable, Cloneable {
    protected Layer layer;
    protected int maxNumLineSearchIterations;
    protected long seed;
    protected OptimizationAlgorithm optimizationAlgo;
    protected StepFunction stepFunction;
    protected CacheMode cacheMode;
    private static final Logger log = LoggerFactory.getLogger(NeuralNetConfiguration.class);
    private static List<Class<?>> REGISTERABLE_CUSTOM_CLASSES = Arrays.asList(Layer.class, GraphVertex.class, InputPreProcessor.class, IActivation.class, ILossFunction.class, ReconstructionDistribution.class);
    protected boolean miniBatch = true;
    protected List<String> variables = new ArrayList();
    protected boolean minimize = true;
    protected int iterationCount = 0;
    protected int epochCount = 0;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/NeuralNetConfiguration$Builder.class */
    public static class Builder implements Cloneable {
        protected IActivation activationFn;
        protected IWeightInit weightInitFn;
        protected double biasInit;
        protected double l1;
        protected double l2;
        protected double l1Bias;
        protected double l2Bias;
        protected IDropout idropOut;
        protected IWeightNoise weightNoise;
        protected IUpdater iUpdater;
        protected IUpdater biasUpdater;
        protected Layer layer;
        protected boolean miniBatch;
        protected int maxNumLineSearchIterations;
        protected long seed;
        protected OptimizationAlgorithm optimizationAlgo;
        protected StepFunction stepFunction;
        protected boolean minimize;
        protected GradientNormalization gradientNormalization;
        protected double gradientNormalizationThreshold;
        protected List<LayerConstraint> allParamConstraints;
        protected List<LayerConstraint> weightConstraints;
        protected List<LayerConstraint> biasConstraints;
        protected boolean legacyBatchScaledL2;
        protected WorkspaceMode trainingWorkspaceMode;
        protected WorkspaceMode inferenceWorkspaceMode;
        protected boolean setTWM;
        protected boolean setIWM;
        protected CacheMode cacheMode;
        protected ConvolutionMode convolutionMode;
        protected ConvolutionLayer.AlgoMode cudnnAlgoMode;

        public Builder() {
            this.activationFn = new ActivationSigmoid();
            this.weightInitFn = new WeightInitXavier();
            this.biasInit = EvaluationBinary.DEFAULT_EDGE_VALUE;
            this.l1 = Double.NaN;
            this.l2 = Double.NaN;
            this.l1Bias = Double.NaN;
            this.l2Bias = Double.NaN;
            this.iUpdater = new Sgd();
            this.biasUpdater = null;
            this.miniBatch = true;
            this.maxNumLineSearchIterations = 5;
            this.seed = System.currentTimeMillis();
            this.optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
            this.stepFunction = null;
            this.minimize = true;
            this.gradientNormalization = GradientNormalization.None;
            this.gradientNormalizationThreshold = 1.0d;
            this.legacyBatchScaledL2 = false;
            this.trainingWorkspaceMode = WorkspaceMode.ENABLED;
            this.inferenceWorkspaceMode = WorkspaceMode.ENABLED;
            this.setTWM = false;
            this.setIWM = false;
            this.cacheMode = CacheMode.NONE;
            this.convolutionMode = ConvolutionMode.Truncate;
            this.cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
        }

        public Builder(NeuralNetConfiguration neuralNetConfiguration) {
            this.activationFn = new ActivationSigmoid();
            this.weightInitFn = new WeightInitXavier();
            this.biasInit = EvaluationBinary.DEFAULT_EDGE_VALUE;
            this.l1 = Double.NaN;
            this.l2 = Double.NaN;
            this.l1Bias = Double.NaN;
            this.l2Bias = Double.NaN;
            this.iUpdater = new Sgd();
            this.biasUpdater = null;
            this.miniBatch = true;
            this.maxNumLineSearchIterations = 5;
            this.seed = System.currentTimeMillis();
            this.optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
            this.stepFunction = null;
            this.minimize = true;
            this.gradientNormalization = GradientNormalization.None;
            this.gradientNormalizationThreshold = 1.0d;
            this.legacyBatchScaledL2 = false;
            this.trainingWorkspaceMode = WorkspaceMode.ENABLED;
            this.inferenceWorkspaceMode = WorkspaceMode.ENABLED;
            this.setTWM = false;
            this.setIWM = false;
            this.cacheMode = CacheMode.NONE;
            this.convolutionMode = ConvolutionMode.Truncate;
            this.cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
            if (neuralNetConfiguration != null) {
                this.minimize = neuralNetConfiguration.minimize;
                this.maxNumLineSearchIterations = neuralNetConfiguration.maxNumLineSearchIterations;
                this.layer = neuralNetConfiguration.layer;
                this.optimizationAlgo = neuralNetConfiguration.optimizationAlgo;
                this.seed = neuralNetConfiguration.seed;
                this.stepFunction = neuralNetConfiguration.stepFunction;
                this.miniBatch = neuralNetConfiguration.miniBatch;
            }
        }

        public Builder miniBatch(boolean z) {
            this.miniBatch = z;
            return this;
        }

        public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
            if (workspaceMode == null) {
                throw new NullPointerException("workspaceMode is marked @NonNull but is null");
            }
            this.trainingWorkspaceMode = workspaceMode;
            this.setTWM = true;
            return this;
        }

        public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
            if (workspaceMode == null) {
                throw new NullPointerException("workspaceMode is marked @NonNull but is null");
            }
            this.inferenceWorkspaceMode = workspaceMode;
            this.setIWM = true;
            return this;
        }

        public Builder cacheMode(@NonNull CacheMode cacheMode) {
            if (cacheMode == null) {
                throw new NullPointerException("cacheMode is marked @NonNull but is null");
            }
            this.cacheMode = cacheMode;
            return this;
        }

        public Builder minimize(boolean z) {
            this.minimize = z;
            return this;
        }

        public Builder maxNumLineSearchIterations(int i) {
            this.maxNumLineSearchIterations = i;
            return this;
        }

        public Builder layer(Layer layer) {
            this.layer = layer;
            return this;
        }

        @Deprecated
        public Builder stepFunction(StepFunction stepFunction) {
            this.stepFunction = stepFunction;
            return this;
        }

        public ListBuilder list() {
            return new ListBuilder(this);
        }

        public ListBuilder list(Layer... layerArr) {
            if (layerArr == null || layerArr.length == 0) {
                throw new IllegalArgumentException("Cannot create network with no layers");
            }
            HashMap hashMap = new HashMap();
            for (int i = 0; i < layerArr.length; i++) {
                Builder m34clone = m34clone();
                m34clone.layer(layerArr[i]);
                hashMap.put(Integer.valueOf(i), m34clone);
            }
            return new ListBuilder(this, hashMap);
        }

        public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
            return new ComputationGraphConfiguration.GraphBuilder(this);
        }

        public Builder seed(long j) {
            this.seed = j;
            Nd4j.getRandom().setSeed(j);
            return this;
        }

        public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgorithm) {
            this.optimizationAlgo = optimizationAlgorithm;
            return this;
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public Builder m34clone() {
            try {
                Builder builder = (Builder) super.clone();
                if (builder.layer != null) {
                    builder.layer = builder.layer.mo55clone();
                }
                if (builder.stepFunction != null) {
                    builder.stepFunction = builder.stepFunction.m108clone();
                }
                return builder;
            } catch (CloneNotSupportedException e) {
                throw new RuntimeException(e);
            }
        }

        public Builder activation(IActivation iActivation) {
            this.activationFn = iActivation;
            return this;
        }

        public Builder activation(Activation activation) {
            return activation(activation.getActivationFunction());
        }

        public Builder weightInit(IWeightInit iWeightInit) {
            this.weightInitFn = iWeightInit;
            return this;
        }

        public Builder weightInit(WeightInit weightInit) {
            if (weightInit == WeightInit.DISTRIBUTION) {
            }
            this.weightInitFn = weightInit.getWeightInitFunction();
            return this;
        }

        public Builder weightInit(Distribution distribution) {
            return weightInit(new WeightInitDistribution(distribution));
        }

        public Builder biasInit(double d) {
            this.biasInit = d;
            return this;
        }

        @Deprecated
        public Builder dist(Distribution distribution) {
            return weightInit(distribution);
        }

        public Builder l1(double d) {
            this.l1 = d;
            return this;
        }

        public Builder l2(double d) {
            this.l2 = d;
            return this;
        }

        public Builder l1Bias(double d) {
            this.l1Bias = d;
            return this;
        }

        public Builder l2Bias(double d) {
            this.l2Bias = d;
            return this;
        }

        public Builder dropOut(double d) {
            return d == EvaluationBinary.DEFAULT_EDGE_VALUE ? dropOut((IDropout) null) : dropOut(new Dropout(d));
        }

        public Builder dropOut(IDropout iDropout) {
            this.idropOut = iDropout == null ? null : iDropout.m43clone();
            return this;
        }

        public Builder weightNoise(IWeightNoise iWeightNoise) {
            this.weightNoise = iWeightNoise;
            return this;
        }

        @Deprecated
        public Builder updater(Updater updater) {
            return updater(updater.getIUpdaterWithDefaultConfig());
        }

        public Builder updater(IUpdater iUpdater) {
            this.iUpdater = iUpdater;
            return this;
        }

        public Builder biasUpdater(IUpdater iUpdater) {
            this.biasUpdater = iUpdater;
            return this;
        }

        public Builder gradientNormalization(GradientNormalization gradientNormalization) {
            this.gradientNormalization = gradientNormalization;
            return this;
        }

        public Builder gradientNormalizationThreshold(double d) {
            this.gradientNormalizationThreshold = d;
            return this;
        }

        public Builder convolutionMode(ConvolutionMode convolutionMode) {
            this.convolutionMode = convolutionMode;
            return this;
        }

        public Builder cudnnAlgoMode(ConvolutionLayer.AlgoMode algoMode) {
            this.cudnnAlgoMode = algoMode;
            return this;
        }

        public Builder constrainAllParameters(LayerConstraint... layerConstraintArr) {
            this.allParamConstraints = Arrays.asList(layerConstraintArr);
            return this;
        }

        public Builder constrainBias(LayerConstraint... layerConstraintArr) {
            this.biasConstraints = Arrays.asList(layerConstraintArr);
            return this;
        }

        public Builder constrainWeights(LayerConstraint... layerConstraintArr) {
            this.weightConstraints = Arrays.asList(layerConstraintArr);
            return this;
        }

        public Builder legacyBatchScaledL2(boolean z) {
            this.legacyBatchScaledL2 = z;
            return this;
        }

        public NeuralNetConfiguration build() {
            NeuralNetConfiguration neuralNetConfiguration = new NeuralNetConfiguration();
            neuralNetConfiguration.minimize = this.minimize;
            neuralNetConfiguration.maxNumLineSearchIterations = this.maxNumLineSearchIterations;
            neuralNetConfiguration.layer = this.layer;
            neuralNetConfiguration.optimizationAlgo = this.optimizationAlgo;
            neuralNetConfiguration.seed = this.seed;
            neuralNetConfiguration.stepFunction = this.stepFunction;
            neuralNetConfiguration.miniBatch = this.miniBatch;
            neuralNetConfiguration.cacheMode = this.cacheMode;
            configureLayer(this.layer);
            if (this.layer instanceof FrozenLayer) {
                configureLayer(((FrozenLayer) this.layer).getLayer());
            }
            if (this.layer instanceof FrozenLayerWithBackprop) {
                configureLayer(((FrozenLayerWithBackprop) this.layer).getUnderlying());
            }
            return neuralNetConfiguration;
        }

        private void configureLayer(Layer layer) {
            String layerName = (layer == null || layer.getLayerName() == null) ? "Layer not named" : layer.getLayerName();
            if (layer instanceof AbstractSameDiffLayer) {
                ((AbstractSameDiffLayer) layer).applyGlobalConfig(this);
            }
            if (layer != null) {
                copyConfigToLayer(layerName, layer);
            }
            if (layer instanceof FrozenLayer) {
                copyConfigToLayer(layerName, ((FrozenLayer) layer).getLayer());
            }
            if (layer instanceof FrozenLayerWithBackprop) {
                copyConfigToLayer(layerName, ((FrozenLayerWithBackprop) layer).getUnderlying());
            }
            if (layer instanceof Bidirectional) {
                Bidirectional bidirectional = (Bidirectional) layer;
                copyConfigToLayer(bidirectional.getFwd().getLayerName(), bidirectional.getFwd());
                copyConfigToLayer(bidirectional.getBwd().getLayerName(), bidirectional.getBwd());
            }
            if (layer instanceof BaseWrapperLayer) {
                configureLayer(((BaseWrapperLayer) layer).getUnderlying());
            }
            if (layer instanceof ConvolutionLayer) {
                ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
                if (convolutionLayer.getConvolutionMode() == null) {
                    convolutionLayer.setConvolutionMode(this.convolutionMode);
                }
                if (convolutionLayer.getCudnnAlgoMode() == null) {
                    convolutionLayer.setCudnnAlgoMode(this.cudnnAlgoMode);
                }
            }
            if (layer instanceof SubsamplingLayer) {
                SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layer;
                if (subsamplingLayer.getConvolutionMode() == null) {
                    subsamplingLayer.setConvolutionMode(this.convolutionMode);
                }
            }
            LayerValidation.generalValidation(layerName, layer, this.idropOut, this.l2, this.l2Bias, this.l1, this.l1Bias, this.allParamConstraints, this.weightConstraints, this.biasConstraints);
        }

        private void copyConfigToLayer(String str, Layer layer) {
            if (layer.getIDropout() == null) {
                layer.setIDropout(this.idropOut == null ? null : this.idropOut.m43clone());
            }
            if (layer instanceof BaseLayer) {
                BaseLayer baseLayer = (BaseLayer) layer;
                if (Double.isNaN(baseLayer.getL1())) {
                    baseLayer.setL1(this.l1);
                }
                if (Double.isNaN(baseLayer.getL2())) {
                    baseLayer.setL2(this.l2);
                }
                if (baseLayer.getActivationFn() == null) {
                    baseLayer.setActivationFn(this.activationFn);
                }
                if (baseLayer.getWeightInitFn() == null) {
                    baseLayer.setWeightInitFn(this.weightInitFn);
                }
                if (Double.isNaN(baseLayer.getBiasInit())) {
                    baseLayer.setBiasInit(this.biasInit);
                }
                if (this.weightNoise != null && ((BaseLayer) layer).getWeightNoise() == null) {
                    ((BaseLayer) layer).setWeightNoise(this.weightNoise.clone());
                }
                if (this.iUpdater != null && baseLayer.getIUpdater() == null) {
                    baseLayer.setIUpdater(this.iUpdater.clone());
                }
                if (this.biasUpdater != null && baseLayer.getBiasUpdater() == null) {
                    baseLayer.setBiasUpdater(this.biasUpdater.clone());
                }
                if (baseLayer.getIUpdater() == null && this.iUpdater == null && baseLayer.initializer().numParams(baseLayer) > 0) {
                    Sgd sgd = new Sgd();
                    baseLayer.setIUpdater(sgd);
                    NeuralNetConfiguration.log.warn("*** No updater configuration is set for layer {} - defaulting to {} ***", str, sgd);
                }
                if (baseLayer.getGradientNormalization() == null) {
                    baseLayer.setGradientNormalization(this.gradientNormalization);
                }
                if (Double.isNaN(baseLayer.getGradientNormalizationThreshold())) {
                    baseLayer.setGradientNormalizationThreshold(this.gradientNormalizationThreshold);
                }
            }
            if (layer instanceof ActivationLayer) {
                ActivationLayer activationLayer = (ActivationLayer) layer;
                if (activationLayer.getActivationFn() == null) {
                    activationLayer.setActivationFn(this.activationFn);
                }
            }
            if (layer instanceof BaseOutputLayer) {
                ((BaseOutputLayer) layer).setLegacyBatchScaledL2(this.legacyBatchScaledL2);
            }
        }

        public IActivation getActivationFn() {
            return this.activationFn;
        }

        public IWeightInit getWeightInitFn() {
            return this.weightInitFn;
        }

        public double getBiasInit() {
            return this.biasInit;
        }

        public double getL1() {
            return this.l1;
        }

        public double getL2() {
            return this.l2;
        }

        public double getL1Bias() {
            return this.l1Bias;
        }

        public double getL2Bias() {
            return this.l2Bias;
        }

        public IDropout getIdropOut() {
            return this.idropOut;
        }

        public IWeightNoise getWeightNoise() {
            return this.weightNoise;
        }

        public IUpdater getIUpdater() {
            return this.iUpdater;
        }

        public IUpdater getBiasUpdater() {
            return this.biasUpdater;
        }

        public Layer getLayer() {
            return this.layer;
        }

        public boolean isMiniBatch() {
            return this.miniBatch;
        }

        public int getMaxNumLineSearchIterations() {
            return this.maxNumLineSearchIterations;
        }

        public long getSeed() {
            return this.seed;
        }

        public OptimizationAlgorithm getOptimizationAlgo() {
            return this.optimizationAlgo;
        }

        public StepFunction getStepFunction() {
            return this.stepFunction;
        }

        public boolean isMinimize() {
            return this.minimize;
        }

        public GradientNormalization getGradientNormalization() {
            return this.gradientNormalization;
        }

        public double getGradientNormalizationThreshold() {
            return this.gradientNormalizationThreshold;
        }

        public List<LayerConstraint> getAllParamConstraints() {
            return this.allParamConstraints;
        }

        public List<LayerConstraint> getWeightConstraints() {
            return this.weightConstraints;
        }

        public List<LayerConstraint> getBiasConstraints() {
            return this.biasConstraints;
        }

        public boolean isLegacyBatchScaledL2() {
            return this.legacyBatchScaledL2;
        }

        public WorkspaceMode getTrainingWorkspaceMode() {
            return this.trainingWorkspaceMode;
        }

        public WorkspaceMode getInferenceWorkspaceMode() {
            return this.inferenceWorkspaceMode;
        }

        public boolean isSetTWM() {
            return this.setTWM;
        }

        public boolean isSetIWM() {
            return this.setIWM;
        }

        public CacheMode getCacheMode() {
            return this.cacheMode;
        }

        public ConvolutionMode getConvolutionMode() {
            return this.convolutionMode;
        }

        public ConvolutionLayer.AlgoMode getCudnnAlgoMode() {
            return this.cudnnAlgoMode;
        }

        public void setActivationFn(IActivation iActivation) {
            this.activationFn = iActivation;
        }

        public void setWeightInitFn(IWeightInit iWeightInit) {
            this.weightInitFn = iWeightInit;
        }

        public void setBiasInit(double d) {
            this.biasInit = d;
        }

        public void setL1(double d) {
            this.l1 = d;
        }

        public void setL2(double d) {
            this.l2 = d;
        }

        public void setL1Bias(double d) {
            this.l1Bias = d;
        }

        public void setL2Bias(double d) {
            this.l2Bias = d;
        }

        public void setIdropOut(IDropout iDropout) {
            this.idropOut = iDropout;
        }

        public void setWeightNoise(IWeightNoise iWeightNoise) {
            this.weightNoise = iWeightNoise;
        }

        public void setIUpdater(IUpdater iUpdater) {
            this.iUpdater = iUpdater;
        }

        public void setBiasUpdater(IUpdater iUpdater) {
            this.biasUpdater = iUpdater;
        }

        public void setLayer(Layer layer) {
            this.layer = layer;
        }

        public void setMiniBatch(boolean z) {
            this.miniBatch = z;
        }

        public void setMaxNumLineSearchIterations(int i) {
            this.maxNumLineSearchIterations = i;
        }

        public void setSeed(long j) {
            this.seed = j;
        }

        public void setOptimizationAlgo(OptimizationAlgorithm optimizationAlgorithm) {
            this.optimizationAlgo = optimizationAlgorithm;
        }

        public void setStepFunction(StepFunction stepFunction) {
            this.stepFunction = stepFunction;
        }

        public void setMinimize(boolean z) {
            this.minimize = z;
        }

        public void setGradientNormalization(GradientNormalization gradientNormalization) {
            this.gradientNormalization = gradientNormalization;
        }

        public void setGradientNormalizationThreshold(double d) {
            this.gradientNormalizationThreshold = d;
        }

        public void setAllParamConstraints(List<LayerConstraint> list) {
            this.allParamConstraints = list;
        }

        public void setWeightConstraints(List<LayerConstraint> list) {
            this.weightConstraints = list;
        }

        public void setBiasConstraints(List<LayerConstraint> list) {
            this.biasConstraints = list;
        }

        public void setLegacyBatchScaledL2(boolean z) {
            this.legacyBatchScaledL2 = z;
        }

        public void setTrainingWorkspaceMode(WorkspaceMode workspaceMode) {
            this.trainingWorkspaceMode = workspaceMode;
        }

        public void setInferenceWorkspaceMode(WorkspaceMode workspaceMode) {
            this.inferenceWorkspaceMode = workspaceMode;
        }

        public void setSetTWM(boolean z) {
            this.setTWM = z;
        }

        public void setSetIWM(boolean z) {
            this.setIWM = z;
        }

        public void setCacheMode(CacheMode cacheMode) {
            this.cacheMode = cacheMode;
        }

        public void setConvolutionMode(ConvolutionMode convolutionMode) {
            this.convolutionMode = convolutionMode;
        }

        public void setCudnnAlgoMode(ConvolutionLayer.AlgoMode algoMode) {
            this.cudnnAlgoMode = algoMode;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Builder)) {
                return false;
            }
            Builder builder = (Builder) obj;
            if (!builder.canEqual(this)) {
                return false;
            }
            IActivation activationFn = getActivationFn();
            IActivation activationFn2 = builder.getActivationFn();
            if (activationFn == null) {
                if (activationFn2 != null) {
                    return false;
                }
            } else if (!activationFn.equals(activationFn2)) {
                return false;
            }
            IWeightInit weightInitFn = getWeightInitFn();
            IWeightInit weightInitFn2 = builder.getWeightInitFn();
            if (weightInitFn == null) {
                if (weightInitFn2 != null) {
                    return false;
                }
            } else if (!weightInitFn.equals(weightInitFn2)) {
                return false;
            }
            if (Double.compare(getBiasInit(), builder.getBiasInit()) != 0 || Double.compare(getL1(), builder.getL1()) != 0 || Double.compare(getL2(), builder.getL2()) != 0 || Double.compare(getL1Bias(), builder.getL1Bias()) != 0 || Double.compare(getL2Bias(), builder.getL2Bias()) != 0) {
                return false;
            }
            IDropout idropOut = getIdropOut();
            IDropout idropOut2 = builder.getIdropOut();
            if (idropOut == null) {
                if (idropOut2 != null) {
                    return false;
                }
            } else if (!idropOut.equals(idropOut2)) {
                return false;
            }
            IWeightNoise weightNoise = getWeightNoise();
            IWeightNoise weightNoise2 = builder.getWeightNoise();
            if (weightNoise == null) {
                if (weightNoise2 != null) {
                    return false;
                }
            } else if (!weightNoise.equals(weightNoise2)) {
                return false;
            }
            IUpdater iUpdater = getIUpdater();
            IUpdater iUpdater2 = builder.getIUpdater();
            if (iUpdater == null) {
                if (iUpdater2 != null) {
                    return false;
                }
            } else if (!iUpdater.equals(iUpdater2)) {
                return false;
            }
            IUpdater biasUpdater = getBiasUpdater();
            IUpdater biasUpdater2 = builder.getBiasUpdater();
            if (biasUpdater == null) {
                if (biasUpdater2 != null) {
                    return false;
                }
            } else if (!biasUpdater.equals(biasUpdater2)) {
                return false;
            }
            Layer layer = getLayer();
            Layer layer2 = builder.getLayer();
            if (layer == null) {
                if (layer2 != null) {
                    return false;
                }
            } else if (!layer.equals(layer2)) {
                return false;
            }
            if (isMiniBatch() != builder.isMiniBatch() || getMaxNumLineSearchIterations() != builder.getMaxNumLineSearchIterations() || getSeed() != builder.getSeed()) {
                return false;
            }
            OptimizationAlgorithm optimizationAlgo = getOptimizationAlgo();
            OptimizationAlgorithm optimizationAlgo2 = builder.getOptimizationAlgo();
            if (optimizationAlgo == null) {
                if (optimizationAlgo2 != null) {
                    return false;
                }
            } else if (!optimizationAlgo.equals(optimizationAlgo2)) {
                return false;
            }
            StepFunction stepFunction = getStepFunction();
            StepFunction stepFunction2 = builder.getStepFunction();
            if (stepFunction == null) {
                if (stepFunction2 != null) {
                    return false;
                }
            } else if (!stepFunction.equals(stepFunction2)) {
                return false;
            }
            if (isMinimize() != builder.isMinimize()) {
                return false;
            }
            GradientNormalization gradientNormalization = getGradientNormalization();
            GradientNormalization gradientNormalization2 = builder.getGradientNormalization();
            if (gradientNormalization == null) {
                if (gradientNormalization2 != null) {
                    return false;
                }
            } else if (!gradientNormalization.equals(gradientNormalization2)) {
                return false;
            }
            if (Double.compare(getGradientNormalizationThreshold(), builder.getGradientNormalizationThreshold()) != 0) {
                return false;
            }
            List<LayerConstraint> allParamConstraints = getAllParamConstraints();
            List<LayerConstraint> allParamConstraints2 = builder.getAllParamConstraints();
            if (allParamConstraints == null) {
                if (allParamConstraints2 != null) {
                    return false;
                }
            } else if (!allParamConstraints.equals(allParamConstraints2)) {
                return false;
            }
            List<LayerConstraint> weightConstraints = getWeightConstraints();
            List<LayerConstraint> weightConstraints2 = builder.getWeightConstraints();
            if (weightConstraints == null) {
                if (weightConstraints2 != null) {
                    return false;
                }
            } else if (!weightConstraints.equals(weightConstraints2)) {
                return false;
            }
            List<LayerConstraint> biasConstraints = getBiasConstraints();
            List<LayerConstraint> biasConstraints2 = builder.getBiasConstraints();
            if (biasConstraints == null) {
                if (biasConstraints2 != null) {
                    return false;
                }
            } else if (!biasConstraints.equals(biasConstraints2)) {
                return false;
            }
            if (isLegacyBatchScaledL2() != builder.isLegacyBatchScaledL2()) {
                return false;
            }
            WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
            WorkspaceMode trainingWorkspaceMode2 = builder.getTrainingWorkspaceMode();
            if (trainingWorkspaceMode == null) {
                if (trainingWorkspaceMode2 != null) {
                    return false;
                }
            } else if (!trainingWorkspaceMode.equals(trainingWorkspaceMode2)) {
                return false;
            }
            WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
            WorkspaceMode inferenceWorkspaceMode2 = builder.getInferenceWorkspaceMode();
            if (inferenceWorkspaceMode == null) {
                if (inferenceWorkspaceMode2 != null) {
                    return false;
                }
            } else if (!inferenceWorkspaceMode.equals(inferenceWorkspaceMode2)) {
                return false;
            }
            if (isSetTWM() != builder.isSetTWM() || isSetIWM() != builder.isSetIWM()) {
                return false;
            }
            CacheMode cacheMode = getCacheMode();
            CacheMode cacheMode2 = builder.getCacheMode();
            if (cacheMode == null) {
                if (cacheMode2 != null) {
                    return false;
                }
            } else if (!cacheMode.equals(cacheMode2)) {
                return false;
            }
            ConvolutionMode convolutionMode = getConvolutionMode();
            ConvolutionMode convolutionMode2 = builder.getConvolutionMode();
            if (convolutionMode == null) {
                if (convolutionMode2 != null) {
                    return false;
                }
            } else if (!convolutionMode.equals(convolutionMode2)) {
                return false;
            }
            ConvolutionLayer.AlgoMode cudnnAlgoMode = getCudnnAlgoMode();
            ConvolutionLayer.AlgoMode cudnnAlgoMode2 = builder.getCudnnAlgoMode();
            return cudnnAlgoMode == null ? cudnnAlgoMode2 == null : cudnnAlgoMode.equals(cudnnAlgoMode2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Builder;
        }

        public int hashCode() {
            IActivation activationFn = getActivationFn();
            int hashCode = (1 * 59) + (activationFn == null ? 43 : activationFn.hashCode());
            IWeightInit weightInitFn = getWeightInitFn();
            int hashCode2 = (hashCode * 59) + (weightInitFn == null ? 43 : weightInitFn.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getBiasInit());
            int i = (hashCode2 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getL1());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(getL2());
            int i3 = (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
            long doubleToLongBits4 = Double.doubleToLongBits(getL1Bias());
            int i4 = (i3 * 59) + ((int) ((doubleToLongBits4 >>> 32) ^ doubleToLongBits4));
            long doubleToLongBits5 = Double.doubleToLongBits(getL2Bias());
            int i5 = (i4 * 59) + ((int) ((doubleToLongBits5 >>> 32) ^ doubleToLongBits5));
            IDropout idropOut = getIdropOut();
            int hashCode3 = (i5 * 59) + (idropOut == null ? 43 : idropOut.hashCode());
            IWeightNoise weightNoise = getWeightNoise();
            int hashCode4 = (hashCode3 * 59) + (weightNoise == null ? 43 : weightNoise.hashCode());
            IUpdater iUpdater = getIUpdater();
            int hashCode5 = (hashCode4 * 59) + (iUpdater == null ? 43 : iUpdater.hashCode());
            IUpdater biasUpdater = getBiasUpdater();
            int hashCode6 = (hashCode5 * 59) + (biasUpdater == null ? 43 : biasUpdater.hashCode());
            Layer layer = getLayer();
            int hashCode7 = (((((hashCode6 * 59) + (layer == null ? 43 : layer.hashCode())) * 59) + (isMiniBatch() ? 79 : 97)) * 59) + getMaxNumLineSearchIterations();
            long seed = getSeed();
            int i6 = (hashCode7 * 59) + ((int) ((seed >>> 32) ^ seed));
            OptimizationAlgorithm optimizationAlgo = getOptimizationAlgo();
            int hashCode8 = (i6 * 59) + (optimizationAlgo == null ? 43 : optimizationAlgo.hashCode());
            StepFunction stepFunction = getStepFunction();
            int hashCode9 = (((hashCode8 * 59) + (stepFunction == null ? 43 : stepFunction.hashCode())) * 59) + (isMinimize() ? 79 : 97);
            GradientNormalization gradientNormalization = getGradientNormalization();
            int hashCode10 = (hashCode9 * 59) + (gradientNormalization == null ? 43 : gradientNormalization.hashCode());
            long doubleToLongBits6 = Double.doubleToLongBits(getGradientNormalizationThreshold());
            int i7 = (hashCode10 * 59) + ((int) ((doubleToLongBits6 >>> 32) ^ doubleToLongBits6));
            List<LayerConstraint> allParamConstraints = getAllParamConstraints();
            int hashCode11 = (i7 * 59) + (allParamConstraints == null ? 43 : allParamConstraints.hashCode());
            List<LayerConstraint> weightConstraints = getWeightConstraints();
            int hashCode12 = (hashCode11 * 59) + (weightConstraints == null ? 43 : weightConstraints.hashCode());
            List<LayerConstraint> biasConstraints = getBiasConstraints();
            int hashCode13 = (((hashCode12 * 59) + (biasConstraints == null ? 43 : biasConstraints.hashCode())) * 59) + (isLegacyBatchScaledL2() ? 79 : 97);
            WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
            int hashCode14 = (hashCode13 * 59) + (trainingWorkspaceMode == null ? 43 : trainingWorkspaceMode.hashCode());
            WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
            int hashCode15 = (((((hashCode14 * 59) + (inferenceWorkspaceMode == null ? 43 : inferenceWorkspaceMode.hashCode())) * 59) + (isSetTWM() ? 79 : 97)) * 59) + (isSetIWM() ? 79 : 97);
            CacheMode cacheMode = getCacheMode();
            int hashCode16 = (hashCode15 * 59) + (cacheMode == null ? 43 : cacheMode.hashCode());
            ConvolutionMode convolutionMode = getConvolutionMode();
            int hashCode17 = (hashCode16 * 59) + (convolutionMode == null ? 43 : convolutionMode.hashCode());
            ConvolutionLayer.AlgoMode cudnnAlgoMode = getCudnnAlgoMode();
            return (hashCode17 * 59) + (cudnnAlgoMode == null ? 43 : cudnnAlgoMode.hashCode());
        }

        public String toString() {
            return "NeuralNetConfiguration.Builder(activationFn=" + getActivationFn() + ", weightInitFn=" + getWeightInitFn() + ", biasInit=" + getBiasInit() + ", l1=" + getL1() + ", l2=" + getL2() + ", l1Bias=" + getL1Bias() + ", l2Bias=" + getL2Bias() + ", idropOut=" + getIdropOut() + ", weightNoise=" + getWeightNoise() + ", iUpdater=" + getIUpdater() + ", biasUpdater=" + getBiasUpdater() + ", layer=" + getLayer() + ", miniBatch=" + isMiniBatch() + ", maxNumLineSearchIterations=" + getMaxNumLineSearchIterations() + ", seed=" + getSeed() + ", optimizationAlgo=" + getOptimizationAlgo() + ", stepFunction=" + getStepFunction() + ", minimize=" + isMinimize() + ", gradientNormalization=" + getGradientNormalization() + ", gradientNormalizationThreshold=" + getGradientNormalizationThreshold() + ", allParamConstraints=" + getAllParamConstraints() + ", weightConstraints=" + getWeightConstraints() + ", biasConstraints=" + getBiasConstraints() + ", legacyBatchScaledL2=" + isLegacyBatchScaledL2() + ", trainingWorkspaceMode=" + getTrainingWorkspaceMode() + ", inferenceWorkspaceMode=" + getInferenceWorkspaceMode() + ", setTWM=" + isSetTWM() + ", setIWM=" + isSetIWM() + ", cacheMode=" + getCacheMode() + ", convolutionMode=" + getConvolutionMode() + ", cudnnAlgoMode=" + getCudnnAlgoMode() + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/NeuralNetConfiguration$ListBuilder.class */
    public static class ListBuilder extends MultiLayerConfiguration.Builder {
        private int layerCounter;
        private Map<Integer, Builder> layerwise;
        private Builder globalConfig;

        /* loaded from: input_file:org/deeplearning4j/nn/conf/NeuralNetConfiguration$ListBuilder$InputTypeBuilder.class */
        public class InputTypeBuilder {
            public InputTypeBuilder() {
            }

            public ListBuilder convolutional(int i, int i2, int i3) {
                return ListBuilder.this.setInputType(InputType.convolutional(i, i2, i3));
            }

            public ListBuilder convolutionalFlat(int i, int i2, int i3) {
                return ListBuilder.this.setInputType(InputType.convolutionalFlat(i, i2, i3));
            }

            public ListBuilder feedForward(int i) {
                return ListBuilder.this.setInputType(InputType.feedForward(i));
            }

            public ListBuilder recurrent(int i) {
                return ListBuilder.this.setInputType(InputType.recurrent(i));
            }
        }

        public ListBuilder(Builder builder, Map<Integer, Builder> map) {
            this.layerCounter = -1;
            this.globalConfig = builder;
            this.layerwise = map;
        }

        public ListBuilder(Builder builder) {
            this(builder, new HashMap());
        }

        public ListBuilder layer(int i, @NonNull Layer layer) {
            if (layer == null) {
                throw new NullPointerException("layer is marked @NonNull but is null");
            }
            if (this.layerwise.containsKey(Integer.valueOf(i))) {
                NeuralNetConfiguration.log.info("Layer index {} already exists, layer of type {} will be replace by layer type {}", new Object[]{Integer.valueOf(i), this.layerwise.get(Integer.valueOf(i)).getClass().getSimpleName(), layer.getClass().getSimpleName()});
                this.layerwise.get(Integer.valueOf(i)).layer(layer);
            } else {
                this.layerwise.put(Integer.valueOf(i), this.globalConfig.m34clone().layer(layer));
            }
            if (this.layerCounter < i) {
                this.layerCounter = i;
            }
            return this;
        }

        public ListBuilder layer(Layer layer) {
            int i = this.layerCounter + 1;
            this.layerCounter = i;
            return layer(i, layer);
        }

        public Map<Integer, Builder> getLayerwise() {
            return this.layerwise;
        }

        @Override // org.deeplearning4j.nn.conf.MultiLayerConfiguration.Builder
        public ListBuilder setInputType(InputType inputType) {
            return (ListBuilder) super.setInputType(inputType);
        }

        public InputTypeBuilder inputType() {
            return new InputTypeBuilder();
        }

        public List<InputType> getLayerActivationTypes() {
            Preconditions.checkState(this.inputType != null, "Can only calculate activation types if input type hasbeen set. Use setInputType(InputType)");
            try {
                return build().getLayerActivationTypes(this.inputType);
            } catch (Exception e) {
                throw new RuntimeException("Error calculating layer activation types: error instantiating MultiLayerConfiguration", e);
            }
        }

        @Override // org.deeplearning4j.nn.conf.MultiLayerConfiguration.Builder
        public MultiLayerConfiguration build() {
            ArrayList arrayList = new ArrayList();
            if (this.layerwise.isEmpty()) {
                throw new IllegalStateException("Invalid configuration: no layers defined");
            }
            for (int i = 0; i < this.layerwise.size(); i++) {
                if (this.layerwise.get(Integer.valueOf(i)) == null) {
                    throw new IllegalStateException("Invalid configuration: layer number " + i + " not specified. Expect layer numbers to be 0 to " + (this.layerwise.size() - 1) + " inclusive (number of layers defined: " + this.layerwise.size() + ")");
                }
                if (this.layerwise.get(Integer.valueOf(i)).getLayer() == null) {
                    throw new IllegalStateException("Cannot construct network: Layer config forlayer with index " + i + " is not defined)");
                }
                if (this.layerwise.get(Integer.valueOf(i)).getLayer().getLayerName() == null) {
                    this.layerwise.get(Integer.valueOf(i)).getLayer().setLayerName("layer" + i);
                }
                arrayList.add(this.layerwise.get(Integer.valueOf(i)).build());
            }
            return new MultiLayerConfiguration.Builder().inputPreProcessors(this.inputPreProcessors).backpropType(this.backpropType).tBPTTForwardLength(this.tbpttFwdLength).tBPTTBackwardLength(this.tbpttBackLength).setInputType(this.inputType).trainingWorkspaceMode(this.globalConfig.setTWM ? this.globalConfig.trainingWorkspaceMode : this.trainingWorkspaceMode).cacheMode(this.globalConfig.cacheMode).inferenceWorkspaceMode(this.globalConfig.setIWM ? this.globalConfig.inferenceWorkspaceMode : this.inferenceWorkspaceMode).confs(arrayList).validateOutputLayerConfig(this.validateOutputConfig).legacyBatchScaledL2(this.legacyBatchScaledL2).build();
        }
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NeuralNetConfiguration m33clone() {
        try {
            NeuralNetConfiguration neuralNetConfiguration = (NeuralNetConfiguration) super.clone();
            if (neuralNetConfiguration.layer != null) {
                neuralNetConfiguration.layer = neuralNetConfiguration.layer.mo55clone();
            }
            if (neuralNetConfiguration.stepFunction != null) {
                neuralNetConfiguration.stepFunction = neuralNetConfiguration.stepFunction.m108clone();
            }
            if (neuralNetConfiguration.variables != null) {
                neuralNetConfiguration.variables = new ArrayList(neuralNetConfiguration.variables);
            }
            return neuralNetConfiguration;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public List<String> variables() {
        return new ArrayList(this.variables);
    }

    public List<String> variables(boolean z) {
        return z ? variables() : this.variables;
    }

    public void addVariable(String str) {
        if (this.variables.contains(str)) {
            return;
        }
        this.variables.add(str);
    }

    public void clearVariables() {
        this.variables.clear();
    }

    public String toYaml() {
        try {
            return mapperYaml().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public static NeuralNetConfiguration fromYaml(String str) {
        try {
            return (NeuralNetConfiguration) mapperYaml().readValue(str, NeuralNetConfiguration.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String toJson() {
        try {
            return mapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public static NeuralNetConfiguration fromJson(String str) {
        try {
            return (NeuralNetConfiguration) mapper().readValue(str, NeuralNetConfiguration.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static ObjectMapper mapperYaml() {
        return JsonMappers.getMapperYaml();
    }

    public static ObjectMapper mapper() {
        return JsonMappers.getMapper();
    }

    public static void registerLegacyCustomClassesForJSON(Class<?>... clsArr) {
        registerLegacyCustomClassesForJSONList(Arrays.asList(clsArr));
    }

    public static void registerLegacyCustomClassesForJSONList(List<Class<?>> list) {
        ArrayList arrayList = new ArrayList();
        for (Class<?> cls : list) {
            arrayList.add(new Pair(cls.getSimpleName(), cls));
        }
        registerLegacyCustomClassesForJSON(arrayList);
    }

    public static void registerLegacyCustomClassesForJSON(List<Pair<String, Class>> list) {
        for (Pair<String, Class> pair : list) {
            String str = (String) pair.getFirst();
            Class<?> cls = (Class) pair.getRight();
            boolean z = false;
            for (Class<?> cls2 : REGISTERABLE_CUSTOM_CLASSES) {
                if (cls2.isAssignableFrom(cls)) {
                    if (cls2 == Layer.class) {
                        LegacyLayerDeserializer.registerLegacyClassSpecifiedName(str, cls);
                    } else if (cls2 == GraphVertex.class) {
                        LegacyGraphVertexDeserializer.registerLegacyClassSpecifiedName(str, cls);
                    } else if (cls2 == InputPreProcessor.class) {
                        LegacyPreprocessorDeserializer.registerLegacyClassSpecifiedName(str, cls);
                    } else if (cls2 == IActivation.class) {
                        LegacyIActivationDeserializer.registerLegacyClassSpecifiedName(str, cls);
                    } else if (cls2 == ILossFunction.class) {
                        LegacyILossFunctionDeserializer.registerLegacyClassSpecifiedName(str, cls);
                    } else if (cls2 == ReconstructionDistribution.class) {
                        LegacyReconstructionDistributionDeserializer.registerLegacyClassSpecifiedName(str, cls);
                    }
                    z = true;
                }
            }
            if (!z) {
                throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " + cls.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
            }
        }
    }

    public Layer getLayer() {
        return this.layer;
    }

    public boolean isMiniBatch() {
        return this.miniBatch;
    }

    public int getMaxNumLineSearchIterations() {
        return this.maxNumLineSearchIterations;
    }

    public long getSeed() {
        return this.seed;
    }

    public OptimizationAlgorithm getOptimizationAlgo() {
        return this.optimizationAlgo;
    }

    public List<String> getVariables() {
        return this.variables;
    }

    public StepFunction getStepFunction() {
        return this.stepFunction;
    }

    public boolean isMinimize() {
        return this.minimize;
    }

    public CacheMode getCacheMode() {
        return this.cacheMode;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public int getEpochCount() {
        return this.epochCount;
    }

    public void setLayer(Layer layer) {
        this.layer = layer;
    }

    public void setMiniBatch(boolean z) {
        this.miniBatch = z;
    }

    public void setMaxNumLineSearchIterations(int i) {
        this.maxNumLineSearchIterations = i;
    }

    public void setSeed(long j) {
        this.seed = j;
    }

    public void setOptimizationAlgo(OptimizationAlgorithm optimizationAlgorithm) {
        this.optimizationAlgo = optimizationAlgorithm;
    }

    public void setVariables(List<String> list) {
        this.variables = list;
    }

    public void setStepFunction(StepFunction stepFunction) {
        this.stepFunction = stepFunction;
    }

    public void setMinimize(boolean z) {
        this.minimize = z;
    }

    public void setCacheMode(CacheMode cacheMode) {
        this.cacheMode = cacheMode;
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setEpochCount(int i) {
        this.epochCount = i;
    }

    public String toString() {
        return "NeuralNetConfiguration(layer=" + getLayer() + ", miniBatch=" + isMiniBatch() + ", maxNumLineSearchIterations=" + getMaxNumLineSearchIterations() + ", seed=" + getSeed() + ", optimizationAlgo=" + getOptimizationAlgo() + ", variables=" + getVariables() + ", stepFunction=" + getStepFunction() + ", minimize=" + isMinimize() + ", cacheMode=" + getCacheMode() + ", iterationCount=" + getIterationCount() + ", epochCount=" + getEpochCount() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof NeuralNetConfiguration)) {
            return false;
        }
        NeuralNetConfiguration neuralNetConfiguration = (NeuralNetConfiguration) obj;
        if (!neuralNetConfiguration.canEqual(this)) {
            return false;
        }
        Layer layer = getLayer();
        Layer layer2 = neuralNetConfiguration.getLayer();
        if (layer == null) {
            if (layer2 != null) {
                return false;
            }
        } else if (!layer.equals(layer2)) {
            return false;
        }
        if (isMiniBatch() != neuralNetConfiguration.isMiniBatch() || getMaxNumLineSearchIterations() != neuralNetConfiguration.getMaxNumLineSearchIterations() || getSeed() != neuralNetConfiguration.getSeed()) {
            return false;
        }
        OptimizationAlgorithm optimizationAlgo = getOptimizationAlgo();
        OptimizationAlgorithm optimizationAlgo2 = neuralNetConfiguration.getOptimizationAlgo();
        if (optimizationAlgo == null) {
            if (optimizationAlgo2 != null) {
                return false;
            }
        } else if (!optimizationAlgo.equals(optimizationAlgo2)) {
            return false;
        }
        List<String> variables = getVariables();
        List<String> variables2 = neuralNetConfiguration.getVariables();
        if (variables == null) {
            if (variables2 != null) {
                return false;
            }
        } else if (!variables.equals(variables2)) {
            return false;
        }
        StepFunction stepFunction = getStepFunction();
        StepFunction stepFunction2 = neuralNetConfiguration.getStepFunction();
        if (stepFunction == null) {
            if (stepFunction2 != null) {
                return false;
            }
        } else if (!stepFunction.equals(stepFunction2)) {
            return false;
        }
        if (isMinimize() != neuralNetConfiguration.isMinimize()) {
            return false;
        }
        CacheMode cacheMode = getCacheMode();
        CacheMode cacheMode2 = neuralNetConfiguration.getCacheMode();
        return cacheMode == null ? cacheMode2 == null : cacheMode.equals(cacheMode2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof NeuralNetConfiguration;
    }

    public int hashCode() {
        Layer layer = getLayer();
        int hashCode = (((((1 * 59) + (layer == null ? 43 : layer.hashCode())) * 59) + (isMiniBatch() ? 79 : 97)) * 59) + getMaxNumLineSearchIterations();
        long seed = getSeed();
        int i = (hashCode * 59) + ((int) ((seed >>> 32) ^ seed));
        OptimizationAlgorithm optimizationAlgo = getOptimizationAlgo();
        int hashCode2 = (i * 59) + (optimizationAlgo == null ? 43 : optimizationAlgo.hashCode());
        List<String> variables = getVariables();
        int hashCode3 = (hashCode2 * 59) + (variables == null ? 43 : variables.hashCode());
        StepFunction stepFunction = getStepFunction();
        int hashCode4 = (((hashCode3 * 59) + (stepFunction == null ? 43 : stepFunction.hashCode())) * 59) + (isMinimize() ? 79 : 97);
        CacheMode cacheMode = getCacheMode();
        return (hashCode4 * 59) + (cacheMode == null ? 43 : cacheMode.hashCode());
    }
}
