package org.nd4j.linalg.learning.config;

import java.beans.ConstructorProperties;
import java.util.HashMap;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.NesterovsUpdater;

/* loaded from: input_file:org/nd4j/linalg/learning/config/Nesterovs.class */
public class Nesterovs implements IUpdater {
    public static final double DEFAULT_NESTEROV_MOMENTUM = 0.9d;
    public static final double DEFAULT_NESTEROV_LEARNING_RATE = 0.1d;
    private double learningRate;
    private double momentum;
    private Map<Integer, Double> momentumSchedule;

    /* loaded from: input_file:org/nd4j/linalg/learning/config/Nesterovs$Builder.class */
    public static class Builder {
        private Map<Integer, Double> momentumSchedule;
        private double learningRate = 0.1d;
        private double momentum = 0.1d;

        public Builder learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder momentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder momentumSchedule(Map<Integer, Double> map) {
            this.momentumSchedule = map;
            return this;
        }

        public Nesterovs build() {
            return new Nesterovs(this.learningRate, this.momentum, this.momentumSchedule);
        }

        public String toString() {
            return "Nesterovs.Builder(learningRate=" + this.learningRate + ", momentum=" + this.momentum + ", momentumSchedule=" + this.momentumSchedule + ")";
        }
    }

    public Nesterovs() {
        this(0.1d, 0.9d, null);
    }

    public Nesterovs(double d) {
        this(0.1d, d);
    }

    public Nesterovs(double d, double d2) {
        this(d, d2, null);
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public long stateSize(long j) {
        return j;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public void applySchedules(int i, double d) {
        this.learningRate = d;
        if (this.momentumSchedule == null || !this.momentumSchedule.containsKey(Integer.valueOf(i))) {
            return;
        }
        this.momentum = this.momentumSchedule.get(Integer.valueOf(i)).doubleValue();
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public GradientUpdater instantiate(INDArray iNDArray, boolean z) {
        NesterovsUpdater nesterovsUpdater = new NesterovsUpdater(this);
        nesterovsUpdater.setStateViewArray(iNDArray, iNDArray.shape(), iNDArray.ordering(), z);
        return nesterovsUpdater;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Nesterovs m67clone() {
        return new Nesterovs(this.learningRate, this.momentum, this.momentumSchedule == null ? null : new HashMap(this.momentumSchedule));
    }

    public static Builder builder() {
        return new Builder();
    }

    @ConstructorProperties({"learningRate", "momentum", "momentumSchedule"})
    public Nesterovs(double d, double d2, Map<Integer, Double> map) {
        this.learningRate = d;
        this.momentum = d2;
        this.momentumSchedule = map;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public Map<Integer, Double> getMomentumSchedule() {
        return this.momentumSchedule;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    public void setMomentumSchedule(Map<Integer, Double> map) {
        this.momentumSchedule = map;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Nesterovs)) {
            return false;
        }
        Nesterovs nesterovs = (Nesterovs) obj;
        if (!nesterovs.canEqual(this) || Double.compare(getLearningRate(), nesterovs.getLearningRate()) != 0 || Double.compare(getMomentum(), nesterovs.getMomentum()) != 0) {
            return false;
        }
        Map<Integer, Double> momentumSchedule = getMomentumSchedule();
        Map<Integer, Double> momentumSchedule2 = nesterovs.getMomentumSchedule();
        return momentumSchedule == null ? momentumSchedule2 == null : momentumSchedule.equals(momentumSchedule2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Nesterovs;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getMomentum());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        Map<Integer, Double> momentumSchedule = getMomentumSchedule();
        return (i2 * 59) + (momentumSchedule == null ? 43 : momentumSchedule.hashCode());
    }

    public String toString() {
        return "Nesterovs(learningRate=" + getLearningRate() + ", momentum=" + getMomentum() + ", momentumSchedule=" + getMomentumSchedule() + ")";
    }
}
