package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.SameDiffParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.class */
public abstract class AbstractSameDiffLayer extends Layer {
    private static final Logger log = LoggerFactory.getLogger(AbstractSameDiffLayer.class);
    protected List<Regularization> regularization;
    protected List<Regularization> regularizationBias;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected GradientNormalization gradientNormalization;
    protected double gradientNormalizationThreshold;
    private SDLayerParams layerParams;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer$Builder.class */
    public static abstract class Builder<T extends Builder<T>> extends Layer.Builder<T> {
        protected List<Regularization> regularization = new ArrayList();
        protected List<Regularization> regularizationBias = new ArrayList();
        protected IUpdater updater = null;
        protected IUpdater biasUpdater = null;

        public T l1(double d) {
            NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
            if (d > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                this.regularization.add(new L1Regularization(d));
            }
            return this;
        }

        public T l2(double d) {
            NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
            if (d > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization");
                this.regularization.add(new L2Regularization(d));
            }
            return this;
        }

        public T l1Bias(double d) {
            NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
            if (d > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                this.regularizationBias.add(new L1Regularization(d));
            }
            return this;
        }

        public T l2Bias(double d) {
            NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
            if (d > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization");
                this.regularizationBias.add(new L2Regularization(d));
            }
            return this;
        }

        public Builder weightDecay(double d) {
            return weightDecay(d, true);
        }

        public Builder weightDecay(double d, boolean z) {
            NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
            if (d > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization");
                this.regularization.add(new WeightDecay(d, z));
            }
            return this;
        }

        public Builder weightDecayBias(double d) {
            return weightDecayBias(d, true);
        }

        public Builder weightDecayBias(double d, boolean z) {
            NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
            if (d > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization");
                this.regularizationBias.add(new WeightDecay(d, z));
            }
            return this;
        }

        public Builder regularization(List<Regularization> list) {
            setRegularization(list);
            return this;
        }

        public Builder regularizationBias(List<Regularization> list) {
            setRegularizationBias(list);
            return this;
        }

        public T updater(IUpdater iUpdater) {
            setUpdater(iUpdater);
            return this;
        }

        public T biasUpdater(IUpdater iUpdater) {
            setBiasUpdater(iUpdater);
            return this;
        }

        public List<Regularization> getRegularization() {
            return this.regularization;
        }

        public List<Regularization> getRegularizationBias() {
            return this.regularizationBias;
        }

        public IUpdater getUpdater() {
            return this.updater;
        }

        public IUpdater getBiasUpdater() {
            return this.biasUpdater;
        }

        public void setRegularization(List<Regularization> list) {
            this.regularization = list;
        }

        public void setRegularizationBias(List<Regularization> list) {
            this.regularizationBias = list;
        }

        public void setUpdater(IUpdater iUpdater) {
            this.updater = iUpdater;
        }

