package org.deeplearning4j.nn.layers.recurrent;

import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.class */
public class GravesBidirectionalLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM> {
    private static final Logger log = LoggerFactory.getLogger(GravesBidirectionalLSTM.class);
    protected FwdPassReturn cachedPassForward;
    protected FwdPassReturn cachedPassBackward;

    public GravesBidirectionalLSTM(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not supported " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return backpropGradientHelper(iNDArray, false, -1, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        return backpropGradientHelper(iNDArray, true, i, layerWorkspaceMgr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<Gradient, INDArray> backpropGradientHelper(INDArray iNDArray, boolean z, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        if (z) {
            throw new UnsupportedOperationException("Time step for bidirectional RNN not supported: it has to run on a batch of data all at once " + layerId());
        }
        Pair<Gradient, INDArray> backpropGradientHelper = LSTMHelpers.backpropGradientHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).getGateActivationFn(), this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), iNDArray, z, i, activateHelperDirectional(true, null, null, true, true, layerWorkspaceMgr), true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS, this.gradientViews, this.maskArray, true, null, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).isHelperAllowFallback());
        Pair<Gradient, INDArray> backpropGradientHelper2 = LSTMHelpers.backpropGradientHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).getGateActivationFn(), this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), iNDArray, z, i, activateHelperDirectional(true, null, null, true, false, layerWorkspaceMgr), false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, this.gradientViews, this.maskArray, true, null, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).isHelperAllowFallback());
        DefaultGradient defaultGradient = new DefaultGradient();
        for (Map.Entry<String, INDArray> entry : ((Gradient) backpropGradientHelper.getFirst()).gradientForVariable().entrySet()) {
            defaultGradient.setGradientFor(entry.getKey(), entry.getValue());
        }
        for (Map.Entry<String, INDArray> entry2 : ((Gradient) backpropGradientHelper2.getFirst()).gradientForVariable().entrySet()) {
            defaultGradient.setGradientFor(entry2.getKey(), entry2.getValue());
        }
        DefaultGradient defaultGradient2 = new DefaultGradient();
        for (String str : this.params.keySet()) {
            defaultGradient2.setGradientFor(str, defaultGradient.getGradientFor(str));
        }
        return new Pair<>(defaultGradient2, ((INDArray) backpropGradientHelper.getSecond()).addi((INDArray) backpropGradientHelper2.getSecond()));
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        return activateOutput(z, false, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return activateOutput(z, false, layerWorkspaceMgr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private INDArray activateOutput(boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        FwdPassReturn activateHelper;
        FwdPassReturn activateHelper2;
        assertInputSet(false);
        if (this.cacheMode == CacheMode.NONE || this.cachedPassForward == null || this.cachedPassBackward == null) {
            activateHelper = LSTMHelpers.activateHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).getGateActivationFn(), this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), z, null, null, z2 || (this.cacheMode != CacheMode.NONE && z), true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, this.maskArray, true, null, z2 ? this.cacheMode : CacheMode.NONE, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).isHelperAllowFallback());
            activateHelper2 = LSTMHelpers.activateHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).getGateActivationFn(), this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), z, null, null, z2 || (this.cacheMode != CacheMode.NONE && z), false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, this.maskArray, true, null, z2 ? this.cacheMode : CacheMode.NONE, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).isHelperAllowFallback());
            this.cachedPassForward = activateHelper;
            this.cachedPassBackward = activateHelper2;
        } else {
            activateHelper = this.cachedPassForward;
            activateHelper2 = this.cachedPassBackward;
            this.cachedPassBackward = null;
            this.cachedPassForward = null;
        }
        INDArray iNDArray = activateHelper.fwdPassOutput;
        INDArray iNDArray2 = activateHelper2.fwdPassOutput;
        return (!z || this.cacheMode == CacheMode.NONE || z2) ? iNDArray.addi(iNDArray2) : iNDArray.add(iNDArray2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private FwdPassReturn activateHelperDirectional(boolean z, INDArray iNDArray, INDArray iNDArray2, boolean z2, boolean z3, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.cacheMode == null) {
            this.cacheMode = CacheMode.NONE;
        }
        if (this.cacheMode != CacheMode.NONE && z3 && z2 && this.cachedPassForward != null) {
            FwdPassReturn fwdPassReturn = this.cachedPassForward;
            this.cachedPassForward = null;
            return fwdPassReturn;
        }
        if (this.cacheMode != CacheMode.NONE && !z3 && z2) {
            FwdPassReturn fwdPassReturn2 = this.cachedPassBackward;
            this.cachedPassBackward = null;
            return fwdPassReturn2;
        }
        String str = GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS;
        String str2 = GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS;
        String str3 = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS;
        if (!z3) {
            str = GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS;
            str2 = GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS;
            str3 = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS;
        }
        return LSTMHelpers.activateHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).getGateActivationFn(), this.input, getParam(str), getParam(str2), getParam(str3), z, iNDArray, iNDArray2, z2, z3, str2, this.maskArray, true, null, z2 ? this.cacheMode : CacheMode.NONE, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) layerConf()).isHelperAllowFallback());
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("you can not time step a bidirectional RNN, it has to run on a batch of data all at once " + layerId());
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Cannot set stored state: bidirectional RNNs don't have stored state " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        this.maskArray = iNDArray;
        this.maskState = maskState;
        return new Pair<>(iNDArray, MaskState.Active);
    }
}
