package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonIgnore;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer.class */
public class ConvolutionLayer extends FeedForwardLayer {
    protected boolean hasBias;
    protected ConvolutionMode convolutionMode;
    protected int[] dilation;
    protected int[] kernelSize;
    protected int[] stride;
    protected int[] padding;
    protected boolean cudnnAllowFallback;
    protected CNN2DFormat cnn2dDataFormat;

    @JsonIgnore
    private boolean defaultValueOverriden;
    protected AlgoMode cudnnAlgoMode;
    protected FwdAlgo cudnnFwdAlgo;
    protected BwdFilterAlgo cudnnBwdFilterAlgo;
    protected BwdDataAlgo cudnnBwdDataAlgo;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer$AlgoMode.class */
    public enum AlgoMode {
        NO_WORKSPACE,
        PREFER_FASTEST,
        USER_SPECIFIED
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer$BaseConvBuilder.class */
    public static abstract class BaseConvBuilder<T extends BaseConvBuilder<T>> extends FeedForwardLayer.Builder<T> {
        protected ConvolutionMode convolutionMode;
        protected FwdAlgo cudnnFwdAlgo;
        protected BwdFilterAlgo cudnnBwdFilterAlgo;
        protected BwdDataAlgo cudnnBwdDataAlgo;
        protected int convolutionDim = 2;
        protected boolean hasBias = true;
        protected int[] dilation = {1, 1};
        public int[] kernelSize = {5, 5};
        protected int[] stride = {1, 1};
        protected int[] padding = {0, 0};
        protected AlgoMode cudnnAlgoMode = null;
        protected boolean cudnnAllowFallback = true;

        /* JADX INFO: Access modifiers changed from: protected */
        public BaseConvBuilder(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, int i) {
            setKernelSize(iArr);
            setStride(iArr2);
            setPadding(iArr3);
            setDilation(iArr4);
            setConvolutionDim(i);
        }

        protected BaseConvBuilder(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
            setKernelSize(iArr);
            setStride(iArr2);
            setPadding(iArr3);
            setDilation(iArr4);
        }

