package org.deeplearning4j.util;

import java.util.Arrays;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/util/TimeSeriesUtils.class */
public class TimeSeriesUtils {
    private TimeSeriesUtils() {
    }

    public static INDArray movingAverage(INDArray iNDArray, int i) {
        INDArray cumsum = Nd4j.cumsum(iNDArray);
        INDArrayIndex[] iNDArrayIndexArr = {NDArrayIndex.interval(i, iNDArray.columns())};
        INDArrayIndex[] iNDArrayIndexArr2 = {NDArrayIndex.interval(0, iNDArray.columns() - i, false)};
        INDArrayIndex[] iNDArrayIndexArr3 = {NDArrayIndex.interval(i - 1, iNDArray.columns())};
        cumsum.put(iNDArrayIndexArr, cumsum.get(iNDArrayIndexArr).sub(cumsum.get(iNDArrayIndexArr2)));
        return cumsum.get(iNDArrayIndexArr3).divi(Integer.valueOf(i));
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray iNDArray) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (iNDArray.ordering() != 'f') {
            iNDArray = iNDArray.dup('f');
        }
        return iNDArray.reshape('f', new long[]{iNDArray.length(), 1});
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (iNDArray.ordering() != 'f' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'f');
        }
        return layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('f', new long[]{iNDArray.length(), 1}));
    }

    public static INDArray reshapeTimeSeriesMaskToCnn4dMask(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (iNDArray.ordering() != 'f' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'f');
        }
        return layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('f', new long[]{iNDArray.length(), 1, 1, 1}));
    }

    public static INDArray reshapeVectorToTimeSeriesMask(INDArray iNDArray, int i) {
        if (iNDArray.isVector()) {
            return iNDArray.reshape('f', new long[]{i, iNDArray.length() / i});
        }
        throw new IllegalArgumentException("Cannot reshape mask: expected vector");
    }

    public static INDArray reshapeCnnMaskToTimeSeriesMask(INDArray iNDArray, int i) {
        Preconditions.checkArgument((iNDArray.rank() != 4 && iNDArray.size(1) == 1 && iNDArray.size(2) == 1 && iNDArray.size(3) == 1) ? false : true, "Expected rank 4 mask with shape [mb*seqLength, 1, 1, 1]. Got rank %s mask array with shape %s", Integer.valueOf(iNDArray.rank()), iNDArray.shape());
        return iNDArray.reshape('f', new long[]{i, iNDArray.length() / i});
    }

    public static INDArray reshapePerOutputTimeSeriesMaskTo2d(INDArray iNDArray) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Cannot reshape per output mask: rank is not 3 (is: " + iNDArray.rank() + ", shape = " + Arrays.toString(iNDArray.shape()) + ")");
        }
        return reshape3dTo2d(iNDArray);
    }

    public static INDArray reshapePerOutputTimeSeriesMaskTo2d(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Cannot reshape per output mask: rank is not 3 (is: " + iNDArray.rank() + ", shape = " + Arrays.toString(iNDArray.shape()) + ")");
        }
        return reshape3dTo2d(iNDArray, layerWorkspaceMgr, arrayType);
    }

    public static INDArray reshape3dTo2d(INDArray iNDArray) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        long[] shape = iNDArray.shape();
        return shape[0] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 2}).permutei(new int[]{1, 0}) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 0}) : iNDArray.permute(new int[]{0, 2, 1}).reshape('f', new long[]{shape[0] * shape[2], shape[1]});
    }

    public static INDArray reshape3dTo2d(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        long[] shape = iNDArray.shape();
        return layerWorkspaceMgr.leverageTo(arrayType, shape[0] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 2}).permutei(new int[]{1, 0}) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0, new int[]{1, 0}) : iNDArray.permute(new int[]{0, 2, 1}).reshape('f', new long[]{shape[0] * shape[2], shape[1]}));
    }

    public static INDArray reshape2dTo3d(INDArray iNDArray, int i) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        long[] shape = iNDArray.shape();
        if (iNDArray.ordering() != 'f') {
            iNDArray = Shape.toOffsetZeroCopy(iNDArray, 'f');
        }
        return iNDArray.reshape('f', new long[]{i, shape[0] / i, shape[1]}).permute(new int[]{0, 2, 1});
    }

    public static INDArray reshape2dTo3d(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        long[] shape = iNDArray.shape();
        if (iNDArray.ordering() != 'f') {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'f');
        }
        return layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('f', new long[]{i, shape[0] / i, shape[1]}).permute(new int[]{0, 2, 1}));
    }

    public static INDArray reverseTimeSeries(INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.ordering() != 'f' || iNDArray.isView() || !Shape.strideDescendingCAscendingF(iNDArray)) {
            iNDArray = iNDArray.dup('f');
        }
        int[] iArr = new int[(int) iNDArray.size(2)];
        int i = 0;
        for (int length = iArr.length - 1; length >= 0; length--) {
            int i2 = i;
            i++;
            iArr[i2] = length;
        }
        return Nd4j.pullRows(iNDArray.reshape('f', new long[]{iNDArray.size(0) * iNDArray.size(1), iNDArray.size(2)}), 0, iArr, 'f').reshape('f', new long[]{iNDArray.size(0), iNDArray.size(1), iNDArray.size(2)});
    }

    public static INDArray reverseTimeSeries(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.ordering() != 'f' || iNDArray.isView() || !Shape.strideDescendingCAscendingF(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'f');
        }
        int[] iArr = new int[(int) iNDArray.size(2)];
        int i = 0;
        for (int length = iArr.length - 1; length >= 0; length--) {
            int i2 = i;
            i++;
            iArr[i2] = length;
        }
        INDArray reshape = iNDArray.reshape('f', new long[]{iNDArray.size(0) * iNDArray.size(1), iNDArray.size(2)});
        INDArray create = layerWorkspaceMgr.create(arrayType, new long[]{reshape.size(0), iArr.length}, 'f');
        Nd4j.pullRows(reshape, create, 0, iArr);
        return layerWorkspaceMgr.leverageTo(arrayType, create.reshape('f', new long[]{iNDArray.size(0), iNDArray.size(1), iNDArray.size(2)}));
    }

    public static INDArray reverseTimeSeriesMask(INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.rank() == 3) {
            return reverseTimeSeries(iNDArray);
        }
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid mask rank: must be rank 2 or 3. Got rank " + iNDArray.rank() + " with shape " + Arrays.toString(iNDArray.shape()));
        }
        int[] iArr = new int[(int) iNDArray.size(1)];
        int i = 0;
        for (int length = iArr.length - 1; length >= 0; length--) {
            int i2 = i;
            i++;
            iArr[i2] = length;
        }
        return Nd4j.pullRows(iNDArray, 0, iArr);
    }

    public static INDArray reverseTimeSeriesMask(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.rank() == 3) {
            return reverseTimeSeries(iNDArray, layerWorkspaceMgr, arrayType);
        }
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid mask rank: must be rank 2 or 3. Got rank " + iNDArray.rank() + " with shape " + Arrays.toString(iNDArray.shape()));
        }
        int[] iArr = new int[(int) iNDArray.size(1)];
        int i = 0;
        for (int length = iArr.length - 1; length >= 0; length--) {
            int i2 = i;
            i++;
            iArr[i2] = length;
        }
        return Nd4j.pullRows(iNDArray, layerWorkspaceMgr.createUninitialized(arrayType, new long[]{iNDArray.size(0), iArr.length}, 'f'), 0, iArr);
    }

    public static Pair<INDArray, int[]> pullLastTimeSteps(INDArray iNDArray, INDArray iNDArray2) {
        INDArray create;
        int[] asInt;
        if (iNDArray2 == null) {
            create = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(((int) iNDArray.size(2)) - 1)});
            asInt = null;
        } else {
            create = Nd4j.create(new long[]{iNDArray.size(0), iNDArray.size(1)});
            asInt = BooleanIndexing.lastIndex(iNDArray2, Conditions.epsNotEquals(Double.valueOf(EvaluationBinary.DEFAULT_EDGE_VALUE)), new int[]{1}).data().asInt();
            for (int i = 0; i < asInt.length; i++) {
                create.putRow(i, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(asInt[i])}));
            }
        }
        return new Pair<>(create, asInt);
    }

    public static Pair<INDArray, int[]> pullLastTimeSteps(INDArray iNDArray, INDArray iNDArray2, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        INDArray create;
        int[] asInt;
        if (iNDArray2 == null) {
            create = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(((int) iNDArray.size(2)) - 1)});
            asInt = null;
        } else {
            create = Nd4j.create(new long[]{iNDArray.size(0), iNDArray.size(1)});
            asInt = BooleanIndexing.lastIndex(iNDArray2, Conditions.epsNotEquals(Double.valueOf(EvaluationBinary.DEFAULT_EDGE_VALUE)), new int[]{1}).data().asInt();
            for (int i = 0; i < asInt.length; i++) {
                create.putRow(i, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(asInt[i])}));
            }
        }
        return new Pair<>(layerWorkspaceMgr.leverageTo(arrayType, create), asInt);
    }
}
