package org.deeplearning4j.nn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/nn/layers/LossLayer.class */
public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossLayer> implements Serializable, IOutputLayer {
    protected INDArray labels;
    private transient Solver solver;
    private double fullNetworkL1;
    private double fullNetworkL2;

    public LossLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public LossLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, double d2, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        this.fullNetworkL1 = d;
        this.fullNetworkL2 = d2;
        double computeScore = (layerConf().getLossFn().computeScore(getLabels2d(), this.input, layerConf().getActivationFn(), this.maskArray, false) + (d + d2)) / getInputMiniBatchSize();
        this.score = computeScore;
        return computeScore;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, double d2, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        INDArray computeScoreArray = layerConf().getLossFn().computeScoreArray(getLabels2d(), this.input, layerConf().getActivationFn(), this.maskArray);
        double d3 = d + d2;
        if (d3 != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            computeScoreArray.addi(Double.valueOf(d3));
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, computeScoreArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            return;
        }
        this.gradient = (Gradient) getGradientsAndDelta(this.input, layerWorkspaceMgr).getFirst();
        this.score = computeScore(this.fullNetworkL1, this.fullNetworkL2, true, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    protected void setScoreWithZ(INDArray iNDArray) {
        throw new RuntimeException("Not supported " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return getGradientsAndDelta(this.input, layerWorkspaceMgr);
    }

    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return new Pair<>(new DefaultGradient(), layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, layerConf().getLossFn().computeGradient(getLabels2d(), iNDArray, layerConf().getActivationFn(), this.maskArray)));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray activation = layerConf().getActivationFn().getActivation(this.input.dup(), z);
        if (this.maskArray != null) {
            activation.muliColumnVector(this.maskArray);
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, activation);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        return activate(z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(DataSet dataSet) {
        return f1Score(dataSet.getFeatures(), dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, labelProbabilities(iNDArray));
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return (int) this.labels.size(1);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray activate = activate(iNDArray, false, LayerWorkspaceMgr.noWorkspacesImmutable());
        int[] iArr = new int[iNDArray.rows()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = Nd4j.getBlasWrapper().iamax(activate.getRow(i));
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(DataSet dataSet) {
        int[] predict = predict(dataSet.getFeatures());
        ArrayList arrayList = new ArrayList();
        for (int i : predict) {
            arrayList.add(i, dataSet.getLabelName(i));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        return activate(iNDArray, false, LayerWorkspaceMgr.noWorkspacesImmutable());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("LossLayer has no parameters and cannot be fit");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        fit(dataSet.getFeatures(), dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        fit(iNDArray, FeatureUtil.toOutcomeMatrix(iArr, numLabels()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void clear() {
        super.clear();
        if (this.labels != null) {
            this.labels.data().destroy();
            this.labels = null;
        }
        this.solver = null;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray getLabels() {
        return this.labels;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public boolean needsLabels() {
        return true;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    protected INDArray getLabels2d() {
        return this.labels.rank() > 2 ? this.labels.reshape(this.labels.size(2), this.labels.size(1)) : this.labels;
    }
}
