package org.deeplearning4j.nn.layers.ocnn;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.class */
public class OCNNOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer> {
    private IActivation activation;
    private static IActivation relu = new ActivationReLU();
    private ILossFunction lossFunction;
    private int batchWindowSizeIndex;
    private INDArray window;

    /* loaded from: input_file:org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer$OCNNLossFunction.class */
    public class OCNNLossFunction implements ILossFunction {
        public OCNNLossFunction() {
        }

        public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
            double doubleValue = Transforms.pow(OCNNOutputLayer.this.getParam(OCNNParamInitializer.W_KEY), 2).sumNumber().doubleValue() * 0.5d;
            double doubleValue2 = Transforms.pow(OCNNOutputLayer.this.getParam(OCNNParamInitializer.V_KEY), 2).sumNumber().doubleValue() * 0.5d;
            org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer oCNNOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) OCNNOutputLayer.this.conf().getLayer();
            double doubleValue3 = OCNNOutputLayer.relu.getActivation(iNDArray2.rsub(Double.valueOf(OCNNOutputLayer.this.getParam(OCNNParamInitializer.R_KEY).getDouble(0L))), true).meanNumber().doubleValue();
            double d = OCNNOutputLayer.this.getParam(OCNNParamInitializer.R_KEY).getDouble(0L);
            double nu = (1.0d / oCNNOutputLayer.getNu()) * doubleValue3;
            return doubleValue + doubleValue2 + nu + (-d);
        }

        public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
            return OCNNOutputLayer.this.getParam(OCNNParamInitializer.R_KEY).sub(iNDArray2);
        }

        public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
            INDArray rsub = iNDArray2.rsub(Double.valueOf(OCNNOutputLayer.this.getParam(OCNNParamInitializer.R_KEY).getDouble(0L)));
            return (INDArray) OCNNOutputLayer.relu.backprop(rsub, Nd4j.ones(rsub.shape())).getFirst();
        }

        public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
            return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
        }

        public String name() {
            return "OCNNLossFunction";
        }
    }

    public OCNNOutputLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.activation = new ActivationReLU();
        this.lossFunction = new OCNNLossFunction();
        ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) neuralNetConfiguration.getLayer()).setLossFn(this.lossFunction);
    }

    public OCNNOutputLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        this.activation = new ActivationReLU();
        ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) neuralNetConfiguration.getLayer()).setLossFn(this.lossFunction);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public void setLabels(INDArray iNDArray) {
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, double d2, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        double computeScore = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getLossFn().computeScore(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d(z, layerWorkspaceMgr), ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getActivationFn(), this.maskArray, false) + d + d2;
        if (conf().isMiniBatch()) {
            computeScore /= getInputMiniBatchSize();
        }
        this.score = computeScore;
        return computeScore;
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public boolean needsLabels() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        Pair<Gradient, INDArray> gradientsAndDelta = getGradientsAndDelta(preOutput2d(true, layerWorkspaceMgr), layerWorkspaceMgr);
        long nIn = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) getConf().getLayer()).getNIn();
        INDArray iNDArray2 = (INDArray) gradientsAndDelta.getSecond();
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{nIn, iNDArray2.length()}, 'f');
        return new Pair<>(gradientsAndDelta.getFirst(), createUninitialized.assign(iNDArray2.broadcast(createUninitialized.shape())).transpose());
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray computeGradient = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getLossFn().computeGradient(getLabels2d(layerWorkspaceMgr, ArrayType.BP_WORKING_MEM), iNDArray, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getActivationFn(), this.maskArray);
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer oCNNOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) conf().getLayer();
        if (oCNNOutputLayer.getLastEpochSinceRUpdated() == 0 && this.epochCount == 0) {
            INDArray doOutput = doOutput(false, layerWorkspaceMgr);
            if (this.window == null) {
                this.window = Nd4j.createUninitializedDetached(oCNNOutputLayer.getWindowSize()).assign(Double.valueOf(EvaluationBinary.DEFAULT_EDGE_VALUE));
            }
            if (this.batchWindowSizeIndex < this.window.length() - doOutput.length()) {
                this.window.put(new INDArrayIndex[]{NDArrayIndex.interval(this.batchWindowSizeIndex, this.batchWindowSizeIndex + doOutput.length())}, doOutput);
            } else if (this.batchWindowSizeIndex < this.window.length()) {
                int length = ((int) this.window.length()) - this.batchWindowSizeIndex;
                this.window.put(new INDArrayIndex[]{NDArrayIndex.interval(this.window.length() - length, this.window.length())}, doOutput.get(new INDArrayIndex[]{NDArrayIndex.interval(0, length)}));
            }
            this.batchWindowSizeIndex = (int) (this.batchWindowSizeIndex + doOutput.length());
            oCNNOutputLayer.setLastEpochSinceRUpdated(this.epochCount);
        } else if (oCNNOutputLayer.getLastEpochSinceRUpdated() != this.epochCount) {
            getParam(OCNNParamInitializer.R_KEY).putScalar(0L, this.window.percentileNumber(Double.valueOf(100.0d * oCNNOutputLayer.getNu())).doubleValue());
            oCNNOutputLayer.setLastEpochSinceRUpdated(this.epochCount);
            this.batchWindowSizeIndex = 0;
        } else {
            INDArray doOutput2 = doOutput(false, layerWorkspaceMgr);
            this.window.put(new INDArrayIndex[]{NDArrayIndex.interval(this.batchWindowSizeIndex, this.batchWindowSizeIndex + doOutput2.length())}, doOutput2);
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray iNDArray2 = this.gradientViews.get(OCNNParamInitializer.V_KEY);
        double nu = 1.0d / ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getNu();
        INDArray mmul = this.input.mmul(getParam(OCNNParamInitializer.V_KEY));
        defaultGradient.setGradientFor(OCNNParamInitializer.W_KEY, this.gradientViews.get(OCNNParamInitializer.W_KEY).assign(((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getActivationFn().getActivation(mmul.dup(), true).negi().muliColumnVector(computeGradient).mean(new int[]{0}).muli(Double.valueOf(nu)).addi(getParam(OCNNParamInitializer.W_KEY))));
        INDArray reshape = ((INDArray) ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getActivationFn().backprop(mmul.dup(), Nd4j.ones(mmul.shape())).getFirst()).muliRowVector(getParam(OCNNParamInitializer.W_KEY).neg()).muliColumnVector(computeGradient).reshape('f', new long[]{this.input.size(0), 1, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getHiddenSize()});
        INDArray reshape2 = this.input.reshape('f', new long[]{this.input.size(0), getParam(OCNNParamInitializer.V_KEY).size(0), 1});
        long[] jArr = new long[reshape.shape().length];
        for (int i = 0; i < reshape.rank(); i++) {
            jArr[i] = Math.max(reshape.size(i), reshape2.size(i));
        }
        INDArray broadcast = reshape.broadcast(Nd4j.createUninitialized(jArr));
        Broadcast.mul(broadcast, reshape2, broadcast, new int[]{0, 1});
        defaultGradient.setGradientFor(OCNNParamInitializer.V_KEY, iNDArray2.assign(broadcast.mean(new int[]{0}).muli(Double.valueOf(nu)).addi(getParam(OCNNParamInitializer.V_KEY))));
        defaultGradient.setGradientFor(OCNNParamInitializer.R_KEY, this.gradientViews.get(OCNNParamInitializer.R_KEY).assign(Nd4j.scalar(computeGradient.meanNumber()).muli(Double.valueOf(nu)).addi(-1)));
        clearNoiseWeightParams();
        return new Pair<>(defaultGradient, backpropDropOutIfPresent(computeGradient));
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.input = iNDArray;
        return doOutput(z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        float[] asFloat = iNDArray.data().asFloat();
        for (int i = 0; i < asFloat.length; i++) {
            if (asFloat[i] < 0.0f) {
                asFloat[i] = 0.0f;
            } else {
                asFloat[i] = 1.0f;
            }
        }
        return Nd4j.create(asFloat);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    public INDArray preOutput2d(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return doOutput(z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    protected INDArray getLabels2d(LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        return this.labels;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return doOutput(z, layerWorkspaceMgr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private INDArray doOutput(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        INDArray paramWithNoise = getParamWithNoise(OCNNParamInitializer.W_KEY, z, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise(OCNNParamInitializer.V_KEY, z, layerWorkspaceMgr);
        applyDropOutIfNecessary(z, layerWorkspaceMgr);
        INDArray createUninitialized = Nd4j.createUninitialized(this.input.size(0), paramWithNoise2.size(1));
        this.input.mmuli(paramWithNoise2, createUninitialized);
        INDArray activation = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getActivationFn().getActivation(createUninitialized, z);
        INDArray createUninitialized2 = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{this.input.size(0)});
        activation.mmuli(paramWithNoise.reshape(new long[]{paramWithNoise.length()}), createUninitialized2);
        this.labels = createUninitialized2;
        return createUninitialized2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, double d2, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        INDArray sum = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getLossFn().computeScoreArray(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d(false, layerWorkspaceMgr), ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layerConf()).getActivationFn(), this.maskArray).sum(new int[]{1});
        double d3 = d + d2;
        if (d3 != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            sum.addi(Double.valueOf(d3));
        }
        return sum;
    }

    public void setActivation(IActivation iActivation) {
        this.activation = iActivation;
    }

    public IActivation getActivation() {
        return this.activation;
    }
}
