package org.datavec.image.loader;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.EqualizeHistTransform;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/datavec/image/loader/CifarLoader.class */
public class CifarLoader extends NativeImageLoader implements Serializable {
    public static final int NUM_TRAIN_IMAGES = 50000;
    public static final int NUM_TEST_IMAGES = 10000;
    public static final int NUM_LABELS = 10;
    public static final int HEIGHT = 32;
    public static final int WIDTH = 32;
    public static final int CHANNELS = 3;
    public static final boolean DEFAULT_USE_SPECIAL_PREPROC = false;
    public static final boolean DEFAULT_SHUFFLE = true;
    private static final int BYTEFILELEN = 3073;
    private static final String TESTFILENAME = "test_batch.bin";
    private static final String dataBinUrl = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz";
    private static final String localDir = "cifar";
    private static final String dataBinFile = "cifar-10-batches-bin";
    private static final String labelFileName = "batches.meta.txt";
    private static final int numToConvertDS = 10000;
    protected final File fullDir;
    protected final File meanVarPath;
    protected final String trainFilesSerialized;
    protected final String testFilesSerialized;
    protected InputStream inputStream;
    protected InputStream trainInputStream;
    protected InputStream testInputStream;
    protected List<String> labels;
    protected boolean train;
    protected boolean useSpecialPreProcessCifar;
    protected long seed;
    protected boolean shuffle;
    protected int numExamples;
    protected double uMean;
    protected double uStd;
    protected double vMean;
    protected double vStd;
    protected boolean meanStdStored;
    protected int loadDSIndex;
    protected DataSet loadDS;
    protected int fileNum;
    private static final String[] TRAINFILENAMES = {"data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch5.bin"};
    public static Map<String, String> cifarDataMap = new HashMap();

    private static File getDefaultDirectory() {
        return new File(BASE_DIR, FilenameUtils.concat(localDir, dataBinFile));
    }

    public CifarLoader() {
        this(true);
    }

    public CifarLoader(boolean z) {
        this(z, null);
    }

    public CifarLoader(boolean z, File file) {
        this(32, 32, 3, null, z, false, file, System.currentTimeMillis(), true);
    }

    public CifarLoader(int i, int i2, int i3, boolean z, boolean z2) {
        this(i, i2, i3, null, z, z2);
    }

    public CifarLoader(int i, int i2, int i3, ImageTransform imageTransform, boolean z, boolean z2) {
        this(i, i2, i3, imageTransform, z, z2, true);
    }

    public CifarLoader(int i, int i2, int i3, ImageTransform imageTransform, boolean z, boolean z2, boolean z3) {
        this(i, i2, i3, imageTransform, z, z2, null, System.currentTimeMillis(), z3);
    }

    public CifarLoader(int i, int i2, int i3, ImageTransform imageTransform, boolean z, boolean z2, File file, long j, boolean z3) {
        super(i, i2, i3, imageTransform);
        this.labels = new ArrayList();
        this.shuffle = true;
        this.numExamples = 0;
        this.uMean = 0.0d;
        this.uStd = 0.0d;
        this.vMean = 0.0d;
        this.vStd = 0.0d;
        this.meanStdStored = false;
        this.loadDSIndex = 0;
        this.loadDS = new DataSet();
        this.fileNum = 0;
        this.height = i;
        this.width = i2;
        this.channels = i3;
        this.train = z;
        this.useSpecialPreProcessCifar = z2;
        this.seed = j;
        this.shuffle = z3;
        if (file == null) {
            this.fullDir = getDefaultDirectory();
        } else {
            this.fullDir = file;
        }
        this.meanVarPath = new File(this.fullDir, "meanVarPath.txt");
        this.trainFilesSerialized = FilenameUtils.concat(this.fullDir.toString(), "cifar_train_serialized");
        this.testFilesSerialized = FilenameUtils.concat(this.fullDir.toString(), "cifar_test_serialized.ser");
        load();
    }

    @Override // org.datavec.image.loader.NativeImageLoader, org.datavec.image.loader.BaseImageLoader
    public INDArray asRowVector(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // org.datavec.image.loader.NativeImageLoader, org.datavec.image.loader.BaseImageLoader
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // org.datavec.image.loader.NativeImageLoader, org.datavec.image.loader.BaseImageLoader
    public INDArray asMatrix(File file) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // org.datavec.image.loader.NativeImageLoader, org.datavec.image.loader.BaseImageLoader
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        throw new UnsupportedOperationException();
    }

