package org.deeplearning4j.earlystopping.termination;

import org.deeplearning4j.eval.EvaluationBinary;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/earlystopping/termination/ScoreImprovementEpochTerminationCondition.class */
public class ScoreImprovementEpochTerminationCondition implements EpochTerminationCondition {
    private static final Logger log = LoggerFactory.getLogger(ScoreImprovementEpochTerminationCondition.class);

    @JsonProperty
    private int maxEpochsWithNoImprovement;

    @JsonProperty
    private int bestEpoch;

    @JsonProperty
    private double bestScore;

    @JsonProperty
    private double minImprovement;

    public ScoreImprovementEpochTerminationCondition(int i) {
        this.bestEpoch = -1;
        this.minImprovement = EvaluationBinary.DEFAULT_EDGE_VALUE;
        this.maxEpochsWithNoImprovement = i;
    }

    public ScoreImprovementEpochTerminationCondition(int i, double d) {
        this.bestEpoch = -1;
        this.minImprovement = EvaluationBinary.DEFAULT_EDGE_VALUE;
        this.maxEpochsWithNoImprovement = i;
        this.minImprovement = d;
    }

    @Override // org.deeplearning4j.earlystopping.termination.EpochTerminationCondition
    public void initialize() {
        this.bestEpoch = -1;
        this.bestScore = Double.NaN;
    }

    @Override // org.deeplearning4j.earlystopping.termination.EpochTerminationCondition
    public boolean terminate(int i, double d, boolean z) {
        if (this.bestEpoch == -1) {
            this.bestEpoch = i;
            this.bestScore = d;
            return false;
        }
        if ((z ? this.bestScore - d : d - this.bestScore) <= this.minImprovement) {
            return i >= this.bestEpoch + this.maxEpochsWithNoImprovement;
        }
        if (this.minImprovement > EvaluationBinary.DEFAULT_EDGE_VALUE) {
            log.info("Epoch with score greater than threshold * * *");
        }
        this.bestScore = d;
        this.bestEpoch = i;
        return false;
    }

    public String toString() {
        return "ScoreImprovementEpochTerminationCondition(maxEpochsWithNoImprovement=" + this.maxEpochsWithNoImprovement + ", minImprovement=" + this.minImprovement + ")";
    }

    public int getMaxEpochsWithNoImprovement() {
        return this.maxEpochsWithNoImprovement;
    }

    public int getBestEpoch() {
        return this.bestEpoch;
    }

    public double getBestScore() {
        return this.bestScore;
    }

    public double getMinImprovement() {
        return this.minImprovement;
    }

    public void setMaxEpochsWithNoImprovement(int i) {
        this.maxEpochsWithNoImprovement = i;
    }

    public void setBestEpoch(int i) {
        this.bestEpoch = i;
    }

    public void setBestScore(double d) {
        this.bestScore = d;
    }

    public void setMinImprovement(double d) {
        this.minImprovement = d;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ScoreImprovementEpochTerminationCondition)) {
            return false;
        }
        ScoreImprovementEpochTerminationCondition scoreImprovementEpochTerminationCondition = (ScoreImprovementEpochTerminationCondition) obj;
        return scoreImprovementEpochTerminationCondition.canEqual(this) && getMaxEpochsWithNoImprovement() == scoreImprovementEpochTerminationCondition.getMaxEpochsWithNoImprovement() && getBestEpoch() == scoreImprovementEpochTerminationCondition.getBestEpoch() && Double.compare(getBestScore(), scoreImprovementEpochTerminationCondition.getBestScore()) == 0 && Double.compare(getMinImprovement(), scoreImprovementEpochTerminationCondition.getMinImprovement()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ScoreImprovementEpochTerminationCondition;
    }

    public int hashCode() {
        int maxEpochsWithNoImprovement = (((1 * 59) + getMaxEpochsWithNoImprovement()) * 59) + getBestEpoch();
        long doubleToLongBits = Double.doubleToLongBits(getBestScore());
        int i = (maxEpochsWithNoImprovement * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getMinImprovement());
        return (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
    }
}
