package org.deeplearning4j.nn.layers.recurrent;

import java.util.Arrays;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.class */
public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.RnnOutputLayer> {
    public RnnOutputLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public RnnOutputLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        applyDropOutIfNecessary(true, layerWorkspaceMgr);
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]. Got input with rank " + this.input.rank() + " and shape " + Arrays.toString(this.input.shape()) + " - " + layerId());
        }
        INDArray iNDArray2 = this.input;
        this.input = TimeSeriesUtils.reshape3dTo2d(this.input, layerWorkspaceMgr, ArrayType.BP_WORKING_MEM);
        Pair<Gradient, INDArray> backpropGradient = super.backpropGradient(iNDArray, layerWorkspaceMgr);
        this.input = iNDArray2;
        INDArray reshape2dTo3d = TimeSeriesUtils.reshape2dTo3d((INDArray) backpropGradient.getSecond(), (int) this.input.size(0), layerWorkspaceMgr, ArrayType.ACTIVATION_GRAD);
        this.weightNoiseParams.clear();
        return new Pair<>(backpropGradient.getFirst(), reshape2dTo3d);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.rank() == 3) {
            iNDArray = TimeSeriesUtils.reshape3dTo2d(iNDArray, LayerWorkspaceMgr.noWorkspaces(), ArrayType.ACTIVATIONS);
        }
        if (iNDArray2.rank() == 3) {
            iNDArray2 = TimeSeriesUtils.reshape3dTo2d(iNDArray2, LayerWorkspaceMgr.noWorkspaces(), ArrayType.ACTIVATIONS);
        }
        return super.f1Score(iNDArray, iNDArray2);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer
    public INDArray getInput() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    public INDArray preOutput2d(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        if (this.input.rank() != 3) {
            return super.preOutput(z, layerWorkspaceMgr);
        }
        INDArray iNDArray = this.input;
        this.input = TimeSeriesUtils.reshape3dTo2d(this.input, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM);
        INDArray preOutput = super.preOutput(z, layerWorkspaceMgr);
        this.input = iNDArray;
        return preOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    protected INDArray getLabels2d(LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        return this.labels.rank() == 3 ? TimeSeriesUtils.reshape3dTo2d(this.labels, layerWorkspaceMgr, arrayType) : layerWorkspaceMgr.castTo(arrayType, Nd4j.defaultFloatingPointType(), this.labels, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input.rank() != 3) {
            throw new UnsupportedOperationException("Input must be rank 3. Got input with rank " + this.input.rank() + " " + layerId());
        }
        INDArray paramWithNoise = getParamWithNoise("b", z, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise("W", z, layerWorkspaceMgr);
        applyDropOutIfNecessary(z, layerWorkspaceMgr);
        INDArray activation = ((org.deeplearning4j.nn.conf.layers.RnnOutputLayer) layerConf()).getActivationFn().getActivation(TimeSeriesUtils.reshape3dTo2d(this.input, LayerWorkspaceMgr.noWorkspaces(), ArrayType.FF_WORKING_MEM).mmul(paramWithNoise2).addiRowVector(paramWithNoise), z);
        if (this.maskArray != null) {
            if (!this.maskArray.isColumnVectorOrScalar() || Arrays.equals(this.maskArray.shape(), activation.shape())) {
                activation.muli(this.maskArray);
            } else {
                activation.muliColumnVector(this.maskArray);
            }
        }
        return TimeSeriesUtils.reshape2dTo3d(activation, (int) this.input.size(0), layerWorkspaceMgr, ArrayType.ACTIVATIONS);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        if (iNDArray == null) {
            this.maskArray = null;
        } else if (iNDArray.rank() == 2) {
            this.maskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray, LayerWorkspaceMgr.noWorkspacesImmutable(), ArrayType.INPUT);
        } else {
            if (iNDArray.rank() != 3) {
                throw new UnsupportedOperationException("Invalid mask array: must be rank 2 or 3 (got: rank " + iNDArray.rank() + ", shape = " + Arrays.toString(iNDArray.shape()) + ") " + layerId());
            }
            this.maskArray = TimeSeriesUtils.reshape3dTo2d(iNDArray, LayerWorkspaceMgr.noWorkspacesImmutable(), ArrayType.INPUT);
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        if (iNDArray == null || maskState != MaskState.Active) {
            this.inputMaskArray = null;
            this.inputMaskArrayState = null;
            return null;
        }
        this.inputMaskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray, LayerWorkspaceMgr.noWorkspacesImmutable(), ArrayType.INPUT);
        this.inputMaskArrayState = maskState;
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, double d2, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        INDArray sum = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(((org.deeplearning4j.nn.conf.layers.RnnOutputLayer) layerConf()).getLossFn().computeScoreArray(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d(false, layerWorkspaceMgr), ((org.deeplearning4j.nn.conf.layers.RnnOutputLayer) layerConf()).getActivationFn(), this.maskArray), (int) this.input.size(0)).sum(true, new int[]{1});
        double d3 = d + d2;
        if (d3 != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            sum.addi(Double.valueOf(d3));
        }
        return sum;
    }
}