    protected void generateMaps() {
        cifarDataMap.put("filesFilename", new File(dataBinUrl).getName());
        cifarDataMap.put("filesURL", dataBinUrl);
        cifarDataMap.put("filesFilenameUnzipped", dataBinFile);
    }

    private void defineLabels() {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(this.fullDir, labelFileName)));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                } else {
                    this.labels.add(readLine);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    protected void load() {
        if (!cifarRawFilesExist() && !this.fullDir.exists()) {
            generateMaps();
            this.fullDir.mkdir();
            log.info("Downloading CIFAR data set");
            downloadAndUntar(cifarDataMap, new File(BASE_DIR, localDir));
        }
        try {
            Iterator it = FileUtils.listFiles(this.fullDir, new String[]{"bin"}, true).iterator();
            this.trainInputStream = new SequenceInputStream(new FileInputStream((File) it.next()), new FileInputStream((File) it.next()));
            while (it.hasNext()) {
                File file = (File) it.next();
                if (!TESTFILENAME.equals(file.getName())) {
                    this.trainInputStream = new SequenceInputStream(this.trainInputStream, new FileInputStream(file));
                }
            }
            this.testInputStream = new FileInputStream(new File(this.fullDir, TESTFILENAME));
            if (this.labels.isEmpty()) {
                defineLabels();
            }
            if (this.useSpecialPreProcessCifar && this.train && !cifarProcessedFilesExists()) {
                for (int i = this.fileNum + 1; i <= TRAINFILENAMES.length; i++) {
                    this.inputStream = this.trainInputStream;
                    convertDataSet(10000).save(new File(this.trainFilesSerialized + i + ".ser"));
                }
                this.inputStream = this.testInputStream;
                convertDataSet(10000).save(new File(this.testFilesSerialized));
            }
            setInputStream();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private boolean cifarRawFilesExist() {
        if (!new File(this.fullDir, TESTFILENAME).exists()) {
            return false;
        }
        for (String str : TRAINFILENAMES) {
            if (!new File(this.fullDir, str).exists()) {
                return false;
            }
        }
        return true;
    }

    private boolean cifarProcessedFilesExists() {
        return this.train ? new File(new StringBuilder().append(this.trainFilesSerialized).append(1).append(".ser").toString()).exists() : new File(this.testFilesSerialized).exists();
    }

    public Mat convertCifar(Mat mat) {
        this.numExamples++;
        Mat mat2 = new Mat();
        OpenCVFrameConverter.ToMat toMat = new OpenCVFrameConverter.ToMat();
        ColorConversionTransform colorConversionTransform = new ColorConversionTransform(new Random(this.seed), 36);
        EqualizeHistTransform equalizeHistTransform = new EqualizeHistTransform(new Random(this.seed), 36);
        if (toMat != null) {
            mat2 = toMat.convert(equalizeHistTransform.transform(colorConversionTransform.transform(new ImageWritable(toMat.convert(mat)))).getFrame());
        }
        return mat2;
    }

    public void normalizeCifar(File file) {
        DataSet dataSet = new DataSet();
        dataSet.load(file);
        if (!this.meanStdStored && this.train) {
            this.uMean = Math.abs(this.uMean / this.numExamples);
            this.uStd = Math.sqrt(this.uStd);
            this.vMean = Math.abs(this.vMean / this.numExamples);
            this.vStd = Math.sqrt(this.vStd);
            try {
                FileUtils.write(this.meanVarPath, this.uMean + "," + this.uStd + "," + this.vMean + "," + this.vStd);
            } catch (IOException e) {
                e.printStackTrace();
            }
            this.meanStdStored = true;
        } else if (this.uMean == 0.0d && this.meanStdStored) {
            try {
                String[] split = FileUtils.readFileToString(this.meanVarPath).split(",");
                this.uMean = Double.parseDouble(split[0]);
                this.uStd = Double.parseDouble(split[1]);
                this.vMean = Double.parseDouble(split[2]);
                this.vStd = Double.parseDouble(split[3]);
            } catch (IOException e2) {
                e2.printStackTrace();
            }
        }
        for (int i = 0; i < dataSet.numExamples(); i++) {
            INDArray features = dataSet.get(i).getFeatures();
            features.tensorAlongDimension(0L, new int[]{0, 2, 3}).divi(255);
            features.tensorAlongDimension(1L, new int[]{0, 2, 3}).subi(Double.valueOf(this.uMean)).divi(Double.valueOf(this.uStd));
            features.tensorAlongDimension(2L, new int[]{0, 2, 3}).subi(Double.valueOf(this.vMean)).divi(Double.valueOf(this.vStd));
            dataSet.get(i).setFeatures(features);
        }
        dataSet.save(file);
    }

    public Pair<INDArray, Mat> convertMat(byte[] bArr) {
        INDArray outcomeVector = FeatureUtil.toOutcomeVector(bArr[0], 10L);
        Mat mat = new Mat(32, 32, opencv_core.CV_8UC(3));
        ByteBuffer byteBuffer = (ByteBuffer) mat.createBuffer();
        for (int i = 0; i < 1024; i++) {
            byteBuffer.put(3 * i, bArr[i + 1 + 2048]);
            byteBuffer.put((3 * i) + 1, bArr[i + 1 + 1024]);
            byteBuffer.put((3 * i) + 2, bArr[i + 1]);
        }
        return new Pair<>(outcomeVector, mat);
    }

    public DataSet convertDataSet(int i) {
        ArrayList arrayList = new ArrayList();
        byte[] bArr = new byte[BYTEFILELEN];
        for (int i2 = 0; i2 != i; i2++) {
            try {
                if (this.inputStream.read(bArr) == -1) {
                    break;
                }
                Pair<INDArray, Mat> convertMat = convertMat(bArr);
                try {
                    arrayList.add(new DataSet(asMatrix((Mat) convertMat.getSecond()), (INDArray) convertMat.getFirst()));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            } catch (IOException e2) {
                e2.printStackTrace();
            }
        }
        if (arrayList.size() == 0) {
            return new DataSet();
        }
        DataSet merge = DataSet.merge(arrayList);
        Iterator it = merge.iterator();
        while (it.hasNext()) {
            DataSet dataSet = (DataSet) it.next();
            try {
                if (this.useSpecialPreProcessCifar) {
                    INDArray tensorAlongDimension = dataSet.getFeatures().tensorAlongDimension(1L, new int[]{0, 2, 3});
                    INDArray tensorAlongDimension2 = dataSet.getFeatures().tensorAlongDimension(2L, new int[]{0, 2, 3});
                    double doubleValue = tensorAlongDimension.meanNumber().doubleValue();
                    this.uStd += varManual(tensorAlongDimension, doubleValue);
                    this.uMean += doubleValue;
                    double doubleValue2 = tensorAlongDimension2.meanNumber().doubleValue();
                    this.vStd += varManual(tensorAlongDimension2, doubleValue2);
                    this.vMean += doubleValue2;
                    dataSet.setFeatures(dataSet.getFeatures().div(255));
                } else {
                    dataSet.setFeatures(dataSet.getFeatures().div(255));
                }
            } catch (IllegalArgumentException e3) {
                throw new IllegalStateException("The number of channels must be 3 to special preProcess Cifar with.");
            }
        }
        if (this.shuffle && i > 1) {
            merge.shuffle(this.seed);
        }
        return merge;
    }

    public double varManual(INDArray iNDArray, double d) {
        INDArray sub = iNDArray.sub(Double.valueOf(d));
        return Nd4j.getExecutioner().execAndReturn(new Sum(sub.muli(sub), new int[0])).getFinalResult().doubleValue() / iNDArray.ravel().length();
    }

    public DataSet next(int i) {
        return next(i, 0);
    }

    public DataSet next(int i, int i2) {
        DataSet convertDataSet;
        ArrayList arrayList = new ArrayList();
        if (cifarProcessedFilesExists() && this.useSpecialPreProcessCifar) {
            if (i2 == 0 || (i2 / this.fileNum == 10000 && this.train)) {
                this.fileNum++;
                if (this.train) {
                    this.loadDS.load(new File(this.trainFilesSerialized + this.fileNum + ".ser"));
                }
                this.loadDS.load(new File(this.testFilesSerialized));
                if (this.shuffle && i > 1) {
                    this.loadDS.shuffle(this.seed);
                }
                this.loadDSIndex = 0;
            }
            for (int i3 = 0; i3 < i && this.loadDS.get(this.loadDSIndex) != null; i3++) {
                arrayList.add(this.loadDS.get(this.loadDSIndex));
                this.loadDSIndex++;
            }
            convertDataSet = arrayList.size() > 1 ? DataSet.merge(arrayList) : (DataSet) arrayList.get(0);
        } else {
            convertDataSet = convertDataSet(i);
        }
        return convertDataSet;
    }

    public InputStream getInputStream() {
        return this.inputStream;
    }

    public void setInputStream() {
        if (this.train) {
            this.inputStream = this.trainInputStream;
        } else {
            this.inputStream = this.testInputStream;
        }
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void reset() {
        this.numExamples = 0;
        this.fileNum = 0;
        load();
    }
}
