package org.deeplearning4j.nn.layers.mkldnn;

import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.class */
public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper implements LocalResponseNormalizationHelper {
    protected OpContext context;

    public MKLDNNLocalResponseNormalizationHelper(DataType dataType) {
    }

    @Override // org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper
    public boolean checkSupported(double d, double d2, double d3, double d4) {
        return BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override // org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, double d, double d2, double d3, double d4, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray.dataType(), iNDArray.shape());
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            this.context.setTArguments(new double[]{d, d3, d4});
            this.context.setIArguments(new long[]{(int) d2});
        } else {
            this.context.purge();
        }
        LocalResponseNormalization localResponseNormalization = new LocalResponseNormalization();
        this.context.setInputArray(0, iNDArray);
        this.context.setInputArray(0, iNDArray2);
        this.context.setOutputArray(0, createUninitialized);
        Nd4j.exec(localResponseNormalization, this.context);
        return new Pair<>(new DefaultGradient(), createUninitialized);
    }

    @Override // org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper
    public INDArray activate(INDArray iNDArray, boolean z, double d, double d2, double d3, double d4, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, iNDArray.dataType(), iNDArray.shape());
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            this.context.setTArguments(new double[]{d, d3, d4});
            this.context.setIArguments(new long[]{(int) d2});
        } else {
            this.context.purge();
        }
        this.context.setInputArray(0, iNDArray);
        this.context.setOutputArray(0, createUninitialized);
        Nd4j.exec(new LocalResponseNormalization(), this.context);
        return createUninitialized;
    }

    @Override // org.deeplearning4j.nn.layers.LayerHelper
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}
