package org.deeplearning4j.nn.layers.mkldnn;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.class */
public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
    private static final int[] RANK2_DIMS = {0};
    private static final int[] RANK4_DIMS_NCHW = {0, 2, 3};
    private static final int[] RANK4_DIMS_NHWC = {0, 1, 2};
    protected OpContext context;
    private INDArray meanCache;
    private INDArray varCache;

    public MKLDNNBatchNormHelper(DataType dataType) {
    }

    @Override // org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper
    public boolean checkSupported(double d, boolean z) {
        return !z && BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override // org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, long[] jArr, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, INDArray iNDArray6, double d, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!Shape.hasDefaultStridesForShape(iNDArray2)) {
            iNDArray2 = iNDArray2.dup('c');
        }
        if (iNDArray.dataType() != DataType.FLOAT) {
            return null;
        }
        int i = (iNDArray.rank() != 4 || cNN2DFormat == CNN2DFormat.NCHW) ? 1 : 3;
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray);
        arrayList.add(this.meanCache);
        arrayList.add(this.varCache);
        if (iNDArray3 != null) {
            arrayList.add(iNDArray3.reshape(new long[]{iNDArray3.length()}));
        }
        if (iNDArray4 != null) {
            arrayList.add(iNDArray4.reshape(new long[]{iNDArray4.length()}));
        }
        arrayList.add(iNDArray2);
        DynamicCustomOp.DynamicCustomOpsBuilder addInputs = DynamicCustomOp.builder("batchnorm_bp").addInputs((INDArray[]) arrayList.toArray(new INDArray[0]));
        int[] iArr = new int[3];
        iArr[0] = iNDArray3 == null ? 0 : 1;
        iArr[1] = iNDArray4 == null ? 0 : 1;
        iArr[2] = i;
        DynamicCustomOp build = addInputs.addIntegerArguments(iArr).addFloatingPointArguments(new Double[]{Double.valueOf(d)}).build();
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray.dataType(), iNDArray.shape());
        INDArray createUninitialized2 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, this.meanCache.dataType(), this.meanCache.shape());
        INDArray createUninitialized3 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, this.meanCache.dataType(), this.meanCache.shape());
        build.setOutputArgument(0, createUninitialized);
        build.setOutputArgument(1, createUninitialized2);
        build.setOutputArgument(2, createUninitialized3);
        if (iNDArray5 != null) {
            build.setOutputArgument(3, iNDArray5.reshape(new long[]{iNDArray5.length()}));
            build.setOutputArgument(4, iNDArray6.reshape(new long[]{iNDArray6.length()}));
        }
        Nd4j.exec(build);
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray5);
        defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray6);
        return new Pair<>(defaultGradient, createUninitialized);
    }

    @Override // org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper
    public INDArray preOutput(INDArray iNDArray, boolean z, long[] jArr, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, double d, double d2, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray reshape;
        INDArray reshape2;
        if (iNDArray.dataType() != DataType.FLOAT) {
            return null;
        }
        int i = (iNDArray.rank() != 4 || cNN2DFormat == CNN2DFormat.NCHW) ? 1 : 3;
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            OpContext opContext = this.context;
            long[] jArr2 = new long[3];
            jArr2[0] = ArrayUtil.fromBoolean(iNDArray2 != null);
            jArr2[1] = ArrayUtil.fromBoolean(iNDArray3 != null);
            jArr2[2] = i;
            opContext.setIArguments(jArr2);
            this.context.setTArguments(new double[]{d2});
        }
        if (z) {
            if (this.meanCache == null) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    this.meanCache = Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray.size(i)});
                    this.varCache = Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray.size(i)});
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            }
            int[] iArr = iNDArray.rank() == 2 ? RANK2_DIMS : cNN2DFormat == CNN2DFormat.NCHW ? RANK4_DIMS_NCHW : RANK4_DIMS_NHWC;
            iNDArray.mean(this.meanCache, iArr);
            Nd4j.exec(new Variance(iNDArray, this.varCache, false, iArr));
            reshape = this.meanCache;
            reshape2 = this.varCache;
        } else {
            reshape = iNDArray4.reshape(new long[]{iNDArray4.length()});
            reshape2 = iNDArray5.reshape(new long[]{iNDArray5.length()});
        }
        this.context.purge();
        this.context.setInputArray(0, iNDArray);
        this.context.setInputArray(1, reshape);
        this.context.setInputArray(2, reshape2);
        if (iNDArray2 != null && iNDArray3 != null) {
            this.context.setInputArray(3, iNDArray2.rank() == 2 ? iNDArray2.reshape(new long[]{iNDArray2.length()}) : iNDArray2);
            this.context.setInputArray(4, iNDArray3.rank() == 2 ? iNDArray3.reshape(new long[]{iNDArray3.length()}) : iNDArray3);
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, iNDArray.dataType(), iNDArray.shape());
        this.context.setOutputArray(0, createUninitialized);
        Nd4j.exec(new BatchNorm(), this.context);
        return createUninitialized;
    }

    @Override // org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper
    public INDArray getMeanCache(DataType dataType) {
        return this.meanCache;
    }

    @Override // org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper
    public INDArray getVarCache(DataType dataType) {
        return this.varCache;
    }

    @Override // org.deeplearning4j.nn.layers.LayerHelper
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}
