package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/NormalizerStandardize.class */
public class NormalizerStandardize extends AbstractDataSetNormalizer<DistributionStats> {
    public NormalizerStandardize(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        this();
        if (iNDArray == null) {
            throw new NullPointerException("featureMean");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("featureStd");
        }
        setFeatureStats(new DistributionStats(iNDArray, iNDArray2));
        fitLabel(false);
    }

    public NormalizerStandardize(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, @NonNull INDArray iNDArray3, @NonNull INDArray iNDArray4) {
        this();
        if (iNDArray == null) {
            throw new NullPointerException("featureMean");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("featureStd");
        }
        if (iNDArray3 == null) {
            throw new NullPointerException("labelMean");
        }
        if (iNDArray4 == null) {
            throw new NullPointerException("labelStd");
        }
        setFeatureStats(new DistributionStats(iNDArray, iNDArray2));
        setLabelStats(new DistributionStats(iNDArray3, iNDArray4));
        fitLabel(true);
    }

    public NormalizerStandardize() {
        super(new StandardizeStrategy());
    }

    public void setLabelStats(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("labelMean");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("labelStd");
        }
        setLabelStats(new DistributionStats(iNDArray, iNDArray2));
    }

    public INDArray getMean() {
        return getFeatureStats().getMean();
    }

    public INDArray getLabelMean() {
        return getLabelStats().getMean();
    }

    public INDArray getStd() {
        return getFeatureStats().getStd();
    }

    public INDArray getLabelStd() {
        return getLabelStats().getStd();
    }

    public void load(File... fileArr) throws IOException {
        setFeatureStats(DistributionStats.load(fileArr[0], fileArr[1]));
        if (isFitLabel()) {
            setLabelStats(DistributionStats.load(fileArr[2], fileArr[3]));
        }
    }

    public void save(File... fileArr) throws IOException {
        getFeatureStats().save(fileArr[0], fileArr[1]);
        if (isFitLabel()) {
            getLabelStats().save(fileArr[2], fileArr[3]);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer
    protected NormalizerStats.Builder newBuilder() {
        return new DistributionStats.Builder();
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public NormalizerType getType() {
        return NormalizerType.STANDARDIZE;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof NormalizerStandardize) && ((NormalizerStandardize) obj).canEqual(this) && super.equals(obj);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer
    protected boolean canEqual(Object obj) {
        return obj instanceof NormalizerStandardize;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer
    public int hashCode() {
        return (1 * 59) + super.hashCode();
    }
}
