package org.deeplearning4j.datasets.iterator.parallel;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/parallel/JointParallelDataSetIterator.class */
public class JointParallelDataSetIterator extends BaseParallelDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(JointParallelDataSetIterator.class);
    protected List<DataSetIterator> asyncIterators;
    protected boolean enforceSingleDevice;
    protected int bufferSizePerDevice;

    /* loaded from: input_file:org/deeplearning4j/datasets/iterator/parallel/JointParallelDataSetIterator$Builder.class */
    public static class Builder {
        private List<DataSetIterator> iterators = new ArrayList();
        private boolean enforceSingleDevice = true;
        private int bufferSize = 4;
        private InequalityHandling inequalityHandling;

        public Builder(@NonNull InequalityHandling inequalityHandling) {
            if (inequalityHandling == null) {
                throw new NullPointerException("inequalityHandling is marked @NonNull but is null");
            }
            this.inequalityHandling = inequalityHandling;
        }

        public Builder(@NonNull List<DataSetIterator> list, @NonNull InequalityHandling inequalityHandling) {
            if (list == null) {
                throw new NullPointerException("iterators is marked @NonNull but is null");
            }
            if (inequalityHandling == null) {
                throw new NullPointerException("inequalityHandling is marked @NonNull but is null");
            }
            this.inequalityHandling = inequalityHandling;
            Iterator<DataSetIterator> it = list.iterator();
            while (it.hasNext()) {
                addSourceIterator(it.next());
            }
        }

        public Builder addSourceIterator(@NonNull DataSetIterator dataSetIterator) {
            if (dataSetIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            if (!dataSetIterator.asyncSupported()) {
                throw new IllegalArgumentException("Source iterators should support async mode");
            }
            if (hasIterator(dataSetIterator)) {
                throw new IllegalArgumentException("You can't put equal iterators into this joint iterator");
            }
            this.iterators.add(dataSetIterator);
            return this;
        }

        protected boolean hasIterator(DataSetIterator dataSetIterator) {
            Iterator<DataSetIterator> it = this.iterators.iterator();
            while (it.hasNext()) {
                if (it.next() == dataSetIterator) {
                    return true;
                }
            }
            return false;
        }

        public Builder setBufferSizePerSplit(int i) {
            this.bufferSize = i;
            return this;
        }

        public Builder enforceSingleDevice(boolean z) {
            this.enforceSingleDevice = z;
            return this;
        }

        public JointParallelDataSetIterator build() {
            return new JointParallelDataSetIterator(this.iterators, this.enforceSingleDevice, this.bufferSize, this.inequalityHandling);
        }
    }

    public JointParallelDataSetIterator(@NonNull List<DataSetIterator> list, boolean z, int i, @NonNull InequalityHandling inequalityHandling) {
        super(list.size());
        this.asyncIterators = new ArrayList();
        if (list == null) {
            throw new NullPointerException("iterators is marked @NonNull but is null");
        }
        if (inequalityHandling == null) {
            throw new NullPointerException("inequalityHandling is marked @NonNull but is null");
        }
        this.enforceSingleDevice = z;
        this.bufferSizePerDevice = i;
        this.numProducers = list.size();
        this.inequalityHandling = inequalityHandling;
        if (this.numProducers == 0) {
            throw new IllegalArgumentException("You can't start ParallelDataSetIterator without input data");
        }
        initializeIterators(list);
    }

    protected void initializeIterators(List<DataSetIterator> list) {
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        if (list.size() % numberOfDevices != 0) {
            log.error("WARNING: number of splits doesn't match number of devices!");
        }
        int i = 0;
        Iterator<DataSetIterator> it = list.iterator();
        while (it.hasNext()) {
            this.asyncIterators.add(new AsyncDataSetIterator(it.next(), this.bufferSizePerDevice, true, Integer.valueOf(i % numberOfDevices)));
            i++;
        }
    }

    @Override // org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator
    public boolean hasNextFor(int i) {
        if (i >= this.numProducers || i < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return this.asyncIterators.get(i).hasNext();
    }

    @Override // org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator
    public DataSet nextFor(int i) {
        if (i >= this.numProducers || i < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return (DataSet) this.asyncIterators.get(i).next();
    }

    @Override // org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator
    protected void reset(int i) {
        if (i >= this.numProducers || i < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        this.asyncIterators.get(i).reset();
    }
}
