package org.deeplearning4j.nn.modelimport.keras.layers.recurrent;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.class */
public class KerasSimpleRnn extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasSimpleRnn.class);
    private final int NUM_TRAINABLE_PARAMS = 3;
    protected boolean unroll;
    protected boolean returnSequences;

    public KerasSimpleRnn(Integer num) throws UnsupportedKerasConfigurationException {
        super(num);
        this.NUM_TRAINABLE_PARAMS = 3;
        this.unroll = false;
    }

    public KerasSimpleRnn(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    public KerasSimpleRnn(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        this.NUM_TRAINABLE_PARAMS = 3;
        this.unroll = false;
        Pair<WeightInit, Distribution> weightInitFromConfig = KerasInitilizationUtils.getWeightInitFromConfig(map, this.conf.getLAYER_FIELD_INIT(), z, this.conf, this.kerasMajorVersion.intValue());
        WeightInit weightInit = (WeightInit) weightInitFromConfig.getFirst();
        Distribution distribution = (Distribution) weightInitFromConfig.getSecond();
        Pair<WeightInit, Distribution> weightInitFromConfig2 = KerasInitilizationUtils.getWeightInitFromConfig(map, this.conf.getLAYER_FIELD_INNER_INIT(), z, this.conf, this.kerasMajorVersion.intValue());
        WeightInit weightInit2 = (WeightInit) weightInitFromConfig2.getFirst();
        Distribution distribution2 = (Distribution) weightInitFromConfig2.getSecond();
        this.returnSequences = ((Boolean) KerasLayerUtils.getInnerLayerConfigFromConfig(map, this.conf).get(this.conf.getLAYER_FIELD_RETURN_SEQUENCES())).booleanValue();
        KerasRnnUtils.getRecurrentDropout(this.conf, map);
        this.unroll = KerasRnnUtils.getUnrollRecurrentLayer(this.conf, map);
        LayerConstraint constraintsFromConfig = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_B_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        LayerConstraint constraintsFromConfig2 = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_W_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        LayerConstraint constraintsFromConfig3 = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_RECURRENT_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        SimpleRnn.Builder l2 = new SimpleRnn.Builder().name(this.layerName).nOut(KerasLayerUtils.getNOutFromConfig(map, this.conf)).dropOut(this.dropout).activation(KerasActivationUtils.getActivationFromConfig(map, this.conf)).weightInit(weightInit).weightInitRecurrent(weightInit2).biasInit(0.0d).l1(this.weightL1Regularization).l2(this.weightL2Regularization);
        if (distribution != null) {
            l2.dist(distribution);
        }
        if (distribution2 != null) {
            l2.dist(distribution2);
        }
        if (constraintsFromConfig != null) {
            l2.constrainBias(new LayerConstraint[]{constraintsFromConfig});
        }
        if (constraintsFromConfig2 != null) {
            l2.constrainInputWeights(new LayerConstraint[]{constraintsFromConfig2});
        }
        if (constraintsFromConfig3 != null) {
            l2.constrainRecurrent(new LayerConstraint[]{constraintsFromConfig3});
        }
        if (this.returnSequences) {
            this.layer = l2.build();
        } else {
            this.layer = new LastTimeStep(l2.build());
        }
    }

    public Layer getSimpleRnnLayer() {
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras SimpleRnn layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        InputPreProcessor inputPreprocessor = getInputPreprocessor(inputTypeArr);
        return inputPreprocessor != null ? inputPreprocessor.getOutputType(inputTypeArr[0]) : getSimpleRnnLayer().getOutputType(-1, inputTypeArr[0]);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public int getNumParams() {
        return 3;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputPreProcessor getInputPreprocessor(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras SimpleRnn layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputTypeArr[0], this.layerName);
    }

    public boolean getUnroll() {
        return this.unroll;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_W())) {
            throw new InvalidKerasConfigurationException("Keras SimpleRNN layer does not contain parameter " + this.conf.getKERAS_PARAM_NAME_W());
        }
        this.weights.put("W", map.get(this.conf.getKERAS_PARAM_NAME_W()));
        if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_RW())) {
            throw new InvalidKerasConfigurationException("Keras SimpleRNN layer does not contain parameter " + this.conf.getKERAS_PARAM_NAME_RW());
        }
        this.weights.put("RW", map.get(this.conf.getKERAS_PARAM_NAME_RW()));
        if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_B())) {
            throw new InvalidKerasConfigurationException("Keras SimpleRNN layer does not contain parameter " + this.conf.getKERAS_PARAM_NAME_B());
        }
        this.weights.put("b", map.get(this.conf.getKERAS_PARAM_NAME_B()));
        if (map.size() > 3) {
            Set<String> keySet = map.keySet();
            keySet.remove(this.conf.getKERAS_PARAM_NAME_B());
            keySet.remove(this.conf.getKERAS_PARAM_NAME_W());
            keySet.remove(this.conf.getKERAS_PARAM_NAME_RW());
            String obj = keySet.toString();
            log.warn("Attemping to set weights for unknown parameters: " + obj.substring(1, obj.length() - 1));
        }
    }

    public int getNUM_TRAINABLE_PARAMS() {
        getClass();
        return 3;
    }

    public boolean isReturnSequences() {
        return this.returnSequences;
    }

    public void setUnroll(boolean z) {
        this.unroll = z;
    }

    public void setReturnSequences(boolean z) {
        this.returnSequences = z;
    }

    public String toString() {
        return "KerasSimpleRnn(NUM_TRAINABLE_PARAMS=" + getNUM_TRAINABLE_PARAMS() + ", unroll=" + getUnroll() + ", returnSequences=" + isReturnSequences() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof KerasSimpleRnn)) {
            return false;
        }
        KerasSimpleRnn kerasSimpleRnn = (KerasSimpleRnn) obj;
        return kerasSimpleRnn.canEqual(this) && getNUM_TRAINABLE_PARAMS() == kerasSimpleRnn.getNUM_TRAINABLE_PARAMS() && getUnroll() == kerasSimpleRnn.getUnroll() && isReturnSequences() == kerasSimpleRnn.isReturnSequences();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof KerasSimpleRnn;
    }

    public int hashCode() {
        return (((((1 * 59) + getNUM_TRAINABLE_PARAMS()) * 59) + (getUnroll() ? 79 : 97)) * 59) + (isReturnSequences() ? 79 : 97);
    }
}
