package org.deeplearning4j.nn.conf.inputs;

import java.io.Serializable;
import java.util.Arrays;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;

@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
/* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType.class */
public abstract class InputType implements Serializable {

    /* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType$InputTypeConvolutional.class */
    public static class InputTypeConvolutional extends InputType {
        private long height;
        private long width;
        private long channels;

        public InputTypeConvolutional(@JsonProperty("height") long j, @JsonProperty("width") long j2, @JsonProperty("channels") long j3) {
            this.height = j;
            this.width = j2;
            this.channels = j3;
        }

        @Deprecated
        public long getDepth() {
            return this.channels;
        }

        @Deprecated
        public void setDepth(long j) {
            this.channels = j;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public Type getType() {
            return Type.CNN;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public String toString() {
            return "InputTypeConvolutional(h=" + this.height + ",w=" + this.width + ",c=" + this.channels + ")";
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long arrayElementsPerExample() {
            return this.height * this.width * this.channels;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long[] getShape(boolean z) {
            return z ? new long[]{-1, this.channels, this.height, this.width} : new long[]{this.channels, this.height, this.width};
        }

        public InputTypeConvolutional() {
        }

        public long getHeight() {
            return this.height;
        }

        public long getWidth() {
            return this.width;
        }

        public long getChannels() {
            return this.channels;
        }

        public void setHeight(long j) {
            this.height = j;
        }

        public void setWidth(long j) {
            this.width = j;
        }

        public void setChannels(long j) {
            this.channels = j;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InputTypeConvolutional)) {
                return false;
            }
            InputTypeConvolutional inputTypeConvolutional = (InputTypeConvolutional) obj;
            return inputTypeConvolutional.canEqual(this) && getHeight() == inputTypeConvolutional.getHeight() && getWidth() == inputTypeConvolutional.getWidth() && getChannels() == inputTypeConvolutional.getChannels();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof InputTypeConvolutional;
        }

        public int hashCode() {
            long height = getHeight();
            int i = (1 * 59) + ((int) ((height >>> 32) ^ height));
            long width = getWidth();
            int i2 = (i * 59) + ((int) ((width >>> 32) ^ width));
            long channels = getChannels();
            return (i2 * 59) + ((int) ((channels >>> 32) ^ channels));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType$InputTypeConvolutional3D.class */
    public static class InputTypeConvolutional3D extends InputType {
        private Convolution3D.DataFormat dataFormat;
        private long depth;
        private long height;
        private long width;
        private long channels;

        public InputTypeConvolutional3D(@JsonProperty("dataFormat") Convolution3D.DataFormat dataFormat, @JsonProperty("depth") long j, @JsonProperty("height") long j2, @JsonProperty("width") long j3, @JsonProperty("channels") long j4) {
            this.dataFormat = dataFormat;
            this.depth = j;
            this.height = j2;
            this.width = j3;
            this.channels = j4;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public Type getType() {
            return Type.CNN3D;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public String toString() {
            return "InputTypeConvolutional3D(format=" + this.dataFormat + ",d=" + this.depth + ",h=" + this.height + ",w=" + this.width + ",c=" + this.channels + ")";
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long arrayElementsPerExample() {
            return this.height * this.width * this.depth * this.channels;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long[] getShape(boolean z) {
            return this.dataFormat == Convolution3D.DataFormat.NDHWC ? z ? new long[]{-1, this.depth, this.height, this.width, this.channels} : new long[]{this.depth, this.height, this.width, this.channels} : z ? new long[]{-1, this.channels, this.depth, this.height, this.width} : new long[]{this.channels, this.depth, this.height, this.width};
        }

        public InputTypeConvolutional3D() {
        }

        public Convolution3D.DataFormat getDataFormat() {
            return this.dataFormat;
        }

        public long getDepth() {
            return this.depth;
        }

        public long getHeight() {
            return this.height;
        }

        public long getWidth() {
            return this.width;
        }

        public long getChannels() {
            return this.channels;
        }

        public void setDataFormat(Convolution3D.DataFormat dataFormat) {
            this.dataFormat = dataFormat;
        }

        public void setDepth(long j) {
            this.depth = j;
        }

        public void setHeight(long j) {
            this.height = j;
        }

        public void setWidth(long j) {
            this.width = j;
        }

        public void setChannels(long j) {
            this.channels = j;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InputTypeConvolutional3D)) {
                return false;
            }
            InputTypeConvolutional3D inputTypeConvolutional3D = (InputTypeConvolutional3D) obj;
            if (!inputTypeConvolutional3D.canEqual(this)) {
                return false;
            }
            Convolution3D.DataFormat dataFormat = getDataFormat();
            Convolution3D.DataFormat dataFormat2 = inputTypeConvolutional3D.getDataFormat();
            if (dataFormat == null) {
                if (dataFormat2 != null) {
                    return false;
                }
            } else if (!dataFormat.equals(dataFormat2)) {
                return false;
            }
            return getDepth() == inputTypeConvolutional3D.getDepth() && getHeight() == inputTypeConvolutional3D.getHeight() && getWidth() == inputTypeConvolutional3D.getWidth() && getChannels() == inputTypeConvolutional3D.getChannels();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof InputTypeConvolutional3D;
        }

        public int hashCode() {
            Convolution3D.DataFormat dataFormat = getDataFormat();
            int hashCode = (1 * 59) + (dataFormat == null ? 43 : dataFormat.hashCode());
            long depth = getDepth();
            int i = (hashCode * 59) + ((int) ((depth >>> 32) ^ depth));
            long height = getHeight();
            int i2 = (i * 59) + ((int) ((height >>> 32) ^ height));
            long width = getWidth();
            int i3 = (i2 * 59) + ((int) ((width >>> 32) ^ width));
            long channels = getChannels();
            return (i3 * 59) + ((int) ((channels >>> 32) ^ channels));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType$InputTypeConvolutionalFlat.class */
    public static class InputTypeConvolutionalFlat extends InputType {
        private long height;
        private long width;
        private long depth;

        public InputTypeConvolutionalFlat(@JsonProperty("height") long j, @JsonProperty("width") long j2, @JsonProperty("depth") long j3) {
            this.height = j;
            this.width = j2;
            this.depth = j3;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public Type getType() {
            return Type.CNNFlat;
        }

        public long getFlattenedSize() {
            return this.height * this.width * this.depth;
        }

        public InputType getUnflattenedType() {
            return InputType.convolutional(this.height, this.width, this.depth);
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public String toString() {
            return "InputTypeConvolutionalFlat(h=" + this.height + ",w=" + this.width + ",d=" + this.depth + ")";
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long arrayElementsPerExample() {
            return this.height * this.width * this.depth;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long[] getShape(boolean z) {
            return z ? new long[]{-1, this.depth, this.height, this.width} : new long[]{this.depth, this.height, this.width};
        }

        public InputTypeConvolutionalFlat() {
        }

        public long getHeight() {
            return this.height;
        }

        public long getWidth() {
            return this.width;
        }

        public long getDepth() {
            return this.depth;
        }

        public void setHeight(long j) {
            this.height = j;
        }

        public void setWidth(long j) {
            this.width = j;
        }

        public void setDepth(long j) {
            this.depth = j;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InputTypeConvolutionalFlat)) {
                return false;
            }
            InputTypeConvolutionalFlat inputTypeConvolutionalFlat = (InputTypeConvolutionalFlat) obj;
            return inputTypeConvolutionalFlat.canEqual(this) && getHeight() == inputTypeConvolutionalFlat.getHeight() && getWidth() == inputTypeConvolutionalFlat.getWidth() && getDepth() == inputTypeConvolutionalFlat.getDepth();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof InputTypeConvolutionalFlat;
        }

        public int hashCode() {
            long height = getHeight();
            int i = (1 * 59) + ((int) ((height >>> 32) ^ height));
            long width = getWidth();
            int i2 = (i * 59) + ((int) ((width >>> 32) ^ width));
            long depth = getDepth();
            return (i2 * 59) + ((int) ((depth >>> 32) ^ depth));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType$InputTypeFeedForward.class */
    public static class InputTypeFeedForward extends InputType {
        private long size;

        public InputTypeFeedForward(@JsonProperty("size") long j) {
            this.size = j;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public Type getType() {
            return Type.FF;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public String toString() {
            return "InputTypeFeedForward(" + this.size + ")";
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long arrayElementsPerExample() {
            return this.size;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long[] getShape(boolean z) {
            return z ? new long[]{-1, this.size} : new long[]{this.size};
        }

        public InputTypeFeedForward() {
        }

        public long getSize() {
            return this.size;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InputTypeFeedForward)) {
                return false;
            }
            InputTypeFeedForward inputTypeFeedForward = (InputTypeFeedForward) obj;
            return inputTypeFeedForward.canEqual(this) && getSize() == inputTypeFeedForward.getSize();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof InputTypeFeedForward;
        }

        public int hashCode() {
            long size = getSize();
            return (1 * 59) + ((int) ((size >>> 32) ^ size));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType$InputTypeRecurrent.class */
    public static class InputTypeRecurrent extends InputType {
        private long size;
        private long timeSeriesLength;

        public InputTypeRecurrent(long j) {
            this(j, -1L);
        }

        public InputTypeRecurrent(@JsonProperty("size") long j, @JsonProperty("timeSeriesLength") long j2) {
            this.size = j;
            this.timeSeriesLength = j2;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public Type getType() {
            return Type.RNN;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public String toString() {
            return this.timeSeriesLength > 0 ? "InputTypeRecurrent(" + this.size + ",timeSeriesLength=" + this.timeSeriesLength + ")" : "InputTypeRecurrent(" + this.size + ")";
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long arrayElementsPerExample() {
            if (this.timeSeriesLength <= 0) {
                throw new IllegalStateException("Cannot calculate number of array elements per example: time series length is not set. Use InputType.recurrent(int size, int timeSeriesLength) instead?");
            }
            return this.timeSeriesLength * this.size;
        }

        @Override // org.deeplearning4j.nn.conf.inputs.InputType
        public long[] getShape(boolean z) {
            return z ? new long[]{-1, this.size, this.timeSeriesLength} : new long[]{this.size, this.timeSeriesLength};
        }

        public InputTypeRecurrent() {
        }

        public long getSize() {
            return this.size;
        }

        public long getTimeSeriesLength() {
            return this.timeSeriesLength;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InputTypeRecurrent)) {
                return false;
            }
            InputTypeRecurrent inputTypeRecurrent = (InputTypeRecurrent) obj;
            return inputTypeRecurrent.canEqual(this) && getSize() == inputTypeRecurrent.getSize() && getTimeSeriesLength() == inputTypeRecurrent.getTimeSeriesLength();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof InputTypeRecurrent;
        }

        public int hashCode() {
            long size = getSize();
            int i = (1 * 59) + ((int) ((size >>> 32) ^ size));
            long timeSeriesLength = getTimeSeriesLength();
            return (i * 59) + ((int) ((timeSeriesLength >>> 32) ^ timeSeriesLength));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/inputs/InputType$Type.class */
    public enum Type {
        FF,
        RNN,
        CNN,
        CNNFlat,
        CNN3D
    }

    @JsonIgnore
    public abstract Type getType();

    public abstract String toString();

    @JsonIgnore
    public abstract long arrayElementsPerExample();

    @JsonIgnore
    public abstract long[] getShape(boolean z);

    public long[] getShape() {
        return getShape(false);
    }

    public static InputType feedForward(long j) {
        return new InputTypeFeedForward(j);
    }

    public static InputType recurrent(long j) {
        return new InputTypeRecurrent(j);
    }

    public static InputType recurrent(long j, long j2) {
        return new InputTypeRecurrent(j, j2);
    }

    public static InputType convolutional(long j, long j2, long j3) {
        return new InputTypeConvolutional(j, j2, j3);
    }

    @Deprecated
    public static InputType convolutional3D(long j, long j2, long j3, long j4) {
        return convolutional3D(Convolution3D.DataFormat.NDHWC, j, j2, j3, j4);
    }

    public static InputType convolutional3D(Convolution3D.DataFormat dataFormat, long j, long j2, long j3, long j4) {
        return new InputTypeConvolutional3D(dataFormat, j, j2, j3, j4);
    }

    public static InputType convolutionalFlat(long j, long j2, long j3) {
        return new InputTypeConvolutionalFlat(j, j2, j3);
    }

    public static InputType inferInputType(INDArray iNDArray) {
        switch (iNDArray.rank()) {
            case 2:
                return feedForward(iNDArray.size(1));
            case 3:
                return recurrent(iNDArray.size(1), (int) iNDArray.size(2));
            case 4:
                return convolutional(iNDArray.size(2), (int) iNDArray.size(3), (int) iNDArray.size(1));
            case 5:
                return convolutional3D(iNDArray.size(2), (int) iNDArray.size(3), (int) iNDArray.size(4), (int) iNDArray.size(1));
            default:
                throw new IllegalArgumentException("Cannot infer input type for array with shape: " + Arrays.toString(iNDArray.shape()));
        }
    }

    public static InputType[] inferInputTypes(INDArray... iNDArrayArr) {
        InputType[] inputTypeArr = new InputType[iNDArrayArr.length];
        for (int i = 0; i < iNDArrayArr.length; i++) {
            inputTypeArr[i] = inferInputType(iNDArrayArr[i]);
        }
        return inputTypeArr;
    }
}
