package org.deeplearning4j.nn.weights;

import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/nn/weights/WeightInitDistribution.class */
public class WeightInitDistribution implements IWeightInit {
    private final Distribution distribution;

    public WeightInitDistribution(@JsonProperty("distribution") Distribution distribution) {
        if (distribution == null) {
            throw new IllegalArgumentException("Must set distribution!");
        }
        this.distribution = distribution;
    }

    @Override // org.deeplearning4j.nn.weights.IWeightInit
    public INDArray init(double d, double d2, long[] jArr, char c, INDArray iNDArray) {
        org.nd4j.linalg.api.rng.distribution.Distribution createDistribution = Distributions.createDistribution(this.distribution);
        if (createDistribution instanceof OrthogonalDistribution) {
            createDistribution.sample(iNDArray.reshape(c, jArr));
        } else {
            createDistribution.sample(iNDArray);
        }
        return iNDArray.reshape(c, jArr);
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof WeightInitDistribution)) {
            return false;
        }
        WeightInitDistribution weightInitDistribution = (WeightInitDistribution) obj;
        if (!weightInitDistribution.canEqual(this)) {
            return false;
        }
        Distribution distribution = this.distribution;
        Distribution distribution2 = weightInitDistribution.distribution;
        return distribution == null ? distribution2 == null : distribution.equals(distribution2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof WeightInitDistribution;
    }

    public int hashCode() {
        Distribution distribution = this.distribution;
        return (1 * 59) + (distribution == null ? 43 : distribution.hashCode());
    }
}
