package org.deeplearning4j.nn.layers.training;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.params.CenterLossParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.class */
public class CenterLossOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer> {
    private double fullNetworkL1;
    private double fullNetworkL2;

    public CenterLossOutputLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public CenterLossOutputLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, double d2, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        this.fullNetworkL1 = d;
        this.fullNetworkL2 = d2;
        INDArray preOutput2d = preOutput2d(z, layerWorkspaceMgr);
        ILossFunction lossFn = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLossFn();
        INDArray norm2 = this.input.sub(this.labels.mmul(this.params.get(CenterLossParamInitializer.CENTER_KEY))).norm2(new int[]{1});
        norm2.muli(norm2);
        double computeScore = ((lossFn.computeScore(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getActivationFn(), this.maskArray, false) + ((((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLambda() / 2.0d) * norm2.sumNumber().doubleValue())) + (d + d2)) / getInputMiniBatchSize();
        this.score = computeScore;
        return computeScore;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, double d2, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        INDArray preOutput2d = preOutput2d(false, layerWorkspaceMgr);
        INDArray sub = this.input.sub(this.labels.mmul(this.params.get(CenterLossParamInitializer.CENTER_KEY)));
        INDArray computeScoreArray = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLossFn().computeScoreArray(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getActivationFn(), this.maskArray);
        computeScoreArray.addi(sub.muli(Double.valueOf(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLambda() / 2.0d)));
        double d3 = d + d2;
        if (d3 != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            computeScoreArray.addi(Double.valueOf(d3));
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, computeScoreArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            return;
        }
        this.gradient = (Gradient) getGradientsAndDelta(preOutput2d(true, layerWorkspaceMgr), layerWorkspaceMgr).getFirst();
        this.score = computeScore(this.fullNetworkL1, this.fullNetworkL2, true, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer
    protected void setScoreWithZ(INDArray iNDArray) {
        throw new RuntimeException("Not supported " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        Pair<Gradient, INDArray> gradientsAndDelta = getGradientsAndDelta(preOutput2d(true, layerWorkspaceMgr), layerWorkspaceMgr);
        INDArray iNDArray2 = (INDArray) gradientsAndDelta.getSecond();
        INDArray sub = this.input.sub(this.labels.mmul(this.params.get(CenterLossParamInitializer.CENTER_KEY)));
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        INDArray transpose = paramWithNoise.mmuli(iNDArray2.transpose(), layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{paramWithNoise.size(0), iNDArray2.size(0)}, 'f')).transpose();
        transpose.addi(sub.muli(Double.valueOf(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLambda())));
        this.weightNoiseParams.clear();
        return new Pair<>(gradientsAndDelta.getFirst(), transpose);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        ILossFunction lossFn = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLossFn();
        INDArray labels2d = getLabels2d(layerWorkspaceMgr, ArrayType.BP_WORKING_MEM);
        if (labels2d.size(1) != iNDArray.size(1)) {
            throw new DL4JInvalidInputException("Labels array numColumns (size(1) = " + labels2d.size(1) + ") does not match output layer number of outputs (nOut = " + iNDArray.size(1) + ") " + layerId());
        }
        INDArray computeGradient = lossFn.computeGradient(labels2d, iNDArray, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getActivationFn(), this.maskArray);
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray iNDArray2 = this.gradientViews.get("W");
        INDArray iNDArray3 = this.gradientViews.get("b");
        INDArray iNDArray4 = this.gradientViews.get(CenterLossParamInitializer.CENTER_KEY);
        INDArray mmul = this.labels.transpose().mmul(this.labels.mmul(this.params.get(CenterLossParamInitializer.CENTER_KEY)).sub(this.input).muli(Double.valueOf(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getAlpha())));
        iNDArray4.assign(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getGradientCheck() ? mmul.muli(Double.valueOf(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) layerConf()).getLambda())) : mmul.diviColumnVector(this.labels.sum(new int[]{0}).reshape(this.labels.size(1), 1L).addi(Double.valueOf(1.0d))));
        Nd4j.gemm(this.input, computeGradient, iNDArray2, true, false, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
        computeGradient.sum(iNDArray3, new int[]{0});
        defaultGradient.gradientForVariable().put("W", iNDArray2);
        defaultGradient.gradientForVariable().put("b", iNDArray3);
        defaultGradient.gradientForVariable().put(CenterLossParamInitializer.CENTER_KEY, iNDArray4);
        return new Pair<>(defaultGradient, computeGradient);
    }

    @Override // org.deeplearning4j.nn.layers.BaseOutputLayer
    protected INDArray getLabels2d(LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        return this.labels;
    }
}
