package org.deeplearning4j.nn.graph.vertex.impl.rnn;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.class */
public class ReverseTimeSeriesVertex extends BaseGraphVertex {
    private final String inputName;
    private final int inputIdx;

    public ReverseTimeSeriesVertex(ComputationGraph computationGraph, String str, int i, String str2, DataType dataType) {
        super(computationGraph, str, i, null, null, dataType);
        this.inputName = str2;
        if (str2 == null) {
            this.inputIdx = -1;
            return;
        }
        this.inputIdx = computationGraph.getConfiguration().getNetworkInputs().indexOf(str2);
        if (this.inputIdx == -1) {
            throw new IllegalArgumentException("Invalid input name: \"" + str2 + "\" not found in list of network inputs (" + computationGraph.getConfiguration().getNetworkInputs() + ")");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return revertTimeSeries(this.inputs[0], getMask(), layerWorkspaceMgr, ArrayType.ACTIVATIONS);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return new Pair<>((Object) null, new INDArray[]{revertTimeSeries(this.epsilon, getMask(), layerWorkspaceMgr, ArrayType.ACTIVATION_GRAD)});
    }

    private INDArray getMask() {
        INDArray[] inputMaskArrays;
        if (this.inputIdx >= 0 && (inputMaskArrays = this.graph.getInputMaskArrays()) != null) {
            return inputMaskArrays[this.inputIdx];
        }
        return null;
    }

    private static INDArray revertTimeSeries(INDArray iNDArray, INDArray iNDArray2, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        long size = iNDArray.size(0);
        long size2 = iNDArray.size(2);
        INDArray create = layerWorkspaceMgr.create(arrayType, iNDArray.dataType(), iNDArray.shape(), 'f');
        for (int i = 0; i < size; i++) {
            long j = 0;
            long j2 = size2;
            while (true) {
                long j3 = j2 - 1;
                if (j < size2 && j3 >= 0) {
                    if (iNDArray2 != null) {
                        while (j < size2 && iNDArray2.getDouble(i, j) == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                            j++;
                        }
                        while (j3 >= 0 && iNDArray2.getDouble(i, j3) == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                            j3--;
                        }
                    }
                    create.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j3)}, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}));
                    j++;
                    j2 = j3;
                }
            }
        }
        return create;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr.length > 1) {
            throw new IllegalArgumentException("This vertex can only handle one input and hence only one mask");
        }
        return new Pair<>(iNDArrayArr[0], maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "ReverseTimeSeriesVertex(" + (this.inputName == null ? "" : "inputName=" + this.inputName) + ")";
    }
}
