package org.nd4j.linalg.lossfunctions.impl;

import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossFMeasure.class */
public class LossFMeasure implements ILossFunction {
    public static final double DEFAULT_BETA = 1.0d;
    private final double beta;

    public LossFMeasure() {
        this(1.0d);
    }

    public LossFMeasure(@JsonProperty("beta") double d) {
        if (d <= 0.0d) {
            throw new UnsupportedOperationException("Invalid value: beta must be > 0. Got: " + d);
        }
        this.beta = d;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double[] computeScoreNumDenom = computeScoreNumDenom(iNDArray, iNDArray2, iActivation, iNDArray3, z);
        double d = computeScoreNumDenom[0];
        double d2 = computeScoreNumDenom[1];
        if (d == 0.0d && d2 == 0.0d) {
            return 0.0d;
        }
        return 1.0d - (d / d2);
    }

    private double[] computeScoreNumDenom(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        INDArray column;
        INDArray column2;
        INDArray column3;
        INDArray column4;
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        int size = iNDArray.size(1);
        if (size != 1 && size != 2) {
            throw new UnsupportedOperationException("For binary classification: expect output size of 1 or 2. Got: " + size);
        }
        if (size == 1) {
            column = iNDArray;
            column2 = Transforms.not(column);
            column3 = activation.rsub(Double.valueOf(1.0d));
            column4 = activation;
        } else {
            column = iNDArray.getColumn(1);
            column2 = iNDArray.getColumn(0);
            column3 = activation.getColumn(0);
            column4 = activation.getColumn(1);
        }
        if (iNDArray3 != null) {
            column = column.mulColumnVector(iNDArray3);
            column2 = column2.mulColumnVector(iNDArray3);
        }
        double doubleValue = column.mul(column4).sumNumber().doubleValue();
        return new double[]{(1.0d + (this.beta * this.beta)) * doubleValue, ((1.0d + (this.beta * this.beta)) * doubleValue) + (this.beta * this.beta * column.mul(column3).sumNumber().doubleValue()) + column2.mul(column4).sumNumber().doubleValue()};
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        throw new UnsupportedOperationException("Cannot compute score array for FMeasure loss function: loss is only defined for minibatches");
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        INDArray create;
        double[] computeScoreNumDenom = computeScoreNumDenom(iNDArray, iNDArray2, iActivation, iNDArray3, false);
        double d = computeScoreNumDenom[0];
        double d2 = computeScoreNumDenom[1];
        if (d == 0.0d && d2 == 0.0d) {
            return Nd4j.create(iNDArray2.shape());
        }
        double d3 = d / (d2 * d2);
        if (iNDArray.size(1) == 1) {
            create = iNDArray.mul(Double.valueOf(1.0d + (this.beta * this.beta))).divi(Double.valueOf(d2)).subi(Double.valueOf(d3));
        } else {
            create = Nd4j.create(iNDArray.shape());
            create.getColumn(1).assign(iNDArray.getColumn(1).mul(Double.valueOf(1.0d + (this.beta * this.beta))).divi(Double.valueOf(d2)).subi(Double.valueOf(d3)));
        }
        create.negi();
        INDArray iNDArray4 = (INDArray) iActivation.backprop(iNDArray2, create).getFirst();
        if (iNDArray3 != null) {
            iNDArray4.muliColumnVector(iNDArray3);
        }
        return iNDArray4;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public String name() {
        return "floss";
    }

    public String toString() {
        return "LossFMeasure(beta=" + this.beta + ")";
    }

    public double getBeta() {
        return this.beta;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossFMeasure)) {
            return false;
        }
        LossFMeasure lossFMeasure = (LossFMeasure) obj;
        return lossFMeasure.canEqual(this) && Double.compare(getBeta(), lossFMeasure.getBeta()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossFMeasure;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getBeta());
        return (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }
}