        public void setBiasUpdater(IUpdater iUpdater) {
            this.biasUpdater = iUpdater;
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public List<Regularization> getRegularizationByParam(String str) {
        if (this.layerParams.isWeightParam(str)) {
            return this.regularization;
        }
        if (this.layerParams.isBiasParam(str)) {
            return this.regularizationBias;
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSameDiffLayer(Builder builder) {
        super(builder);
        this.gradientNormalizationThreshold = Double.NaN;
        this.regularization = builder.regularization;
        this.regularizationBias = builder.regularizationBias;
        this.updater = builder.updater;
        this.biasUpdater = builder.biasUpdater;
        try {
            getClass().getDeclaredConstructor(new Class[0]);
        } catch (NoSuchMethodException e) {
            log.warn("***SameDiff layer {} does not have a zero argument (no-arg) constructor.***\nA no-arg constructor is required for JSON deserialization, which is used for both model saving and distributed (Spark) training.\nA no-arg constructor (private, protected or public) as well as setters (or simply a Lombok @Data annotation) should be added to avoid JSON errors later.", getClass().getName());
        } catch (SecurityException e2) {
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSameDiffLayer() {
        this.gradientNormalizationThreshold = Double.NaN;
    }

    public SDLayerParams getLayerParams() {
        if (this.layerParams == null) {
            this.layerParams = new SDLayerParams();
            defineParameters(this.layerParams);
        }
        return this.layerParams;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return null;
    }

    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder builder) {
    }

    public abstract void defineParameters(SDLayerParams sDLayerParams);

    public abstract void initializeParameters(Map<String, INDArray> map);

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<TrainingListener> collection, int i, INDArray iNDArray, boolean z, DataType dataType);

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public ParamInitializer initializer() {
        return SameDiffParamInitializer.getInstance();
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public IUpdater getUpdaterByParam(String str) {
        if (this.biasUpdater != null && initializer().isBiasParam(this, str)) {
            return this.biasUpdater;
        }
        if (initializer().isBiasParam(this, str) || initializer().isWeightParam(this, str)) {
            return this.updater;
        }
        throw new IllegalStateException("Unknown parameter key: " + str);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public boolean isPretrainParam(String str) {
        return false;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        return new LayerMemoryReport();
    }

    public char paramReshapeOrder(String str) {
        return 'c';
    }

    protected void initWeights(int i, int i2, WeightInit weightInit, INDArray iNDArray) {
        WeightInitUtil.initWeights(i, i2, iNDArray.shape(), weightInit, (Distribution) null, paramReshapeOrder(null), iNDArray);
    }

    public void applyGlobalConfig(NeuralNetConfiguration.Builder builder) {
        if (this.regularization == null || this.regularization.isEmpty()) {
            this.regularization = builder.getRegularization();
        }
        if (this.regularizationBias == null || this.regularizationBias.isEmpty()) {
            this.regularizationBias = builder.getRegularizationBias();
        }
        if (this.updater == null) {
            this.updater = builder.getIUpdater();
        }
        if (this.biasUpdater == null) {
            this.biasUpdater = builder.getBiasUpdater();
        }
        if (this.gradientNormalization == null) {
            this.gradientNormalization = builder.getGradientNormalization();
        }
        if (Double.isNaN(this.gradientNormalizationThreshold)) {
            this.gradientNormalizationThreshold = builder.getGradientNormalizationThreshold();
        }
        applyGlobalConfigToLayer(builder);
    }

    public INDArray onesMaskForInput(INDArray iNDArray) {
        if (iNDArray.rank() == 2) {
            return Nd4j.ones(iNDArray.dataType(), new long[]{iNDArray.size(0), 1});
        }
        if (iNDArray.rank() == 3) {
            return Nd4j.ones(iNDArray.dataType(), new long[]{iNDArray.size(0), iNDArray.size(2)});
        }
        if (iNDArray.rank() == 4) {
            return Nd4j.ones(iNDArray.dataType(), new long[]{iNDArray.size(0), 1, 1, 1});
        }
        if (iNDArray.rank() == 5) {
            return Nd4j.ones(iNDArray.dataType(), new long[]{iNDArray.size(0), 1, 1, 1, 1});
        }
        throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, in order to determine the correct mask shape for this layer");
    }

    public List<Regularization> getRegularization() {
        return this.regularization;
    }

    public List<Regularization> getRegularizationBias() {
        return this.regularizationBias;
    }

    public IUpdater getUpdater() {
        return this.updater;
    }

    public IUpdater getBiasUpdater() {
        return this.biasUpdater;
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public GradientNormalization getGradientNormalization() {
        return this.gradientNormalization;
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public double getGradientNormalizationThreshold() {
        return this.gradientNormalizationThreshold;
    }

    public void setRegularization(List<Regularization> list) {
        this.regularization = list;
    }

    public void setRegularizationBias(List<Regularization> list) {
        this.regularizationBias = list;
    }

    public void setUpdater(IUpdater iUpdater) {
        this.updater = iUpdater;
    }

    public void setBiasUpdater(IUpdater iUpdater) {
        this.biasUpdater = iUpdater;
    }

    public void setGradientNormalization(GradientNormalization gradientNormalization) {
        this.gradientNormalization = gradientNormalization;
    }

    public void setGradientNormalizationThreshold(double d) {
        this.gradientNormalizationThreshold = d;
    }

    public void setLayerParams(SDLayerParams sDLayerParams) {
        this.layerParams = sDLayerParams;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "AbstractSameDiffLayer(regularization=" + getRegularization() + ", regularizationBias=" + getRegularizationBias() + ", updater=" + getUpdater() + ", biasUpdater=" + getBiasUpdater() + ", gradientNormalization=" + getGradientNormalization() + ", gradientNormalizationThreshold=" + getGradientNormalizationThreshold() + ", layerParams=" + getLayerParams() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AbstractSameDiffLayer)) {
            return false;
        }
        AbstractSameDiffLayer abstractSameDiffLayer = (AbstractSameDiffLayer) obj;
        if (!abstractSameDiffLayer.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        List<Regularization> regularization = getRegularization();
        List<Regularization> regularization2 = abstractSameDiffLayer.getRegularization();
        if (regularization == null) {
            if (regularization2 != null) {
                return false;
            }
        } else if (!regularization.equals(regularization2)) {
            return false;
        }
        List<Regularization> regularizationBias = getRegularizationBias();
        List<Regularization> regularizationBias2 = abstractSameDiffLayer.getRegularizationBias();
        if (regularizationBias == null) {
            if (regularizationBias2 != null) {
                return false;
            }
        } else if (!regularizationBias.equals(regularizationBias2)) {
            return false;
        }
        IUpdater updater = getUpdater();
        IUpdater updater2 = abstractSameDiffLayer.getUpdater();
        if (updater == null) {
            if (updater2 != null) {
                return false;
            }
        } else if (!updater.equals(updater2)) {
            return false;
        }
        IUpdater biasUpdater = getBiasUpdater();
        IUpdater biasUpdater2 = abstractSameDiffLayer.getBiasUpdater();
        if (biasUpdater == null) {
            if (biasUpdater2 != null) {
                return false;
            }
        } else if (!biasUpdater.equals(biasUpdater2)) {
            return false;
        }
        GradientNormalization gradientNormalization = getGradientNormalization();
        GradientNormalization gradientNormalization2 = abstractSameDiffLayer.getGradientNormalization();
        if (gradientNormalization == null) {
            if (gradientNormalization2 != null) {
                return false;
            }
        } else if (!gradientNormalization.equals(gradientNormalization2)) {
            return false;
        }
        if (Double.compare(getGradientNormalizationThreshold(), abstractSameDiffLayer.getGradientNormalizationThreshold()) != 0) {
            return false;
        }
        SDLayerParams layerParams = getLayerParams();
        SDLayerParams layerParams2 = abstractSameDiffLayer.getLayerParams();
        return layerParams == null ? layerParams2 == null : layerParams.equals(layerParams2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof AbstractSameDiffLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = super.hashCode();
        List<Regularization> regularization = getRegularization();
        int hashCode2 = (hashCode * 59) + (regularization == null ? 43 : regularization.hashCode());
        List<Regularization> regularizationBias = getRegularizationBias();
        int hashCode3 = (hashCode2 * 59) + (regularizationBias == null ? 43 : regularizationBias.hashCode());
        IUpdater updater = getUpdater();
        int hashCode4 = (hashCode3 * 59) + (updater == null ? 43 : updater.hashCode());
        IUpdater biasUpdater = getBiasUpdater();
        int hashCode5 = (hashCode4 * 59) + (biasUpdater == null ? 43 : biasUpdater.hashCode());
        GradientNormalization gradientNormalization = getGradientNormalization();
        int hashCode6 = (hashCode5 * 59) + (gradientNormalization == null ? 43 : gradientNormalization.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getGradientNormalizationThreshold());
        int i = (hashCode6 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        SDLayerParams layerParams = getLayerParams();
        return (i * 59) + (layerParams == null ? 43 : layerParams.hashCode());
    }
}
