package org.deeplearning4j.datasets.iterator.parallel;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.IOFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.deeplearning4j.datasets.iterator.FileSplitDataSetIterator;
import org.deeplearning4j.datasets.iterator.callbacks.FileCallback;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.class */
public class FileSplitParallelDataSetIterator extends BaseParallelDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(FileSplitParallelDataSetIterator.class);
    public static final String DEFAULT_PATTERN = "dataset-%d.bin";
    private String pattern;
    private int buffer;
    protected List<DataSetIterator> asyncIterators;

    public FileSplitParallelDataSetIterator(@NonNull File file, @NonNull String str, @NonNull FileCallback fileCallback) {
        this(file, str, fileCallback, Nd4j.getAffinityManager().getNumberOfDevices());
        if (file == null) {
            throw new NullPointerException("rootFolder is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("pattern is marked non-null but is null");
        }
        if (fileCallback == null) {
            throw new NullPointerException("callback is marked non-null but is null");
        }
    }

    public FileSplitParallelDataSetIterator(@NonNull File file, @NonNull String str, @NonNull FileCallback fileCallback, int i) {
        this(file, str, fileCallback, i, InequalityHandling.STOP_EVERYONE);
        if (file == null) {
            throw new NullPointerException("rootFolder is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("pattern is marked non-null but is null");
        }
        if (fileCallback == null) {
            throw new NullPointerException("callback is marked non-null but is null");
        }
    }

    public FileSplitParallelDataSetIterator(@NonNull File file, @NonNull String str, @NonNull FileCallback fileCallback, int i, @NonNull InequalityHandling inequalityHandling) {
        this(file, str, fileCallback, i, 2, inequalityHandling);
        if (file == null) {
            throw new NullPointerException("rootFolder is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("pattern is marked non-null but is null");
        }
        if (fileCallback == null) {
            throw new NullPointerException("callback is marked non-null but is null");
        }
        if (inequalityHandling == null) {
            throw new NullPointerException("inequalityHandling is marked non-null but is null");
        }
    }

    public FileSplitParallelDataSetIterator(@NonNull File file, @NonNull String str, @NonNull FileCallback fileCallback, int i, int i2, @NonNull InequalityHandling inequalityHandling) {
        super(i);
        this.asyncIterators = new ArrayList();
        if (file == null) {
            throw new NullPointerException("rootFolder is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("pattern is marked non-null but is null");
        }
        if (fileCallback == null) {
            throw new NullPointerException("callback is marked non-null but is null");
        }
        if (inequalityHandling == null) {
            throw new NullPointerException("inequalityHandling is marked non-null but is null");
        }
        if (!file.exists() || !file.isDirectory()) {
            throw new IllegalArgumentException("Root folder should point to existing folder");
        }
        this.pattern = str;
        this.inequalityHandling = inequalityHandling;
        this.buffer = i2;
        ArrayList arrayList = new ArrayList(FileUtils.listFiles(file, new RegexFileFilter(str.replaceAll("\\%d", ".*.")), (IOFileFilter) null));
        log.debug("Files found: {}; Producers: {}", Integer.valueOf(arrayList.size()), Integer.valueOf(this.numProducers));
        if (arrayList.isEmpty()) {
            throw new IllegalArgumentException("No suitable files were found");
        }
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int i3 = 0;
        for (List list : Lists.partition(arrayList, arrayList.size() / i)) {
            if (i3 >= i) {
                return;
            }
            this.asyncIterators.add(new AsyncDataSetIterator(new FileSplitDataSetIterator(list, fileCallback), i2, true, Integer.valueOf(i3 % numberOfDevices)));
            i3++;
        }
    }

    @Override // org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator
    public boolean hasNextFor(int i) {
        if (i >= this.numProducers || i < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return this.asyncIterators.get(i).hasNext();
    }

    @Override // org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator
    public DataSet nextFor(int i) {
        if (i >= this.numProducers || i < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return (DataSet) this.asyncIterators.get(i).next();
    }

    @Override // org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator
    protected void reset(int i) {
        if (i >= this.numProducers || i < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        this.asyncIterators.get(i).reset();
    }
}
