package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayerUtils;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.PadMode;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties({"paramShapes"})
/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/LocallyConnected2D.class */
public class LocallyConnected2D extends SameDiffLayer {
    private static final List<String> WEIGHT_KEYS = Collections.singletonList("W");
    private static final List<String> BIAS_KEYS = Collections.singletonList("b");
    private static final List<String> PARAM_KEYS = Arrays.asList("b", "W");
    private long nIn;
    private long nOut;
    private Activation activation;
    private int[] kernel;
    private int[] stride;
    private int[] padding;
    private int[] paddingBr;
    private ConvolutionMode cm;
    private int[] dilation;
    private boolean hasBias;
    private int[] inputSize;
    private int[] outputSize;
    private int featureDim;
    protected CNN2DFormat format;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/LocallyConnected2D$Builder.class */
    public static class Builder extends SameDiffLayer.Builder<Builder> {
        private int nIn;
        private int nOut;
        private int[] inputSize;
        private Activation activation = Activation.TANH;
        private int[] kernel = {2, 2};
        private int[] stride = {1, 1};
        private int[] padding = {0, 0};
        private int[] dilation = {1, 1};
        private ConvolutionMode cm = ConvolutionMode.Same;
        private boolean hasBias = true;
        protected CNN2DFormat format = CNN2DFormat.NCHW;

        public void setKernel(int... iArr) {
            this.kernel = ValidationUtils.validate2NonNegative(iArr, false, "kernel");
        }

        public void setStride(int... iArr) {
            this.stride = ValidationUtils.validate2NonNegative(iArr, false, "stride");
        }

        public void setPadding(int... iArr) {
            this.padding = ValidationUtils.validate2NonNegative(iArr, false, "padding");
        }

        public void setDilation(int... iArr) {
            this.dilation = ValidationUtils.validate2NonNegative(iArr, false, "dilation");
        }

        public Builder nIn(int i) {
            setNIn(i);
            return this;
        }

        public Builder nOut(int i) {
            setNOut(i);
            return this;
        }

        public Builder activation(Activation activation) {
            setActivation(activation);
            return this;
        }

        public Builder kernelSize(int... iArr) {
            setKernel(iArr);
            return this;
        }

        public Builder stride(int... iArr) {
            setStride(iArr);
            return this;
        }

        public Builder padding(int... iArr) {
            setPadding(iArr);
            return this;
        }

        public Builder convolutionMode(ConvolutionMode convolutionMode) {
            setCm(convolutionMode);
            return this;
        }

        public Builder dilation(int... iArr) {
            setDilation(iArr);
            return this;
        }

        public Builder dataFormat(CNN2DFormat cNN2DFormat) {
            this.format = cNN2DFormat;
            return this;
        }

        public Builder hasBias(boolean z) {
            setHasBias(z);
            return this;
        }

