package org.deeplearning4j.datasets.fetchers;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.zip.Adler32;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.MathUtils;

/* loaded from: input_file:org/deeplearning4j/datasets/fetchers/MnistDataFetcher.class */
public class MnistDataFetcher extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;
    protected transient MnistManager man;
    protected boolean binarize;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;
    protected boolean oneIndexed;
    protected boolean fOrder;
    protected boolean firstShuffle;
    protected final int numExamples;
    protected static final long CHECKSUM_TRAIN_FEATURES = 2094436111;
    protected static final long CHECKSUM_TRAIN_LABELS = 4008842612L;
    protected static final long[] CHECKSUMS_TRAIN = {CHECKSUM_TRAIN_FEATURES, CHECKSUM_TRAIN_LABELS};
    protected static final long CHECKSUM_TEST_FEATURES = 2165396896L;
    protected static final long CHECKSUM_TEST_LABELS = 2212998611L;
    protected static final long[] CHECKSUMS_TEST = {CHECKSUM_TEST_FEATURES, CHECKSUM_TEST_LABELS};

    public MnistDataFetcher(boolean z) throws IOException {
        this(z, true, true, System.currentTimeMillis(), NUM_EXAMPLES);
    }

    public MnistDataFetcher(boolean z, boolean z2, boolean z3, long j, int i) throws IOException {
        String concat;
        String concat2;
        long[] jArr;
        this.binarize = true;
        this.oneIndexed = false;
        this.fOrder = false;
        this.firstShuffle = true;
        if (!mnistExists()) {
            new MnistFetcher().downloadAndUntar();
        }
        String absolutePath = DL4JResources.getDirectory(ResourceType.DATASET, "MNIST").getAbsolutePath();
        if (z2) {
            concat = FilenameUtils.concat(absolutePath, MnistFetcher.TRAINING_FILES_FILENAME_UNZIPPED);
            concat2 = FilenameUtils.concat(absolutePath, MnistFetcher.TRAINING_FILE_LABELS_FILENAME_UNZIPPED);
            this.totalExamples = NUM_EXAMPLES;
            jArr = CHECKSUMS_TRAIN;
        } else {
            concat = FilenameUtils.concat(absolutePath, MnistFetcher.TEST_FILES_FILENAME_UNZIPPED);
            concat2 = FilenameUtils.concat(absolutePath, MnistFetcher.TEST_FILE_LABELS_FILENAME_UNZIPPED);
            this.totalExamples = NUM_EXAMPLES_TEST;
            jArr = CHECKSUMS_TEST;
        }
        String[] strArr = {concat, concat2};
        try {
            this.man = new MnistManager(concat, concat2, z2);
            validateFiles(strArr, jArr);
        } catch (Exception e) {
            try {
                FileUtils.deleteDirectory(new File(absolutePath));
            } catch (Exception e2) {
            }
            new MnistFetcher().downloadAndUntar();
            this.man = new MnistManager(concat, concat2, z2);
            validateFiles(strArr, jArr);
        }
        this.numOutcomes = 10;
        this.binarize = z;
        this.cursor = 0;
        this.inputColumns = this.man.getImages().getEntryLength();
        this.train = z2;
        this.shuffle = z3;
        if (z2) {
            this.order = new int[NUM_EXAMPLES];
        } else {
            this.order = new int[NUM_EXAMPLES_TEST];
        }
        for (int i2 = 0; i2 < this.order.length; i2++) {
            this.order[i2] = i2;
        }
        this.rng = new Random(j);
        this.numExamples = i;
        reset();
    }

    private boolean mnistExists() {
        String absolutePath = DL4JResources.getDirectory(ResourceType.DATASET, "MNIST").getAbsolutePath();
        return new File(absolutePath, MnistFetcher.TRAINING_FILES_FILENAME_UNZIPPED).exists() && new File(absolutePath, MnistFetcher.TRAINING_FILE_LABELS_FILENAME_UNZIPPED).exists() && new File(absolutePath, MnistFetcher.TEST_FILES_FILENAME_UNZIPPED).exists() && new File(absolutePath, MnistFetcher.TEST_FILE_LABELS_FILENAME_UNZIPPED).exists();
    }

    private void validateFiles(String[] strArr, long[] jArr) {
        for (int i = 0; i < strArr.length; i++) {
            try {
                File file = new File(strArr[i]);
                long value = file.exists() ? FileUtils.checksum(file, new Adler32()).getValue() : -1L;
                if (!file.exists() || value != jArr[i]) {
                    throw new IllegalStateException("Failed checksum: expected " + jArr[i] + ", got " + value + " for file: " + file);
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    public MnistDataFetcher() throws IOException {
        this(true);
    }

    public void fetch(int i) {
        if (!hasMore()) {
            throw new IllegalStateException("Unable to get more; there are no more images");
        }
        float[][] fArr = new float[i][0];
        float[][] fArr2 = new float[i][0];
        int i2 = 0;
        byte[] bArr = null;
        int i3 = 0;
        while (i3 < i && hasMore()) {
            byte[] readImageUnsafe = this.man.readImageUnsafe(this.order[this.cursor]);
            if (this.fOrder) {
                if (bArr == null) {
                    bArr = new byte[784];
                }
                for (int i4 = 0; i4 < 784; i4++) {
                    bArr[i4] = readImageUnsafe[(28 * (i4 % 28)) + (i4 / 28)];
                }
                readImageUnsafe = bArr;
            }
            int readLabel = this.man.readLabel(this.order[this.cursor]);
            if (this.oneIndexed) {
                readLabel--;
            }
            float[] fArr3 = new float[readImageUnsafe.length];
            fArr[i2] = fArr3;
            fArr2[i2] = new float[this.numOutcomes];
            fArr2[i2][readLabel] = 1.0f;
            for (int i5 = 0; i5 < readImageUnsafe.length; i5++) {
                float f = readImageUnsafe[i5] & 255;
                if (!this.binarize) {
                    fArr3[i5] = f / 255.0f;
                } else if (f > 30.0f) {
                    fArr3[i5] = 1.0f;
                } else {
                    fArr3[i5] = 0.0f;
                }
            }
            i2++;
            i3++;
            this.cursor++;
        }
        if (i2 < i) {
            fArr = (float[][]) Arrays.copyOfRange(fArr, 0, i2);
            fArr2 = (float[][]) Arrays.copyOfRange(fArr2, 0, i2);
        }
        this.curr = new DataSet(Nd4j.create(fArr), Nd4j.create(fArr2));
    }

    public void reset() {
        this.cursor = 0;
        this.curr = null;
        if (this.shuffle) {
            if ((!this.train || this.numExamples >= 60000) && (this.train || this.numExamples >= 10000)) {
                MathUtils.shuffleArray(this.order, this.rng);
            } else if (!this.firstShuffle) {
                MathUtils.shuffleArraySubset(this.order, this.numExamples, this.rng);
            } else {
                MathUtils.shuffleArray(this.order, this.rng);
                this.firstShuffle = false;
            }
        }
    }

    public DataSet next() {
        return super.next();
    }
}
