package org.deeplearning4j.util;

import java.util.Arrays;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/util/MaskedReductionUtil.class */
public class MaskedReductionUtil {
    private static final int[] CNN_DIM_MASK_H = {0, 2};
    private static final int[] CNN_DIM_MASK_W = {0, 3};

    /* renamed from: org.deeplearning4j.util.MaskedReductionUtil$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/util/MaskedReductionUtil$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType = new int[PoolingType.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.MAX.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.AVG.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.SUM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.PNORM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    private MaskedReductionUtil() {
    }

    public static INDArray maskedPoolingTimeSeries(PoolingType poolingType, INDArray iNDArray, INDArray iNDArray2, int i, DataType dataType) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Expect rank 3 array: got " + iNDArray.rank());
        }
        if (iNDArray2.rank() != 2) {
            throw new IllegalArgumentException("Expect rank 2 array for mask: got " + iNDArray2.rank());
        }
        INDArray castTo = iNDArray.castTo(dataType);
        INDArray castTo2 = iNDArray2.castTo(dataType);
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[poolingType.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                INDArray rsub = castTo2.castTo(dataType).rsub(Double.valueOf(1.0d));
                BooleanIndexing.replaceWhere(rsub, Double.valueOf(Double.NEGATIVE_INFINITY), Conditions.equals(Double.valueOf(1.0d)));
                INDArray createUninitialized = Nd4j.createUninitialized(dataType, castTo.shape());
                Nd4j.getExecutioner().exec(new BroadcastAddOp(castTo, rsub, createUninitialized, new int[]{0, 2}));
                return createUninitialized.max(new int[]{2});
            case 2:
            case 3:
                INDArray createUninitialized2 = Nd4j.createUninitialized(dataType, castTo.shape());
                Nd4j.getExecutioner().exec(new BroadcastMulOp(castTo, castTo2, createUninitialized2, new int[]{0, 2}));
                INDArray sum = createUninitialized2.sum(new int[]{2});
                if (poolingType == PoolingType.SUM) {
                    return sum;
                }
                sum.diviColumnVector(castTo2.sum(new int[]{1}));
                return sum;
            case 4:
                INDArray createUninitialized3 = Nd4j.createUninitialized(dataType, castTo.shape());
                Nd4j.getExecutioner().exec(new BroadcastMulOp(castTo, castTo2, createUninitialized3, new int[]{0, 2}));
                INDArray abs = Transforms.abs(createUninitialized3, true);
                Transforms.pow(abs, Integer.valueOf(i), false);
                return Transforms.pow(abs.sum(new int[]{2}), Double.valueOf(1.0d / i));
            default:
                throw new UnsupportedOperationException("Unknown or not supported pooling type: " + poolingType);
        }
    }

    public static INDArray maskedPoolingEpsilonTimeSeries(PoolingType poolingType, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Expect rank 3 input activation array: got " + iNDArray.rank());
        }
        if (iNDArray2.rank() != 2) {
            throw new IllegalArgumentException("Expect rank 2 array for mask: got " + iNDArray2.rank());
        }
        if (iNDArray3.rank() != 2) {
            throw new IllegalArgumentException("Expected rank 2 array for errors: got " + iNDArray3.rank());
        }
        INDArray castTo = iNDArray2.castTo(iNDArray.dataType());
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[poolingType.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                INDArray rsub = castTo.rsub(Double.valueOf(1.0d));
                BooleanIndexing.replaceWhere(rsub, Double.valueOf(Double.NEGATIVE_INFINITY), Conditions.equals(Double.valueOf(1.0d)));
                INDArray createUninitialized = Nd4j.createUninitialized(iNDArray.dataType(), iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastAddOp(iNDArray, rsub, createUninitialized, new int[]{0, 2}));
                INDArray iNDArray4 = Nd4j.exec(new IsMax(createUninitialized, createUninitialized.ulike(), new int[]{2}))[0];
                return Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray4, iNDArray3, iNDArray4, new int[]{0, 1}));
            case 2:
            case 3:
                INDArray createUninitialized2 = Nd4j.createUninitialized(iNDArray.dataType(), iNDArray.shape(), 'f');
                Nd4j.getExecutioner().exec(new BroadcastCopyOp(createUninitialized2, iNDArray3, createUninitialized2, new int[]{0, 1}));
                Nd4j.getExecutioner().exec(new BroadcastMulOp(createUninitialized2, castTo, createUninitialized2, new int[]{0, 2}));
                if (poolingType == PoolingType.SUM) {
                    return createUninitialized2;
                }
                Nd4j.getExecutioner().exec(new BroadcastDivOp(createUninitialized2, castTo.sum(new int[]{1}), createUninitialized2, new int[]{0}));
                return createUninitialized2;
            case 4:
                INDArray createUninitialized3 = Nd4j.createUninitialized(iNDArray.dataType(), iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray, castTo, createUninitialized3, new int[]{0, 2}));
                INDArray abs = Transforms.abs(createUninitialized3, true);
                Transforms.pow(abs, Integer.valueOf(i), false);
                INDArray pow = Transforms.pow(abs.sum(new int[]{2}), Double.valueOf(1.0d / i));
                INDArray dup = i == 2 ? iNDArray.dup() : iNDArray.mul(Transforms.pow(Transforms.abs(iNDArray, true), Integer.valueOf(i - 2), false));
                INDArray pow2 = Transforms.pow(pow, Integer.valueOf(i - 1), false);
                pow2.rdivi(iNDArray3);
                Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dup, pow2, dup, new int[]{0, 1}));
                Nd4j.getExecutioner().exec(new BroadcastMulOp(dup, castTo, dup, new int[]{0, 2}));
                return dup;
            default:
                throw new UnsupportedOperationException("Unknown or not supported pooling type: " + poolingType);
        }
    }

    public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArray iNDArray, INDArray iNDArray2, int i, DataType dataType) {
        if (iNDArray2.rank() != 4) {
            throw new IllegalStateException("Expected rank 4 mask array: Got array with shape " + Arrays.toString(iNDArray2.shape()));
        }
        INDArray castTo = iNDArray2.castTo(dataType);
        int[] iArr = new int[4];
        int i2 = 0;
        for (int i3 = 0; i3 < 4; i3++) {
            if (iNDArray.size(i3) == castTo.size(i3)) {
                int i4 = i2;
                i2++;
                iArr[i4] = i3;
            }
        }
        if (i2 < 4) {
            iArr = Arrays.copyOfRange(iArr, 0, i2);
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[poolingType.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                INDArray castTo2 = castTo.dataType() == DataType.BOOL ? Transforms.not(castTo).castTo(dataType) : castTo.rsub(Double.valueOf(1.0d));
                BooleanIndexing.replaceWhere(castTo2, Double.valueOf(Double.NEGATIVE_INFINITY), Conditions.equals(Double.valueOf(1.0d)));
                INDArray createUninitialized = Nd4j.createUninitialized(dataType, iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastAddOp(iNDArray, castTo2, createUninitialized, iArr));
                return createUninitialized.max(new int[]{2, 3});
            case 2:
            case 3:
                INDArray createUninitialized2 = Nd4j.createUninitialized(dataType, iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray, castTo, createUninitialized2, iArr));
                INDArray sum = createUninitialized2.sum(new int[]{2, 3});
                if (poolingType == PoolingType.SUM) {
                    return sum;
                }
                sum.diviColumnVector(castTo.sum(new int[]{1, 2, 3}));
                return sum;
            case 4:
                INDArray createUninitialized3 = Nd4j.createUninitialized(dataType, iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray, castTo, createUninitialized3, iArr));
                INDArray abs = Transforms.abs(createUninitialized3, true);
                Transforms.pow(abs, Integer.valueOf(i), false);
                return Transforms.pow(abs.sum(new int[]{2, 3}), Double.valueOf(1.0d / i));
            default:
                throw new UnsupportedOperationException("Unknown or not supported pooling type: " + poolingType);
        }
    }

    public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i, DataType dataType) {
        INDArray castTo = iNDArray2.castTo(dataType);
        int[] iArr = new int[4];
        int i2 = 0;
        for (int i3 = 0; i3 < 4; i3++) {
            if (iNDArray.size(i3) == castTo.size(i3)) {
                int i4 = i2;
                i2++;
                iArr[i4] = i3;
            }
        }
        if (i2 < 4) {
            iArr = Arrays.copyOfRange(iArr, 0, i2);
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[poolingType.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                INDArray castTo2 = castTo.dataType() == DataType.BOOL ? Transforms.not(castTo).castTo(dataType) : castTo.rsub(Double.valueOf(1.0d));
                BooleanIndexing.replaceWhere(castTo2, Double.valueOf(Double.NEGATIVE_INFINITY), Conditions.equals(Double.valueOf(1.0d)));
                INDArray createUninitialized = Nd4j.createUninitialized(dataType, iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastAddOp(iNDArray, castTo2, createUninitialized, iArr));
                INDArray iNDArray4 = Nd4j.exec(new IsMax(createUninitialized, createUninitialized.ulike(), new int[]{2, 3}))[0];
                return Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray4, iNDArray3, iNDArray4, new int[]{0, 1}));
            case 2:
            case 3:
                INDArray createUninitialized2 = Nd4j.createUninitialized(dataType, iNDArray.shape(), 'f');
                Nd4j.getExecutioner().exec(new BroadcastCopyOp(createUninitialized2, iNDArray3, createUninitialized2, new int[]{0, 1}));
                Nd4j.getExecutioner().exec(new BroadcastMulOp(createUninitialized2, castTo, createUninitialized2, iArr));
                if (poolingType == PoolingType.SUM) {
                    return createUninitialized2;
                }
                Nd4j.getExecutioner().exec(new BroadcastDivOp(createUninitialized2, castTo.sum(new int[]{1, 2, 3}), createUninitialized2, new int[]{0}));
                return createUninitialized2;
            case 4:
                INDArray createUninitialized3 = Nd4j.createUninitialized(dataType, iNDArray.shape());
                Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray, castTo, createUninitialized3, iArr));
                INDArray abs = Transforms.abs(createUninitialized3, true);
                Transforms.pow(abs, Integer.valueOf(i), false);
                INDArray pow = Transforms.pow(abs.sum(new int[]{2, 3}), Double.valueOf(1.0d / i));
                INDArray dup = i == 2 ? iNDArray.dup() : iNDArray.mul(Transforms.pow(Transforms.abs(iNDArray, true), Integer.valueOf(i - 2), false));
                INDArray pow2 = Transforms.pow(pow, Integer.valueOf(i - 1), false);
                pow2.rdivi(iNDArray3);
                Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dup, pow2, dup, new int[]{0, 1}));
                Nd4j.getExecutioner().exec(new BroadcastMulOp(dup, castTo, dup, iArr));
                return dup;
            default:
                throw new UnsupportedOperationException("Unknown or not supported pooling type: " + poolingType);
        }
    }
}
