package org.deeplearning4j.earlystopping;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.common.function.Supplier;

/* loaded from: input_file:org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.class */
public class EarlyStoppingConfiguration<T extends Model> implements Serializable {
    private EarlyStoppingModelSaver<T> modelSaver;
    private List<EpochTerminationCondition> epochTerminationConditions;
    private List<IterationTerminationCondition> iterationTerminationConditions;
    private boolean saveLastModel;
    private int evaluateEveryNEpochs;
    private ScoreCalculator<T> scoreCalculator;
    private Supplier<ScoreCalculator> scoreCalculatorSupplier;

    /* loaded from: input_file:org/deeplearning4j/earlystopping/EarlyStoppingConfiguration$Builder.class */
    public static class Builder<T extends Model> {
        private EarlyStoppingModelSaver<T> modelSaver = new InMemoryModelSaver();
        private List<EpochTerminationCondition> epochTerminationConditions = new ArrayList();
        private List<IterationTerminationCondition> iterationTerminationConditions = new ArrayList();
        private boolean saveLastModel = false;
        private int evaluateEveryNEpochs = 1;
        private ScoreCalculator<T> scoreCalculator;
        private Supplier<ScoreCalculator> scoreCalculatorSupplier;

        public Builder<T> modelSaver(EarlyStoppingModelSaver<T> earlyStoppingModelSaver) {
            this.modelSaver = earlyStoppingModelSaver;
            return this;
        }

        public Builder<T> epochTerminationConditions(EpochTerminationCondition... epochTerminationConditionArr) {
            this.epochTerminationConditions.clear();
            Collections.addAll(this.epochTerminationConditions, epochTerminationConditionArr);
            return this;
        }

        public Builder<T> epochTerminationConditions(List<EpochTerminationCondition> list) {
            this.epochTerminationConditions = list;
            return this;
        }

        public Builder<T> iterationTerminationConditions(IterationTerminationCondition... iterationTerminationConditionArr) {
            this.iterationTerminationConditions.clear();
            Collections.addAll(this.iterationTerminationConditions, iterationTerminationConditionArr);
            return this;
        }

        public Builder<T> saveLastModel(boolean z) {
            this.saveLastModel = z;
            return this;
        }

        public Builder<T> evaluateEveryNEpochs(int i) {
            this.evaluateEveryNEpochs = i;
            return this;
        }

        public Builder<T> scoreCalculator(ScoreCalculator scoreCalculator) {
            this.scoreCalculator = scoreCalculator;
            return this;
        }

        public Builder<T> scoreCalculator(Supplier<ScoreCalculator> supplier) {
            this.scoreCalculatorSupplier = supplier;
            return this;
        }

        public EarlyStoppingConfiguration<T> build() {
            return new EarlyStoppingConfiguration<>(this);
        }
    }

    private EarlyStoppingConfiguration(Builder<T> builder) {
        this.modelSaver = ((Builder) builder).modelSaver;
        this.epochTerminationConditions = ((Builder) builder).epochTerminationConditions;
        this.iterationTerminationConditions = ((Builder) builder).iterationTerminationConditions;
        this.saveLastModel = ((Builder) builder).saveLastModel;
        this.evaluateEveryNEpochs = ((Builder) builder).evaluateEveryNEpochs;
        this.scoreCalculator = ((Builder) builder).scoreCalculator;
        this.scoreCalculatorSupplier = ((Builder) builder).scoreCalculatorSupplier;
    }

    public ScoreCalculator<T> getScoreCalculator() {
        return this.scoreCalculatorSupplier != null ? (ScoreCalculator) this.scoreCalculatorSupplier.get() : this.scoreCalculator;
    }

    public void validate() {
        if (this.scoreCalculator == null && this.scoreCalculatorSupplier == null) {
            throw new DL4JInvalidConfigException("A score calculator or score calculator supplier must be defined.");
        }
        if (this.modelSaver == null) {
            throw new DL4JInvalidConfigException("A model saver must be defined");
        }
        boolean z = false;
        if (this.iterationTerminationConditions != null && !this.iterationTerminationConditions.isEmpty()) {
            z = true;
        } else if (this.epochTerminationConditions != null && !this.epochTerminationConditions.isEmpty()) {
            z = true;
        }
        if (!z) {
            throw new DL4JInvalidConfigException("No termination conditions defined.");
        }
    }

    public EarlyStoppingModelSaver<T> getModelSaver() {
        return this.modelSaver;
    }

    public List<EpochTerminationCondition> getEpochTerminationConditions() {
        return this.epochTerminationConditions;
    }

    public List<IterationTerminationCondition> getIterationTerminationConditions() {
        return this.iterationTerminationConditions;
    }

    public boolean isSaveLastModel() {
        return this.saveLastModel;
    }

    public int getEvaluateEveryNEpochs() {
        return this.evaluateEveryNEpochs;
    }

    public Supplier<ScoreCalculator> getScoreCalculatorSupplier() {
        return this.scoreCalculatorSupplier;
    }

    public void setModelSaver(EarlyStoppingModelSaver<T> earlyStoppingModelSaver) {
        this.modelSaver = earlyStoppingModelSaver;
    }