        public Builder setInputSize(int... iArr) {
            this.inputSize = ValidationUtils.validate2(iArr, false, "inputSize");
            return this;
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public LocallyConnected2D build() {
            ConvolutionUtils.validateConvolutionModePadding(this.cm, this.padding);
            ConvolutionUtils.validateCnnKernelStridePadding(this.kernel, this.stride, this.padding);
            return new LocallyConnected2D(this);
        }

        public int getNIn() {
            return this.nIn;
        }

        public int getNOut() {
            return this.nOut;
        }

        public Activation getActivation() {
            return this.activation;
        }

        public int[] getKernel() {
            return this.kernel;
        }

        public int[] getStride() {
            return this.stride;
        }

        public int[] getPadding() {
            return this.padding;
        }

        public int[] getDilation() {
            return this.dilation;
        }

        public int[] getInputSize() {
            return this.inputSize;
        }

        public ConvolutionMode getCm() {
            return this.cm;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public CNN2DFormat getFormat() {
            return this.format;
        }

        public void setNIn(int i) {
            this.nIn = i;
        }

        public void setNOut(int i) {
            this.nOut = i;
        }

        public void setActivation(Activation activation) {
            this.activation = activation;
        }

        public void setCm(ConvolutionMode convolutionMode) {
            this.cm = convolutionMode;
        }

        public void setHasBias(boolean z) {
            this.hasBias = z;
        }

        public void setFormat(CNN2DFormat cNN2DFormat) {
            this.format = cNN2DFormat;
        }
    }

    protected LocallyConnected2D(Builder builder) {
        super(builder);
        this.format = CNN2DFormat.NCHW;
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
        this.activation = builder.activation;
        this.kernel = builder.kernel;
        this.stride = builder.stride;
        this.padding = builder.padding;
        this.cm = builder.cm;
        this.dilation = builder.dilation;
        this.hasBias = builder.hasBias;
        this.inputSize = builder.inputSize;
        this.featureDim = this.kernel[0] * this.kernel[1] * ((int) this.nIn);
        this.format = builder.format;
    }

    private LocallyConnected2D() {
        this.format = CNN2DFormat.NCHW;
    }

    public void computeOutputSize() {
        int nIn = (int) getNIn();
        if (this.inputSize == null) {
            throw new IllegalArgumentException("Input size has to be specified for locally connected layers.");
        }
        INDArray ones = Nd4j.ones(this.format == CNN2DFormat.NCHW ? new int[]{1, nIn, this.inputSize[0], this.inputSize[1]} : new int[]{1, this.inputSize[0], this.inputSize[1], nIn});
        if (this.cm != ConvolutionMode.Same) {
            this.outputSize = ConvolutionUtils.getOutputSize(ones, this.kernel, this.stride, this.padding, this.cm, this.dilation, this.format);
            return;
        }
        this.outputSize = ConvolutionUtils.getOutputSize(ones, this.kernel, this.stride, null, this.cm, this.dilation, this.format);
        this.padding = ConvolutionUtils.getSameModeTopLeftPadding(this.outputSize, this.inputSize, this.kernel, this.stride, this.dilation);
        this.paddingBr = ConvolutionUtils.getSameModeBottomRightPadding(this.outputSize, this.inputSize, this.kernel, this.stride, this.dilation);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalArgumentException("Provided input type for locally connected 2D layers has to be of CNN type, got: " + inputType);
        }
        InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) inputType;
        this.inputSize = new int[]{(int) inputTypeConvolutional.getHeight(), (int) inputTypeConvolutional.getWidth()};
        computeOutputSize();
        return InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernel, this.stride, this.padding, new int[]{1, 1}, this.cm, this.nOut, i, getLayerName(), this.format, LocallyConnected2D.class);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
        if (this.nIn <= 0 || z) {
            this.nIn = ((InputType.InputTypeConvolutional) inputType).getChannels();
            this.featureDim = this.kernel[0] * this.kernel[1] * ((int) this.nIn);
        }
        this.format = ((InputType.InputTypeConvolutional) inputType).getFormat();
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getLayerName());
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void defineParameters(SDLayerParams sDLayerParams) {
        sDLayerParams.clear();
        sDLayerParams.addWeightParam("W", this.outputSize[0] * this.outputSize[1], this.featureDim, this.nOut);
        if (this.hasBias) {
            sDLayerParams.addBiasParam("b", this.nOut);
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void initializeParameters(Map<String, INDArray> map) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                    if ("b".equals(entry.getKey())) {
                        entry.getValue().assign(0);
                    } else {
                        WeightInitUtil.initWeights(this.nIn * this.kernel[0] * this.kernel[1], ((this.nOut * this.kernel[0]) * this.kernel[1]) / (this.stride[0] * this.stride[1]), entry.getValue().shape(), this.weightInit, (Distribution) null, 'c', entry.getValue());
                    }
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v5, types: [int[], int[][]] */
    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable sDVariable, Map<String, SDVariable> map, SDVariable sDVariable2) {
        SDVariable sDVariable3 = map.get("W");
        long j = sDVariable.getShape()[0];
        int i = this.outputSize[0];
        int i2 = this.outputSize[1];
        int i3 = this.stride[0];
        int i4 = this.stride[1];
        int i5 = this.kernel[0];
        int i6 = this.kernel[1];
        boolean z = this.format == CNN2DFormat.NCHW;
        if (!z) {
            sDVariable = sDVariable.permute(new int[]{0, 3, 1, 2});
        }
        if (this.padding[0] > 0 || this.padding[1] > 0 || (this.cm == ConvolutionMode.Same && (this.paddingBr[0] > 0 || this.paddingBr[1] > 0))) {
            sDVariable = this.cm == ConvolutionMode.Same ? sameDiff.nn().pad(sDVariable, sameDiff.constant(Nd4j.createFromArray((int[][]) new int[]{new int[]{0, 0}, new int[]{0, 0}, new int[]{this.padding[0], this.paddingBr[0]}, new int[]{this.padding[1], this.paddingBr[1]}})), PadMode.CONSTANT, EvaluationBinary.DEFAULT_EDGE_VALUE) : sameDiff.nn().pad(sDVariable, sameDiff.constant(Nd4j.createFromArray((int[][]) new int[]{new int[]{0, 0}, new int[]{0, 0}, new int[]{this.padding[0], this.padding[0]}, new int[]{this.padding[1], this.padding[1]}})), PadMode.CONSTANT, EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
        SDVariable[] sDVariableArr = new SDVariable[i * i2];
        for (int i7 = 0; i7 < i; i7++) {
            for (int i8 = 0; i8 < i2; i8++) {
                sDVariableArr[(i7 * i) + i8] = sameDiff.reshape(sDVariable.get(new SDIndex[]{SDIndex.all(), SDIndex.all(), SDIndex.interval(Integer.valueOf(i7 * i3), Integer.valueOf((i7 * i3) + i5)), SDIndex.interval(Integer.valueOf(i8 * i4), Integer.valueOf((i8 * i4) + i6))}), new long[]{1, j, this.featureDim});
            }
        }
        SDVariable reshape = sameDiff.reshape(sameDiff.mmul(sameDiff.concat(0, sDVariableArr), sDVariable3), new long[]{i, i2, j, this.nOut});
        SDVariable permute = z ? reshape.permute(new int[]{2, 3, 0, 1}) : reshape.permute(new int[]{2, 0, 1, 3});
        return this.hasBias ? this.activation.asSameDiff("out", sameDiff, sameDiff.nn().biasAdd(permute, map.get("b"), z)) : this.activation.asSameDiff("out", sameDiff, permute);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer
    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder builder) {
        if (this.activation == null) {
            this.activation = SameDiffLayerUtils.fromIActivation(builder.getActivationFn());
        }
        if (this.cm == null) {
            this.cm = builder.getConvolutionMode();
        }
    }

    public long getNIn() {
        return this.nIn;
    }

    public long getNOut() {
        return this.nOut;
    }

    public Activation getActivation() {
        return this.activation;
    }

    public int[] getKernel() {
        return this.kernel;
    }

    public int[] getStride() {
        return this.stride;
    }

    public int[] getPadding() {
        return this.padding;
    }

    public int[] getPaddingBr() {
        return this.paddingBr;
    }

    public ConvolutionMode getCm() {
        return this.cm;
    }

    public int[] getDilation() {
        return this.dilation;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public int[] getInputSize() {
        return this.inputSize;
    }

    public int[] getOutputSize() {
        return this.outputSize;
    }

    public int getFeatureDim() {
        return this.featureDim;
    }

    public CNN2DFormat getFormat() {
        return this.format;
    }

    public void setNIn(long j) {
        this.nIn = j;
    }

    public void setNOut(long j) {
        this.nOut = j;
    }

    public void setActivation(Activation activation) {
        this.activation = activation;
    }

    public void setKernel(int[] iArr) {
        this.kernel = iArr;
    }

    public void setStride(int[] iArr) {
        this.stride = iArr;
    }

    public void setPadding(int[] iArr) {
        this.padding = iArr;
    }

    public void setPaddingBr(int[] iArr) {
        this.paddingBr = iArr;
    }

    public void setCm(ConvolutionMode convolutionMode) {
        this.cm = convolutionMode;
    }

    public void setDilation(int[] iArr) {
        this.dilation = iArr;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public void setInputSize(int[] iArr) {
        this.inputSize = iArr;
    }

    public void setOutputSize(int[] iArr) {
        this.outputSize = iArr;
    }

    public void setFeatureDim(int i) {
        this.featureDim = i;
    }

    public void setFormat(CNN2DFormat cNN2DFormat) {
        this.format = cNN2DFormat;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "LocallyConnected2D(nIn=" + getNIn() + ", nOut=" + getNOut() + ", activation=" + getActivation() + ", kernel=" + Arrays.toString(getKernel()) + ", stride=" + Arrays.toString(getStride()) + ", padding=" + Arrays.toString(getPadding()) + ", paddingBr=" + Arrays.toString(getPaddingBr()) + ", cm=" + getCm() + ", dilation=" + Arrays.toString(getDilation()) + ", hasBias=" + isHasBias() + ", inputSize=" + Arrays.toString(getInputSize()) + ", outputSize=" + Arrays.toString(getOutputSize()) + ", featureDim=" + getFeatureDim() + ", format=" + getFormat() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LocallyConnected2D)) {
            return false;
        }
        LocallyConnected2D locallyConnected2D = (LocallyConnected2D) obj;
        if (!locallyConnected2D.canEqual(this) || !super.equals(obj) || getNIn() != locallyConnected2D.getNIn() || getNOut() != locallyConnected2D.getNOut()) {
            return false;
        }
        Activation activation = getActivation();
        Activation activation2 = locallyConnected2D.getActivation();
        if (activation == null) {
            if (activation2 != null) {
                return false;
            }
        } else if (!activation.equals(activation2)) {
            return false;
        }
        if (!Arrays.equals(getKernel(), locallyConnected2D.getKernel()) || !Arrays.equals(getStride(), locallyConnected2D.getStride()) || !Arrays.equals(getPadding(), locallyConnected2D.getPadding()) || !Arrays.equals(getPaddingBr(), locallyConnected2D.getPaddingBr())) {
            return false;
        }
        ConvolutionMode cm = getCm();
        ConvolutionMode cm2 = locallyConnected2D.getCm();
        if (cm == null) {
            if (cm2 != null) {
                return false;
            }
        } else if (!cm.equals(cm2)) {
            return false;
        }
        if (!Arrays.equals(getDilation(), locallyConnected2D.getDilation()) || isHasBias() != locallyConnected2D.isHasBias() || !Arrays.equals(getInputSize(), locallyConnected2D.getInputSize()) || !Arrays.equals(getOutputSize(), locallyConnected2D.getOutputSize()) || getFeatureDim() != locallyConnected2D.getFeatureDim()) {
            return false;
        }
        CNN2DFormat format = getFormat();
        CNN2DFormat format2 = locallyConnected2D.getFormat();
        return format == null ? format2 == null : format.equals(format2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof LocallyConnected2D;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer, org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = super.hashCode();
        long nIn = getNIn();
        int i = (hashCode * 59) + ((int) ((nIn >>> 32) ^ nIn));
        long nOut = getNOut();
        int i2 = (i * 59) + ((int) ((nOut >>> 32) ^ nOut));
        Activation activation = getActivation();
        int hashCode2 = (((((((((i2 * 59) + (activation == null ? 43 : activation.hashCode())) * 59) + Arrays.hashCode(getKernel())) * 59) + Arrays.hashCode(getStride())) * 59) + Arrays.hashCode(getPadding())) * 59) + Arrays.hashCode(getPaddingBr());
        ConvolutionMode cm = getCm();
        int hashCode3 = (((((((((((hashCode2 * 59) + (cm == null ? 43 : cm.hashCode())) * 59) + Arrays.hashCode(getDilation())) * 59) + (isHasBias() ? 79 : 97)) * 59) + Arrays.hashCode(getInputSize())) * 59) + Arrays.hashCode(getOutputSize())) * 59) + getFeatureDim();
        CNN2DFormat format = getFormat();
        return (hashCode3 * 59) + (format == null ? 43 : format.hashCode());
    }
}
