package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.class */
public class TimeDistributedLayer extends BaseWrapperLayer {
    private RNNFormat rnnDataFormat;

    public TimeDistributedLayer(Layer layer, RNNFormat rNNFormat) {
        super(layer);
        this.rnnDataFormat = rNNFormat;
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        Pair<Gradient, INDArray> backpropGradient = this.underlying.backpropGradient(reshape(iNDArray), layerWorkspaceMgr);
        backpropGradient.setSecond(layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, revertReshape((INDArray) backpropGradient.getSecond(), iNDArray.size(0))));
        return backpropGradient;
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return activate(input(), z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, revertReshape(this.underlying.activate(reshape(iNDArray), z, layerWorkspaceMgr), iNDArray.size(0)));
    }

    protected INDArray reshape(INDArray iNDArray) {
        int i = this.rnnDataFormat == RNNFormat.NCW ? 2 : 1;
        if (i < 0) {
            i += iNDArray.rank();
        }
        INDArray permute = iNDArray.permute(permuteAxes(iNDArray.rank(), i));
        long[] jArr = new long[iNDArray.rank() - 1];
        jArr[0] = iNDArray.size(0) * iNDArray.size(i);
        int i2 = 1;
        for (int i3 = 1; i3 < iNDArray.rank(); i3++) {
            if (i != i3) {
                int i4 = i2;
                i2++;
                jArr[i4] = iNDArray.size(i3);
            }
        }
        return permute.dup().reshape('c', jArr);
    }

    protected int[] permuteAxes(int i, int i2) {
        int[] iArr = new int[i];
        iArr[0] = 0;
        iArr[1] = i2;
        int i3 = 2;
        for (int i4 = 1; i4 < i; i4++) {
            if (i2 != i4) {
                int i5 = i3;
                i3++;
                iArr[i5] = i4;
            }
        }
        return iArr;
    }

    protected INDArray revertReshape(INDArray iNDArray, long j) {
        int i = this.rnnDataFormat == RNNFormat.NCW ? 2 : 1;
        if (i < 0) {
            i += iNDArray.rank() + 1;
        }
        long[] jArr = new long[iNDArray.rank() + 1];
        jArr[0] = j;
        jArr[1] = iNDArray.size(0) / j;
        for (int i2 = 1; i2 < iNDArray.rank(); i2++) {
            jArr[i2 + 1] = iNDArray.size(i2);
        }
        return iNDArray.reshape('c', jArr).permute(ArrayUtil.invertPermutation(permuteAxes(iNDArray.rank() + 1, i)));
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        if (iNDArray == null) {
            this.underlying.setMaskArray(null);
        } else {
            this.underlying.setMaskArray(TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray, LayerWorkspaceMgr.noWorkspaces(), ArrayType.ACTIVATIONS));
        }
    }

    @Override // org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        if (iNDArray == null) {
            return this.underlying.feedForwardMaskArray(null, maskState, i);
        }
        Pair<INDArray, MaskState> feedForwardMaskArray = this.underlying.feedForwardMaskArray(TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray, LayerWorkspaceMgr.noWorkspaces(), ArrayType.ACTIVATIONS), maskState, i);
        if (feedForwardMaskArray == null || feedForwardMaskArray.getFirst() == null) {
            return feedForwardMaskArray;
        }
        feedForwardMaskArray.setFirst(TimeSeriesUtils.reshapeVectorToTimeSeriesMask((INDArray) feedForwardMaskArray.getFirst(), (int) iNDArray.size(0)));
        return feedForwardMaskArray;
    }
}