        protected BaseConvBuilder(int[] iArr, int[] iArr2, int[] iArr3, int i) {
            setKernelSize(iArr);
            setStride(iArr2);
            setPadding(iArr3);
            setConvolutionDim(i);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public BaseConvBuilder(int[] iArr, int[] iArr2, int[] iArr3) {
            setKernelSize(iArr);
            setStride(iArr2);
            setPadding(iArr3);
        }

        protected BaseConvBuilder(int[] iArr, int[] iArr2, int i) {
            setKernelSize(iArr);
            setStride(iArr2);
            setConvolutionDim(i);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public BaseConvBuilder(int[] iArr, int[] iArr2) {
            setKernelSize(iArr);
            setStride(iArr2);
        }

        protected BaseConvBuilder(int i, int... iArr) {
            setKernelSize(iArr);
            setConvolutionDim(i);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public BaseConvBuilder(int... iArr) {
            setKernelSize(iArr);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public BaseConvBuilder() {
        }

        protected abstract boolean allowCausal();

        /* JADX INFO: Access modifiers changed from: protected */
        public void setConvolutionMode(ConvolutionMode convolutionMode) {
            Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D convolutional neural network layers");
            this.convolutionMode = convolutionMode;
        }

        public T hasBias(boolean z) {
            setHasBias(z);
            return this;
        }

        public T convolutionMode(ConvolutionMode convolutionMode) {
            setConvolutionMode(convolutionMode);
            return this;
        }

        public T dilation(int... iArr) {
            setDilation(iArr);
            return this;
        }

        public T kernelSize(int... iArr) {
            setKernelSize(iArr);
            return this;
        }

        public T stride(int... iArr) {
            setStride(iArr);
            return this;
        }

        public T padding(int... iArr) {
            setPadding(iArr);
            return this;
        }

        public T cudnnAlgoMode(AlgoMode algoMode) {
            setCudnnAlgoMode(algoMode);
            return this;
        }

        public T cudnnFwdMode(FwdAlgo fwdAlgo) {
            setCudnnFwdAlgo(fwdAlgo);
            return this;
        }

        public T cudnnBwdFilterMode(BwdFilterAlgo bwdFilterAlgo) {
            setCudnnBwdFilterAlgo(bwdFilterAlgo);
            return this;
        }

        public T cudnnBwdDataMode(BwdDataAlgo bwdDataAlgo) {
            setCudnnBwdDataAlgo(bwdDataAlgo);
            return this;
        }

        @Deprecated
        public T cudnnAllowFallback(boolean z) {
            setCudnnAllowFallback(z);
            return this;
        }

        public T helperAllowFallback(boolean z) {
            this.cudnnAllowFallback = z;
            return this;
        }

        public int getConvolutionDim() {
            return this.convolutionDim;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public ConvolutionMode getConvolutionMode() {
            return this.convolutionMode;
        }

        public int[] getDilation() {
            return this.dilation;
        }

        public int[] getKernelSize() {
            return this.kernelSize;
        }

        public int[] getStride() {
            return this.stride;
        }

        public int[] getPadding() {
            return this.padding;
        }

        public AlgoMode getCudnnAlgoMode() {
            return this.cudnnAlgoMode;
        }

        public FwdAlgo getCudnnFwdAlgo() {
            return this.cudnnFwdAlgo;
        }

        public BwdFilterAlgo getCudnnBwdFilterAlgo() {
            return this.cudnnBwdFilterAlgo;
        }

        public BwdDataAlgo getCudnnBwdDataAlgo() {
            return this.cudnnBwdDataAlgo;
        }

        public boolean isCudnnAllowFallback() {
            return this.cudnnAllowFallback;
        }

        public void setConvolutionDim(int i) {
            this.convolutionDim = i;
        }

        public void setHasBias(boolean z) {
            this.hasBias = z;
        }

        public void setDilation(int[] iArr) {
            this.dilation = iArr;
        }

        public void setKernelSize(int[] iArr) {
            this.kernelSize = iArr;
        }

        public void setStride(int[] iArr) {
            this.stride = iArr;
        }

        public void setPadding(int[] iArr) {
            this.padding = iArr;
        }

        public void setCudnnAlgoMode(AlgoMode algoMode) {
            this.cudnnAlgoMode = algoMode;
        }

        public void setCudnnFwdAlgo(FwdAlgo fwdAlgo) {
            this.cudnnFwdAlgo = fwdAlgo;
        }

        public void setCudnnBwdFilterAlgo(BwdFilterAlgo bwdFilterAlgo) {
            this.cudnnBwdFilterAlgo = bwdFilterAlgo;
        }

        public void setCudnnBwdDataAlgo(BwdDataAlgo bwdDataAlgo) {
            this.cudnnBwdDataAlgo = bwdDataAlgo;
        }

        public void setCudnnAllowFallback(boolean z) {
            this.cudnnAllowFallback = z;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer$Builder.class */
    public static class Builder extends BaseConvBuilder<Builder> {
        protected CNN2DFormat dataFormat;

        public Builder(int[] iArr, int[] iArr2, int[] iArr3) {
            super(iArr, iArr2, iArr3);
            this.dataFormat = CNN2DFormat.NCHW;
        }

        public Builder(int[] iArr, int[] iArr2) {
            super(iArr, iArr2);
            this.dataFormat = CNN2DFormat.NCHW;
        }

        public Builder(int... iArr) {
            super(iArr);
            this.dataFormat = CNN2DFormat.NCHW;
        }

        public Builder() {
            this.dataFormat = CNN2DFormat.NCHW;
        }

        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        protected boolean allowCausal() {
            return false;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public Builder kernelSize(int... iArr) {
            setKernelSize(iArr);
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public Builder stride(int... iArr) {
            setStride(iArr);
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public Builder padding(int... iArr) {
            setPadding(iArr);
            return this;
        }

        public Builder dataFormat(CNN2DFormat cNN2DFormat) {
            this.dataFormat = cNN2DFormat;
            return this;
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public ConvolutionLayer build() {
            ConvolutionUtils.validateConvolutionModePadding(this.convolutionMode, this.padding);
            ConvolutionUtils.validateCnnKernelStridePadding(this.kernelSize, this.stride, this.padding);
            return new ConvolutionLayer(this);
        }

        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public void setKernelSize(int... iArr) {
            this.kernelSize = ValidationUtils.validate2NonNegative(iArr, false, "kernelSize");
        }

        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public void setStride(int... iArr) {
            this.stride = ValidationUtils.validate2NonNegative(iArr, false, "stride");
        }

        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public void setPadding(int... iArr) {
            this.padding = ValidationUtils.validate2NonNegative(iArr, false, "padding");
        }

        @Override // org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BaseConvBuilder
        public void setDilation(int... iArr) {
            this.dilation = ValidationUtils.validate2NonNegative(iArr, false, "dilation");
        }

        public void setDataFormat(CNN2DFormat cNN2DFormat) {
            this.dataFormat = cNN2DFormat;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer$BwdDataAlgo.class */
    public enum BwdDataAlgo {
        ALGO_0,
        ALGO_1,
        FFT,
        FFT_TILING,
        WINOGRAD,
        WINOGRAD_NONFUSED,
        COUNT
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer$BwdFilterAlgo.class */
    public enum BwdFilterAlgo {
        ALGO_0,
        ALGO_1,
        FFT,
        ALGO_3,
        WINOGRAD,
        WINOGRAD_NONFUSED,
        FFT_TILING,
        COUNT
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/ConvolutionLayer$FwdAlgo.class */
    public enum FwdAlgo {
        IMPLICIT_GEMM,
        IMPLICIT_PRECOMP_GEMM,
        GEMM,
        DIRECT,
        FFT,
        FFT_TILING,
        WINOGRAD,
        WINOGRAD_NONFUSED,
        COUNT
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ConvolutionLayer(BaseConvBuilder<?> baseConvBuilder) {
        super(baseConvBuilder);
        this.hasBias = true;
        this.convolutionMode = ConvolutionMode.Truncate;
        this.dilation = new int[]{1, 1};
        this.cudnnAllowFallback = true;
        this.cnn2dDataFormat = CNN2DFormat.NCHW;
        this.defaultValueOverriden = false;
        this.cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
        int i = baseConvBuilder.convolutionDim;
        this.hasBias = baseConvBuilder.hasBias;
        this.convolutionMode = baseConvBuilder.convolutionMode;
        this.dilation = baseConvBuilder.dilation;
        if (baseConvBuilder.kernelSize.length != i) {
            throw new IllegalArgumentException("Kernel argument should be a " + i + "d array, got " + Arrays.toString(baseConvBuilder.kernelSize));
        }
        this.kernelSize = baseConvBuilder.kernelSize;
        if (baseConvBuilder.stride.length != i) {
            throw new IllegalArgumentException("Strides argument should be a " + i + "d array, got " + Arrays.toString(baseConvBuilder.stride));
        }
        this.stride = baseConvBuilder.stride;
        if (baseConvBuilder.padding.length != i) {
            throw new IllegalArgumentException("Padding argument should be a " + i + "d array, got " + Arrays.toString(baseConvBuilder.padding));
        }
        this.padding = baseConvBuilder.padding;
        if (baseConvBuilder.dilation.length != i) {
            throw new IllegalArgumentException("Dilation argument should be a " + i + "d array, got " + Arrays.toString(baseConvBuilder.dilation));
        }
        this.dilation = baseConvBuilder.dilation;
        this.cudnnAlgoMode = baseConvBuilder.cudnnAlgoMode;
        this.cudnnFwdAlgo = baseConvBuilder.cudnnFwdAlgo;
        this.cudnnBwdFilterAlgo = baseConvBuilder.cudnnBwdFilterAlgo;
        this.cudnnBwdDataAlgo = baseConvBuilder.cudnnBwdDataAlgo;
        this.cudnnAllowFallback = baseConvBuilder.cudnnAllowFallback;
        if (baseConvBuilder instanceof Builder) {
            this.cnn2dDataFormat = ((Builder) baseConvBuilder).dataFormat;
        }
        initializeConstraints(baseConvBuilder);
    }

    public boolean hasBias() {
        return this.hasBias;
    }

    @Override // org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    /* renamed from: clone */
    public ConvolutionLayer mo62clone() {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) super.mo62clone();
        if (convolutionLayer.kernelSize != null) {
            convolutionLayer.kernelSize = (int[]) convolutionLayer.kernelSize.clone();
        }
        if (convolutionLayer.stride != null) {
            convolutionLayer.stride = (int[]) convolutionLayer.stride.clone();
        }
        if (convolutionLayer.padding != null) {
            convolutionLayer.padding = (int[]) convolutionLayer.padding.clone();
        }
        return convolutionLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<TrainingListener> collection, int i, INDArray iNDArray, boolean z, DataType dataType) {
        LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), i, getNIn(), getNOut());
        org.deeplearning4j.nn.layers.convolution.ConvolutionLayer convolutionLayer = new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(neuralNetConfiguration, dataType);
        convolutionLayer.setListeners(collection);
        convolutionLayer.setIndex(i);
        convolutionLayer.setParamsViewArray(iNDArray);
        convolutionLayer.setParamTable(initializer().init(neuralNetConfiguration, iNDArray, z));
        convolutionLayer.setConf(neuralNetConfiguration);
        return convolutionLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public ParamInitializer initializer() {
        return ConvolutionParamInitializer.getInstance();
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input for Convolution layer (layer name=\"" + getLayerName() + "\"): Expected CNN input, got " + inputType);
        }
        return InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernelSize, this.stride, this.padding, this.dilation, this.convolutionMode, this.nOut, i, getLayerName(), this.cnn2dDataFormat, ConvolutionLayer.class);
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input for Convolution layer (layer name=\"" + getLayerName() + "\"): Expected CNN input, got " + inputType);
        }
        if (!this.defaultValueOverriden || this.nIn <= 0 || z) {
            this.nIn = ((InputType.InputTypeConvolutional) inputType).getChannels();
            this.cnn2dDataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
        }
        if (this.cnn2dDataFormat == null || z) {
            this.cnn2dDataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        if (inputType == null) {
            throw new IllegalStateException("Invalid input for Convolution layer (layer name=\"" + getLayerName() + "\"): input is null");
        }
        return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getLayerName());
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        long j;
        long numParams = initializer().numParams(this);
        int stateSize = (int) getIUpdater().stateSize(numParams);
        InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) getOutputType(-1, inputType);
        long channels = ((InputType.InputTypeConvolutional) inputType).getChannels() * inputTypeConvolutional.getHeight() * inputTypeConvolutional.getWidth() * this.kernelSize[0] * this.kernelSize[1];
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (CacheMode cacheMode : CacheMode.values()) {
            long j2 = 0;
            if (cacheMode == CacheMode.NONE) {
                j = 2 * channels;
            } else {
                j2 = channels;
                j = channels;
            }
            if (getIDropout() != null) {
                j += inputType.arrayElementsPerExample();
            }
            hashMap.put(cacheMode, Long.valueOf(j));
            hashMap2.put(cacheMode, Long.valueOf(j2));
        }
        return new LayerMemoryReport.Builder(this.layerName, ConvolutionLayer.class, inputType, inputTypeConvolutional).standardMemory(numParams, stateSize).workingMemory(0L, channels, MemoryReport.CACHE_MODE_ALL_ZEROS, hashMap).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, hashMap2).build();
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public ConvolutionMode getConvolutionMode() {
        return this.convolutionMode;
    }

    public int[] getDilation() {
        return this.dilation;
    }

    public int[] getKernelSize() {
        return this.kernelSize;
    }

    public int[] getStride() {
        return this.stride;
    }

    public int[] getPadding() {
        return this.padding;
    }

    public boolean isCudnnAllowFallback() {
        return this.cudnnAllowFallback;
    }

    public CNN2DFormat getCnn2dDataFormat() {
        return this.cnn2dDataFormat;
    }

    public boolean isDefaultValueOverriden() {
        return this.defaultValueOverriden;
    }

    public AlgoMode getCudnnAlgoMode() {
        return this.cudnnAlgoMode;
    }

    public FwdAlgo getCudnnFwdAlgo() {
        return this.cudnnFwdAlgo;
    }

    public BwdFilterAlgo getCudnnBwdFilterAlgo() {
        return this.cudnnBwdFilterAlgo;
    }

    public BwdDataAlgo getCudnnBwdDataAlgo() {
        return this.cudnnBwdDataAlgo;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public void setConvolutionMode(ConvolutionMode convolutionMode) {
        this.convolutionMode = convolutionMode;
    }

    public void setDilation(int[] iArr) {
        this.dilation = iArr;
    }

    public void setKernelSize(int[] iArr) {
        this.kernelSize = iArr;
    }

    public void setStride(int[] iArr) {
        this.stride = iArr;
    }

    public void setPadding(int[] iArr) {
        this.padding = iArr;
    }

    public void setCudnnAllowFallback(boolean z) {
        this.cudnnAllowFallback = z;
    }

    public void setCnn2dDataFormat(CNN2DFormat cNN2DFormat) {
        this.cnn2dDataFormat = cNN2DFormat;
    }

    public void setDefaultValueOverriden(boolean z) {
        this.defaultValueOverriden = z;
    }

    public void setCudnnAlgoMode(AlgoMode algoMode) {
        this.cudnnAlgoMode = algoMode;
    }

    public void setCudnnFwdAlgo(FwdAlgo fwdAlgo) {
        this.cudnnFwdAlgo = fwdAlgo;
    }

    public void setCudnnBwdFilterAlgo(BwdFilterAlgo bwdFilterAlgo) {
        this.cudnnBwdFilterAlgo = bwdFilterAlgo;
    }

    public void setCudnnBwdDataAlgo(BwdDataAlgo bwdDataAlgo) {
        this.cudnnBwdDataAlgo = bwdDataAlgo;
    }

    public ConvolutionLayer() {
        this.hasBias = true;
        this.convolutionMode = ConvolutionMode.Truncate;
        this.dilation = new int[]{1, 1};
        this.cudnnAllowFallback = true;
        this.cnn2dDataFormat = CNN2DFormat.NCHW;
        this.defaultValueOverriden = false;
        this.cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "ConvolutionLayer(super=" + super.toString() + ", hasBias=" + isHasBias() + ", convolutionMode=" + getConvolutionMode() + ", dilation=" + Arrays.toString(getDilation()) + ", kernelSize=" + Arrays.toString(getKernelSize()) + ", stride=" + Arrays.toString(getStride()) + ", padding=" + Arrays.toString(getPadding()) + ", cudnnAllowFallback=" + isCudnnAllowFallback() + ", cnn2dDataFormat=" + getCnn2dDataFormat() + ", defaultValueOverriden=" + isDefaultValueOverriden() + ", cudnnAlgoMode=" + getCudnnAlgoMode() + ", cudnnFwdAlgo=" + getCudnnFwdAlgo() + ", cudnnBwdFilterAlgo=" + getCudnnBwdFilterAlgo() + ", cudnnBwdDataAlgo=" + getCudnnBwdDataAlgo() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ConvolutionLayer)) {
            return false;
        }
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) obj;
        if (!convolutionLayer.canEqual(this) || !super.equals(obj) || isHasBias() != convolutionLayer.isHasBias()) {
            return false;
        }
        ConvolutionMode convolutionMode = getConvolutionMode();
        ConvolutionMode convolutionMode2 = convolutionLayer.getConvolutionMode();
        if (convolutionMode == null) {
            if (convolutionMode2 != null) {
                return false;
            }
        } else if (!convolutionMode.equals(convolutionMode2)) {
            return false;
        }
        if (!Arrays.equals(getDilation(), convolutionLayer.getDilation()) || !Arrays.equals(getKernelSize(), convolutionLayer.getKernelSize()) || !Arrays.equals(getStride(), convolutionLayer.getStride()) || !Arrays.equals(getPadding(), convolutionLayer.getPadding()) || isCudnnAllowFallback() != convolutionLayer.isCudnnAllowFallback()) {
            return false;
        }
        CNN2DFormat cnn2dDataFormat = getCnn2dDataFormat();
        CNN2DFormat cnn2dDataFormat2 = convolutionLayer.getCnn2dDataFormat();
        if (cnn2dDataFormat == null) {
            if (cnn2dDataFormat2 != null) {
                return false;
            }
        } else if (!cnn2dDataFormat.equals(cnn2dDataFormat2)) {
            return false;
        }
        AlgoMode cudnnAlgoMode = getCudnnAlgoMode();
        AlgoMode cudnnAlgoMode2 = convolutionLayer.getCudnnAlgoMode();
        if (cudnnAlgoMode == null) {
            if (cudnnAlgoMode2 != null) {
                return false;
            }
        } else if (!cudnnAlgoMode.equals(cudnnAlgoMode2)) {
            return false;
        }
        FwdAlgo cudnnFwdAlgo = getCudnnFwdAlgo();
        FwdAlgo cudnnFwdAlgo2 = convolutionLayer.getCudnnFwdAlgo();
        if (cudnnFwdAlgo == null) {
            if (cudnnFwdAlgo2 != null) {
                return false;
            }
        } else if (!cudnnFwdAlgo.equals(cudnnFwdAlgo2)) {
            return false;
        }
        BwdFilterAlgo cudnnBwdFilterAlgo = getCudnnBwdFilterAlgo();
        BwdFilterAlgo cudnnBwdFilterAlgo2 = convolutionLayer.getCudnnBwdFilterAlgo();
        if (cudnnBwdFilterAlgo == null) {
            if (cudnnBwdFilterAlgo2 != null) {
                return false;
            }
        } else if (!cudnnBwdFilterAlgo.equals(cudnnBwdFilterAlgo2)) {
            return false;
        }
        BwdDataAlgo cudnnBwdDataAlgo = getCudnnBwdDataAlgo();
        BwdDataAlgo cudnnBwdDataAlgo2 = convolutionLayer.getCudnnBwdDataAlgo();
        return cudnnBwdDataAlgo == null ? cudnnBwdDataAlgo2 == null : cudnnBwdDataAlgo.equals(cudnnBwdDataAlgo2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof ConvolutionLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = (super.hashCode() * 59) + (isHasBias() ? 79 : 97);
        ConvolutionMode convolutionMode = getConvolutionMode();
        int hashCode2 = (((((((((((hashCode * 59) + (convolutionMode == null ? 43 : convolutionMode.hashCode())) * 59) + Arrays.hashCode(getDilation())) * 59) + Arrays.hashCode(getKernelSize())) * 59) + Arrays.hashCode(getStride())) * 59) + Arrays.hashCode(getPadding())) * 59) + (isCudnnAllowFallback() ? 79 : 97);
        CNN2DFormat cnn2dDataFormat = getCnn2dDataFormat();
        int hashCode3 = (hashCode2 * 59) + (cnn2dDataFormat == null ? 43 : cnn2dDataFormat.hashCode());
        AlgoMode cudnnAlgoMode = getCudnnAlgoMode();
        int hashCode4 = (hashCode3 * 59) + (cudnnAlgoMode == null ? 43 : cudnnAlgoMode.hashCode());
        FwdAlgo cudnnFwdAlgo = getCudnnFwdAlgo();
        int hashCode5 = (hashCode4 * 59) + (cudnnFwdAlgo == null ? 43 : cudnnFwdAlgo.hashCode());
        BwdFilterAlgo cudnnBwdFilterAlgo = getCudnnBwdFilterAlgo();
        int hashCode6 = (hashCode5 * 59) + (cudnnBwdFilterAlgo == null ? 43 : cudnnBwdFilterAlgo.hashCode());
        BwdDataAlgo cudnnBwdDataAlgo = getCudnnBwdDataAlgo();
        return (hashCode6 * 59) + (cudnnBwdDataAlgo == null ? 43 : cudnnBwdDataAlgo.hashCode());
    }
}
