package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/LSTM.class */
public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.LSTM> {
    private static final Logger log = LoggerFactory.getLogger(LSTM.class);
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
    public static final String STATE_KEY_PREV_MEMCELL = "prevMem";
    protected LSTMHelper helper;
    protected FwdPassReturn cachedFwdPass;

    public LSTM(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.helper = null;
        initializeHelper();
    }

    /* JADX WARN: Multi-variable type inference failed */
    void initializeHelper() {
        if ("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))) {
            this.helper = (LSTMHelper) DL4JClassLoading.createNewInstance("org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper", LSTMHelper.class, new Object[]{this.dataType});
            log.debug("CudnnLSTMHelper successfully initialized");
            if (this.helper.checkSupported(((org.deeplearning4j.nn.conf.layers.LSTM) layerConf()).getGateActivationFn(), ((org.deeplearning4j.nn.conf.layers.LSTM) layerConf()).getActivationFn(), false)) {
                return;
            }
            this.helper = null;
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("gradient() method for layerwise pretraining: not supported for LSTMs (pretraining not possible) " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return backpropGradientHelper(iNDArray, false, -1, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        return backpropGradientHelper(iNDArray, true, i, layerWorkspaceMgr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<Gradient, INDArray> backpropGradientHelper(INDArray iNDArray, boolean z, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        FwdPassReturn activateHelper;
        assertInputSet(true);
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise("RW", true, layerWorkspaceMgr);
        if (z) {
            activateHelper = activateHelper(true, this.stateMap.get("prevAct"), this.stateMap.get("prevMem"), true, layerWorkspaceMgr);
            this.tBpttStateMap.put("prevAct", activateHelper.lastAct.detach());
            this.tBpttStateMap.put("prevMem", activateHelper.lastMemCell.detach());
        } else {
            activateHelper = activateHelper(true, null, null, true, layerWorkspaceMgr);
        }
        activateHelper.fwdPassOutput = permuteIfNWC(activateHelper.fwdPassOutput);
        Pair<Gradient, INDArray> backpropGradientHelper = LSTMHelpers.backpropGradientHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.LSTM) layerConf()).getGateActivationFn(), permuteIfNWC(this.input), paramWithNoise2, paramWithNoise, permuteIfNWC(iNDArray), z, i, activateHelper, true, "W", "RW", "b", this.gradientViews, null, false, this.helper, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.LSTM) layerConf()).isHelperAllowFallback());
        this.weightNoiseParams.clear();
        backpropGradientHelper.setSecond(permuteIfNWC(backpropDropOutIfPresent((INDArray) backpropGradientHelper.getSecond())));
        return backpropGradientHelper;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        return activateHelper(z, null, null, false, layerWorkspaceMgr).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return activateHelper(z, null, null, false, layerWorkspaceMgr).fwdPassOutput;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.conf.layers.BaseLayer] */
    private FwdPassReturn activateHelper(boolean z, INDArray iNDArray, INDArray iNDArray2, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        Preconditions.checkState(this.input.rank() == 3, "3D input expected to RNN layer expected, got " + this.input.rank());
        boolean z3 = TimeSeriesUtils.getFormatFromRnnLayer(layerConf()) == RNNFormat.NWC;
        INDArray iNDArray3 = this.input;
        if (z3) {
            this.input = permuteIfNWC(this.input);
        }
        applyDropOutIfNecessary(z, layerWorkspaceMgr);
        this.cacheMode = CacheMode.NONE;
        if (z2 && this.cachedFwdPass != null) {
            FwdPassReturn fwdPassReturn = this.cachedFwdPass;
            this.cachedFwdPass = null;
            return fwdPassReturn;
        }
        FwdPassReturn activateHelper = LSTMHelpers.activateHelper(this, this.conf, ((org.deeplearning4j.nn.conf.layers.LSTM) layerConf()).getGateActivationFn(), this.input, getParamWithNoise("RW", z, layerWorkspaceMgr), getParamWithNoise("W", z, layerWorkspaceMgr), getParamWithNoise("b", z, layerWorkspaceMgr), z, iNDArray, iNDArray2, (z && this.cacheMode != CacheMode.NONE) || z2, true, "W", this.maskArray, false, this.helper, z2 ? this.cacheMode : CacheMode.NONE, layerWorkspaceMgr, ((org.deeplearning4j.nn.conf.layers.LSTM) layerConf()).isHelperAllowFallback());
        activateHelper.fwdPassOutput = permuteIfNWC(activateHelper.fwdPassOutput);
        if (z && this.cacheMode != CacheMode.NONE) {
            this.cachedFwdPass = activateHelper;
        }
        if (z3) {
            this.input = iNDArray3;
        }
        return activateHelper;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        return new Pair<>(iNDArray, MaskState.Passthrough);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        FwdPassReturn activateHelper = activateHelper(false, this.stateMap.get("prevAct"), this.stateMap.get("prevMem"), false, layerWorkspaceMgr);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        this.stateMap.put("prevAct", activateHelper.lastAct.detach());
        this.stateMap.put("prevMem", activateHelper.lastMemCell.detach());
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        FwdPassReturn activateHelper = activateHelper(z, this.tBpttStateMap.get("prevAct"), this.tBpttStateMap.get("prevMem"), false, layerWorkspaceMgr);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        if (z2) {
            this.tBpttStateMap.put("prevAct", activateHelper.lastAct.detach());
            this.tBpttStateMap.put("prevMem", activateHelper.lastMemCell.detach());
        }
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public LayerHelper getHelper() {
        return this.helper;
    }
}
