package org.deeplearning4j.datasets.iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.class */
public class MultipleEpochsIterator implements DataSetIterator {

    @VisibleForTesting
    protected int epochs;
    protected int numEpochs;
    protected int batch;
    protected int lastBatch;
    protected DataSetIterator iter;
    protected DataSet ds;
    protected List<DataSet> batchedDS;
    protected static final Logger log = LoggerFactory.getLogger(MultipleEpochsIterator.class);
    protected DataSetPreProcessor preProcessor;
    protected boolean newEpoch;
    protected AtomicLong iterationsCounter;
    protected long totalIterations;

    public MultipleEpochsIterator(int i, DataSetIterator dataSetIterator) {
        this.epochs = 0;
        this.batch = 0;
        this.lastBatch = this.batch;
        this.batchedDS = Lists.newArrayList();
        this.newEpoch = false;
        this.iterationsCounter = new AtomicLong(0L);
        this.totalIterations = Long.MAX_VALUE;
        this.numEpochs = i;
        this.iter = dataSetIterator;
    }

    @Deprecated
    public MultipleEpochsIterator(int i, DataSetIterator dataSetIterator, int i2) {
        this.epochs = 0;
        this.batch = 0;
        this.lastBatch = this.batch;
        this.batchedDS = Lists.newArrayList();
        this.newEpoch = false;
        this.iterationsCounter = new AtomicLong(0L);
        this.totalIterations = Long.MAX_VALUE;
        this.numEpochs = i;
        this.iter = dataSetIterator;
    }

    @Deprecated
    public MultipleEpochsIterator(DataSetIterator dataSetIterator, int i, long j) {
        this.epochs = 0;
        this.batch = 0;
        this.lastBatch = this.batch;
        this.batchedDS = Lists.newArrayList();
        this.newEpoch = false;
        this.iterationsCounter = new AtomicLong(0L);
        this.totalIterations = Long.MAX_VALUE;
        this.numEpochs = Integer.MAX_VALUE;
        this.iter = dataSetIterator;
        this.totalIterations = j;
    }

    public MultipleEpochsIterator(DataSetIterator dataSetIterator, long j) {
        this.epochs = 0;
        this.batch = 0;
        this.lastBatch = this.batch;
        this.batchedDS = Lists.newArrayList();
        this.newEpoch = false;
        this.iterationsCounter = new AtomicLong(0L);
        this.totalIterations = Long.MAX_VALUE;
        this.numEpochs = Integer.MAX_VALUE;
        this.iter = dataSetIterator;
        this.totalIterations = j;
    }

    public MultipleEpochsIterator(int i, DataSet dataSet) {
        this.epochs = 0;
        this.batch = 0;
        this.lastBatch = this.batch;
        this.batchedDS = Lists.newArrayList();
        this.newEpoch = false;
        this.iterationsCounter = new AtomicLong(0L);
        this.totalIterations = Long.MAX_VALUE;
        this.numEpochs = i;
        this.ds = dataSet;
    }

    public DataSet next(int i) {
        DataSet next;
        if (!hasNext()) {
            throw new NoSuchElementException("No next element");
        }
        this.batch++;
        this.iterationsCounter.incrementAndGet();
        if (this.iter != null) {
            next = i == -1 ? (DataSet) this.iter.next() : this.iter.next(i);
            if (next == null) {
                throw new IllegalStateException("Iterator returned null DataSet");
            }
            if (!this.iter.hasNext()) {
                trackEpochs();
                if (this.epochs < this.numEpochs) {
                    this.iter.reset();
                    this.lastBatch = this.batch;
                    this.batch = 0;
                }
            }
        } else if (i == -1) {
            next = this.ds;
            if (this.epochs < this.numEpochs) {
                trackEpochs();
            }
        } else {
            if (this.batchedDS.isEmpty() && i > 0) {
                this.batchedDS = this.ds.batchBy(i);
            }
            next = this.batchedDS.get(this.batch);
            if (this.batch + 1 == this.batchedDS.size()) {
                trackEpochs();
                if (this.epochs < this.numEpochs) {
                    this.batch = -1;
                }
            }
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(next);
        }
        return next;
    }

    public void trackEpochs() {
        this.epochs++;
        this.newEpoch = true;
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m23next() {
        return next(-1);
    }

    public int totalExamples() {
        return this.iter.totalExamples();
    }

    public int inputColumns() {
        return this.iter.inputColumns();
    }

    public int totalOutcomes() {
        return this.iter.totalOutcomes();
    }

    public boolean resetSupported() {
        return this.iter.resetSupported();
    }

    public boolean asyncSupported() {
        return this.iter.asyncSupported();
    }

    public void reset() {
        if (!this.iter.resetSupported()) {
            throw new IllegalStateException("Cannot reset MultipleEpochsIterator with base iter that does not support reset");
        }
        this.epochs = 0;
        this.lastBatch = this.batch;
        this.batch = 0;
        this.iterationsCounter.set(0L);
        this.iter.reset();
    }

    public int batch() {
        return this.iter.batch();
    }

    public int cursor() {
        return this.iter.cursor();
    }

    public int numExamples() {
        return this.iter.numExamples();
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    public List<String> getLabels() {
        return this.iter.getLabels();
    }

    public boolean hasNext() {
        if (this.iterationsCounter.get() >= this.totalIterations) {
            return false;
        }
        if (this.newEpoch) {
            log.info("Epoch " + this.epochs + ", number of batches completed " + this.lastBatch);
            this.newEpoch = false;
        }
        return this.iter == null ? this.epochs < this.numEpochs && ((!this.batchedDS.isEmpty() && this.batchedDS.size() > this.batch) || this.batchedDS.isEmpty()) : this.epochs < this.numEpochs || (this.iter.hasNext() && (this.epochs == 0 || this.epochs == this.numEpochs));
    }

    public void remove() {
        this.iter.remove();
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }
}
