package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.class */
public abstract class KerasConvolution extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasConvolution.class);
    protected int numTrainableParams;
    protected boolean hasBias;

    public KerasConvolution(Integer num) throws UnsupportedKerasConfigurationException {
        super(num);
    }

    public KerasConvolution(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    public KerasConvolution(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public int getNumParams() {
        return this.numTrainableParams;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_W())) {
            throw new InvalidKerasConfigurationException("Parameter " + this.conf.getKERAS_PARAM_NAME_W() + " does not exist in weights");
        }
        this.weights.put("W", getConvParameterValues(map.get(this.conf.getKERAS_PARAM_NAME_W())));
        if (this.hasBias) {
            if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_B())) {
                throw new InvalidKerasConfigurationException("Parameter " + this.conf.getKERAS_PARAM_NAME_B() + " does not exist in weights");
            }
            this.weights.put("b", map.get(this.conf.getKERAS_PARAM_NAME_B()));
        }
        KerasLayerUtils.removeDefaultWeights(map, this.conf);
    }

    public INDArray getConvParameterValues(INDArray iNDArray) throws InvalidKerasConfigurationException {
        INDArray dup;
        switch (getDimOrder()) {
            case TENSORFLOW:
                if (iNDArray.rank() == 5) {
                    dup = iNDArray.permute(new int[]{4, 3, 0, 1, 2});
                    break;
                } else {
                    dup = iNDArray.permute(new int[]{3, 2, 0, 1});
                    break;
                }
            case THEANO:
                dup = iNDArray.dup();
                for (int i = 0; i < dup.tensorsAlongDimension(new int[]{2, 3}); i++) {
                    INDArray dup2 = dup.tensorAlongDimension(i, new int[]{2, 3}).dup();
                    double[] asDouble = dup2.ravel().data().asDouble();
                    ArrayUtils.reverse(asDouble);
                    INDArray create = Nd4j.create(asDouble, dup2.shape());
                    INDArray tensorAlongDimension = dup.tensorAlongDimension(i, new int[]{2, 3});
                    tensorAlongDimension.muli(0).addi(create.castTo(tensorAlongDimension.dataType()));
                }
                break;
            default:
                throw new InvalidKerasConfigurationException("Unknown keras backend " + getDimOrder());
        }
        return dup;
    }

    public int getNumTrainableParams() {
        return this.numTrainableParams;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public void setNumTrainableParams(int i) {
        this.numTrainableParams = i;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public String toString() {
        return "KerasConvolution(numTrainableParams=" + getNumTrainableParams() + ", hasBias=" + isHasBias() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof KerasConvolution)) {
            return false;
        }
        KerasConvolution kerasConvolution = (KerasConvolution) obj;
        return kerasConvolution.canEqual(this) && getNumTrainableParams() == kerasConvolution.getNumTrainableParams() && isHasBias() == kerasConvolution.isHasBias();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof KerasConvolution;
    }

    public int hashCode() {
        return (((1 * 59) + getNumTrainableParams()) * 59) + (isHasBias() ? 79 : 97);
    }
}