    public void setEpochTerminationConditions(List<EpochTerminationCondition> list) {
        this.epochTerminationConditions = list;
    }

    public void setIterationTerminationConditions(List<IterationTerminationCondition> list) {
        this.iterationTerminationConditions = list;
    }

    public void setSaveLastModel(boolean z) {
        this.saveLastModel = z;
    }

    public void setEvaluateEveryNEpochs(int i) {
        this.evaluateEveryNEpochs = i;
    }

    public void setScoreCalculator(ScoreCalculator<T> scoreCalculator) {
        this.scoreCalculator = scoreCalculator;
    }

    public void setScoreCalculatorSupplier(Supplier<ScoreCalculator> supplier) {
        this.scoreCalculatorSupplier = supplier;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof EarlyStoppingConfiguration)) {
            return false;
        }
        EarlyStoppingConfiguration earlyStoppingConfiguration = (EarlyStoppingConfiguration) obj;
        if (!earlyStoppingConfiguration.canEqual(this)) {
            return false;
        }
        EarlyStoppingModelSaver<T> modelSaver = getModelSaver();
        EarlyStoppingModelSaver<T> modelSaver2 = earlyStoppingConfiguration.getModelSaver();
        if (modelSaver == null) {
            if (modelSaver2 != null) {
                return false;
            }
        } else if (!modelSaver.equals(modelSaver2)) {
            return false;
        }
        List<EpochTerminationCondition> epochTerminationConditions = getEpochTerminationConditions();
        List<EpochTerminationCondition> epochTerminationConditions2 = earlyStoppingConfiguration.getEpochTerminationConditions();
        if (epochTerminationConditions == null) {
            if (epochTerminationConditions2 != null) {
                return false;
            }
        } else if (!epochTerminationConditions.equals(epochTerminationConditions2)) {
            return false;
        }
        List<IterationTerminationCondition> iterationTerminationConditions = getIterationTerminationConditions();
        List<IterationTerminationCondition> iterationTerminationConditions2 = earlyStoppingConfiguration.getIterationTerminationConditions();
        if (iterationTerminationConditions == null) {
            if (iterationTerminationConditions2 != null) {
                return false;
            }
        } else if (!iterationTerminationConditions.equals(iterationTerminationConditions2)) {
            return false;
        }
        if (isSaveLastModel() != earlyStoppingConfiguration.isSaveLastModel() || getEvaluateEveryNEpochs() != earlyStoppingConfiguration.getEvaluateEveryNEpochs()) {
            return false;
        }
        ScoreCalculator<T> scoreCalculator = getScoreCalculator();
        ScoreCalculator<T> scoreCalculator2 = earlyStoppingConfiguration.getScoreCalculator();
        if (scoreCalculator == null) {
            if (scoreCalculator2 != null) {
                return false;
            }
        } else if (!scoreCalculator.equals(scoreCalculator2)) {
            return false;
        }
        Supplier<ScoreCalculator> scoreCalculatorSupplier = getScoreCalculatorSupplier();
        Supplier<ScoreCalculator> scoreCalculatorSupplier2 = earlyStoppingConfiguration.getScoreCalculatorSupplier();
        return scoreCalculatorSupplier == null ? scoreCalculatorSupplier2 == null : scoreCalculatorSupplier.equals(scoreCalculatorSupplier2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof EarlyStoppingConfiguration;
    }

    public int hashCode() {
        EarlyStoppingModelSaver<T> modelSaver = getModelSaver();
        int hashCode = (1 * 59) + (modelSaver == null ? 43 : modelSaver.hashCode());
        List<EpochTerminationCondition> epochTerminationConditions = getEpochTerminationConditions();
        int hashCode2 = (hashCode * 59) + (epochTerminationConditions == null ? 43 : epochTerminationConditions.hashCode());
        List<IterationTerminationCondition> iterationTerminationConditions = getIterationTerminationConditions();
        int hashCode3 = (((((hashCode2 * 59) + (iterationTerminationConditions == null ? 43 : iterationTerminationConditions.hashCode())) * 59) + (isSaveLastModel() ? 79 : 97)) * 59) + getEvaluateEveryNEpochs();
        ScoreCalculator<T> scoreCalculator = getScoreCalculator();
        int hashCode4 = (hashCode3 * 59) + (scoreCalculator == null ? 43 : scoreCalculator.hashCode());
        Supplier<ScoreCalculator> scoreCalculatorSupplier = getScoreCalculatorSupplier();
        return (hashCode4 * 59) + (scoreCalculatorSupplier == null ? 43 : scoreCalculatorSupplier.hashCode());
    }

    public String toString() {
        return "EarlyStoppingConfiguration(modelSaver=" + getModelSaver() + ", epochTerminationConditions=" + getEpochTerminationConditions() + ", iterationTerminationConditions=" + getIterationTerminationConditions() + ", saveLastModel=" + isSaveLastModel() + ", evaluateEveryNEpochs=" + getEvaluateEveryNEpochs() + ", scoreCalculator=" + getScoreCalculator() + ", scoreCalculatorSupplier=" + getScoreCalculatorSupplier() + ")";
    }

    public EarlyStoppingConfiguration() {
    }
}
