package org.deeplearning4j.datasets.datavec;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.records.Record;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.class */
public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, Serializable {
    private int batchSize;
    private AlignmentMode alignmentMode;
    private Map<String, RecordReader> recordReaders;
    private Map<String, SequenceRecordReader> sequenceRecordReaders;
    private List<SubsetDetails> inputs;
    private List<SubsetDetails> outputs;
    private boolean collectMetaData;
    private boolean timeSeriesRandomOffset;
    private Random timeSeriesRandomOffsetRng;
    private MultiDataSetPreProcessor preProcessor;
    private boolean resetSupported;

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator$AlignmentMode.class */
    public enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END
    }

    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator$Builder.class */
    public static class Builder {
        private int batchSize;
        private AlignmentMode alignmentMode = AlignmentMode.ALIGN_START;
        private Map<String, RecordReader> recordReaders = new HashMap();
        private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap();
        private List<SubsetDetails> inputs = new ArrayList();
        private List<SubsetDetails> outputs = new ArrayList();
        private boolean timeSeriesRandomOffset = false;
        private long timeSeriesRandomOffsetSeed = System.currentTimeMillis();

        public Builder(int i) {
            this.batchSize = i;
        }

        public Builder addReader(String str, RecordReader recordReader) {
            this.recordReaders.put(str, recordReader);
            return this;
        }

        public Builder addSequenceReader(String str, SequenceRecordReader sequenceRecordReader) {
            this.sequenceRecordReaders.put(str, sequenceRecordReader);
            return this;
        }

        public Builder sequenceAlignmentMode(AlignmentMode alignmentMode) {
            this.alignmentMode = alignmentMode;
            return this;
        }

        public Builder addInput(String str) {
            this.inputs.add(new SubsetDetails(str, true, false, -1, -1, -1));
            return this;
        }

        public Builder addInput(String str, int i, int i2) {
            this.inputs.add(new SubsetDetails(str, false, false, -1, i, i2));
            return this;
        }

        public Builder addInputOneHot(String str, int i, int i2) {
            this.inputs.add(new SubsetDetails(str, false, true, i2, i, i));
            return this;
        }

        public Builder addOutput(String str) {
            this.outputs.add(new SubsetDetails(str, true, false, -1, -1, -1));
            return this;
        }

        public Builder addOutput(String str, int i, int i2) {
            this.outputs.add(new SubsetDetails(str, false, false, -1, i, i2));
            return this;
        }

        public Builder addOutputOneHot(String str, int i, int i2) {
            this.outputs.add(new SubsetDetails(str, false, true, i2, i, i));
            return this;
        }

        public Builder timeSeriesRandomOffset(boolean z, long j) {
            this.timeSeriesRandomOffset = z;
            this.timeSeriesRandomOffsetSeed = j;
            return this;
        }

        public RecordReaderMultiDataSetIterator build() {
            if (this.recordReaders.isEmpty() && this.sequenceRecordReaders.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no readers");
            }
            if (this.batchSize <= 0) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with batch size <= 0");
            }
            if (this.inputs.isEmpty() && this.outputs.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no inputs/outputs");
            }
            for (SubsetDetails subsetDetails : this.inputs) {
                if (!this.recordReaders.containsKey(subsetDetails.readerName) && !this.sequenceRecordReaders.containsKey(subsetDetails.readerName)) {
                    throw new IllegalStateException("Invalid input name: \"" + subsetDetails.readerName + "\" - no reader found with this name");
                }
            }
            for (SubsetDetails subsetDetails2 : this.outputs) {
                if (!this.recordReaders.containsKey(subsetDetails2.readerName) && !this.sequenceRecordReaders.containsKey(subsetDetails2.readerName)) {
                    throw new IllegalStateException("Invalid output name: \"" + subsetDetails2.readerName + "\" - no reader found with this name");
                }
            }
            return new RecordReaderMultiDataSetIterator(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator$SubsetDetails.class */
    public static class SubsetDetails implements Serializable {
        private final String readerName;
        private final boolean entireReader;
        private final boolean oneHot;
        private final int oneHotNumClasses;
        private final int subsetStart;
        private final int subsetEndInclusive;

        public SubsetDetails(String str, boolean z, boolean z2, int i, int i2, int i3) {
            this.readerName = str;
            this.entireReader = z;
            this.oneHot = z2;
            this.oneHotNumClasses = i;
            this.subsetStart = i2;
            this.subsetEndInclusive = i3;
        }
    }

    private RecordReaderMultiDataSetIterator(Builder builder) {
        this.recordReaders = new HashMap();
        this.sequenceRecordReaders = new HashMap();
        this.inputs = new ArrayList();
        this.outputs = new ArrayList();
        this.collectMetaData = false;
        this.timeSeriesRandomOffset = false;
        this.resetSupported = true;
        this.batchSize = builder.batchSize;
        this.alignmentMode = builder.alignmentMode;
        this.recordReaders = builder.recordReaders;
        this.sequenceRecordReaders = builder.sequenceRecordReaders;
        this.inputs.addAll(builder.inputs);
        this.outputs.addAll(builder.outputs);
        this.timeSeriesRandomOffset = builder.timeSeriesRandomOffset;
        if (this.timeSeriesRandomOffset) {
            this.timeSeriesRandomOffsetRng = new Random(builder.timeSeriesRandomOffsetSeed);
        }
        if (this.recordReaders != null) {
            Iterator<RecordReader> it = this.recordReaders.values().iterator();
            while (it.hasNext()) {
                this.resetSupported &= it.next().resetSupported();
            }
        }
        if (this.sequenceRecordReaders != null) {
            Iterator<SequenceRecordReader> it2 = this.sequenceRecordReaders.values().iterator();
            while (it2.hasNext()) {
                this.resetSupported &= it2.next().resetSupported();
            }
        }
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public MultiDataSet m2next() {
        return next(this.batchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported");
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v120, types: [java.util.List] */
    /* JADX WARN: Type inference failed for: r6v0, types: [org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator] */
    public MultiDataSet next(int i) {
        List sequenceRecord;
        List next;
        ArrayList arrayList;
        if (!hasNext()) {
            throw new NoSuchElementException("No next elements");
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = null;
        HashMap hashMap3 = new HashMap();
        ArrayList arrayList2 = this.collectMetaData ? new ArrayList() : null;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            RecordReader value = entry.getValue();
            if (this.collectMetaData || !value.batchesSupported()) {
                ArrayList arrayList3 = new ArrayList(Math.min(i, 100000));
                for (int i2 = 0; i2 < i && value.hasNext(); i2++) {
                    if (this.collectMetaData) {
                        Record nextRecord = value.nextRecord();
                        next = nextRecord.getRecord();
                        if (arrayList2.size() <= i2) {
                            arrayList2.add(new RecordMetaDataComposableMap(new HashMap()));
                        }
                        ((RecordMetaDataComposableMap) arrayList2.get(i2)).getMeta().put(entry.getKey(), nextRecord.getMetaData());
                    } else {
                        next = value.next();
                    }
                    arrayList3.add(next);
                }
                hashMap.put(entry.getKey(), arrayList3);
            } else {
                NDArrayRecordBatch next2 = value.next(i);
                if (next2 instanceof NDArrayRecordBatch) {
                    arrayList = next2.getArrays();
                } else {
                    List<List<Writable>> filterRequiredColumns = filterRequiredColumns(entry.getKey(), next2);
                    arrayList = new ArrayList();
                    ArrayList arrayList4 = new ArrayList();
                    int size = filterRequiredColumns.get(0).size();
                    for (int i3 = 0; i3 < size; i3++) {
                        arrayList4.clear();
                        for (int i4 = 0; i4 < filterRequiredColumns.size(); i4++) {
                            arrayList4.add(filterRequiredColumns.get(i4).get(i3));
                        }
                        arrayList.add(RecordConverter.toMinibatchArray(arrayList4));
                    }
                }
                if (hashMap2 == null) {
                    hashMap2 = new HashMap();
                }
                hashMap2.put(entry.getKey(), arrayList);
            }
        }
        for (Map.Entry<String, SequenceRecordReader> entry2 : this.sequenceRecordReaders.entrySet()) {
            SequenceRecordReader value2 = entry2.getValue();
            ArrayList arrayList5 = new ArrayList(i);
            for (int i5 = 0; i5 < i && value2.hasNext(); i5++) {
                if (this.collectMetaData) {
                    SequenceRecord nextSequence = value2.nextSequence();
                    sequenceRecord = nextSequence.getSequenceRecord();
                    if (arrayList2.size() <= i5) {
                        arrayList2.add(new RecordMetaDataComposableMap(new HashMap()));
                    }
                    ((RecordMetaDataComposableMap) arrayList2.get(i5)).getMeta().put(entry2.getKey(), nextSequence.getMetaData());
                } else {
                    sequenceRecord = value2.sequenceRecord();
                }
                arrayList5.add(sequenceRecord);
            }
            hashMap3.put(entry2.getKey(), arrayList5);
        }
        return nextMultiDataSet(hashMap, hashMap2, hashMap3, arrayList2);
    }

    private List<List<Writable>> filterRequiredColumns(String str, List<List<Writable>> list) {
        boolean z = false;
        ArrayList<SubsetDetails> arrayList = null;
        int i = -1;
        int i2 = Integer.MAX_VALUE;
        Iterator it = Arrays.asList(this.inputs, this.outputs).iterator();
        while (it.hasNext()) {
            Iterator it2 = ((List) it.next()).iterator();
            while (true) {
                if (it2.hasNext()) {
                    SubsetDetails subsetDetails = (SubsetDetails) it2.next();
                    if (str.equals(subsetDetails.readerName)) {
                        if (subsetDetails.entireReader) {
                            z = true;
                            break;
                        }
                        if (arrayList == null) {
                            arrayList = new ArrayList();
                        }
                        arrayList.add(subsetDetails);
                        i = Math.max(i, subsetDetails.subsetEndInclusive);
                        i2 = Math.min(i2, subsetDetails.subsetStart);
                    }
                }
            }
        }
        if (z) {
            return list;
        }
        if (arrayList == null) {
            throw new IllegalStateException("Found no usages of reader: " + str);
        }
        boolean[] zArr = new boolean[i + 1];
        for (SubsetDetails subsetDetails2 : arrayList) {
            for (int i3 = subsetDetails2.subsetStart; i3 <= subsetDetails2.subsetEndInclusive; i3++) {
                zArr[i3] = true;
            }
        }
        ArrayList arrayList2 = new ArrayList();
        IntWritable intWritable = new IntWritable(0);
        for (List<Writable> list2 : list) {
            ArrayList arrayList3 = new ArrayList(list2.size());
            for (int i4 = 0; i4 < list2.size(); i4++) {
                if (i4 >= zArr.length || !zArr[i4]) {
                    arrayList3.add(intWritable);
                } else {
                    arrayList3.add(list2.get(i4));
                }
            }
            arrayList2.add(arrayList3);
        }
        return arrayList2;
    }

    public MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> map, Map<String, List<INDArray>> map2, Map<String, List<List<List<Writable>>>> map3, List<RecordMetaDataComposableMap> list) {
        int i = Integer.MAX_VALUE;
        Iterator<List<List<Writable>>> it = map.values().iterator();
        while (it.hasNext()) {
            i = Math.min(i, it.next().size());
        }
        if (map2 != null) {
            Iterator<List<INDArray>> it2 = map2.values().iterator();
            while (it2.hasNext()) {
                Iterator<INDArray> it3 = it2.next().iterator();
                while (it3.hasNext()) {
                    i = (int) Math.min(i, it3.next().size(0));
                }
            }
        }
        Iterator<List<List<List<Writable>>>> it4 = map3.values().iterator();
        while (it4.hasNext()) {
            i = Math.min(i, it4.next().size());
        }
        if (i == Integer.MAX_VALUE) {
            throw new RuntimeException("Error occurred during data set generation: no readers?");
        }
        int[] iArr = null;
        if (this.timeSeriesRandomOffset || this.alignmentMode == AlignmentMode.ALIGN_END) {
            iArr = new int[i];
            Iterator<Map.Entry<String, List<List<List<Writable>>>>> it5 = map3.entrySet().iterator();
            while (it5.hasNext()) {
                List<List<List<Writable>>> value = it5.next().getValue();
                for (int i2 = 0; i2 < value.size() && i2 < i; i2++) {
                    iArr[i2] = Math.max(iArr[i2], value.get(i2).size());
                }
            }
        }
        int i3 = -1;
        if (this.alignmentMode != AlignmentMode.EQUAL_LENGTH) {
            Iterator<Map.Entry<String, List<List<List<Writable>>>>> it6 = map3.entrySet().iterator();
            while (it6.hasNext()) {
                Iterator<List<List<Writable>>> it7 = it6.next().getValue().iterator();
                while (it7.hasNext()) {
                    i3 = Math.max(i3, it7.next().size());
                }
            }
        }
        long nextLong = this.timeSeriesRandomOffset ? this.timeSeriesRandomOffsetRng.nextLong() : -1L;
        Pair<INDArray[], INDArray[]> convertFeaturesOrLabels = convertFeaturesOrLabels(new INDArray[this.inputs.size()], new INDArray[this.inputs.size()], this.inputs, i, map, map2, map3, i3, iArr, nextLong);
        Pair<INDArray[], INDArray[]> convertFeaturesOrLabels2 = convertFeaturesOrLabels(new INDArray[this.outputs.size()], new INDArray[this.outputs.size()], this.outputs, i, map, map2, map3, i3, iArr, nextLong);
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet((INDArray[]) convertFeaturesOrLabels.getFirst(), (INDArray[]) convertFeaturesOrLabels2.getFirst(), (INDArray[]) convertFeaturesOrLabels.getSecond(), (INDArray[]) convertFeaturesOrLabels2.getSecond());
        if (this.collectMetaData) {
            multiDataSet.setExampleMetaData(list);
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(multiDataSet);
        }
        return multiDataSet;
    }

    private Pair<INDArray[], INDArray[]> convertFeaturesOrLabels(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, List<SubsetDetails> list, int i, Map<String, List<List<Writable>>> map, Map<String, List<INDArray>> map2, Map<String, List<List<List<Writable>>>> map3, int i2, int[] iArr, long j) {
        boolean z = false;
        int i3 = 0;
        for (SubsetDetails subsetDetails : list) {
            if (map2 != null && map2.containsKey(subsetDetails.readerName)) {
                iNDArrayArr[i3] = convertWritablesBatched(map2.get(subsetDetails.readerName), subsetDetails);
            } else if (map.containsKey(subsetDetails.readerName)) {
                iNDArrayArr[i3] = convertWritables(map.get(subsetDetails.readerName), i, subsetDetails);
            } else {
                Pair<INDArray, INDArray> convertWritablesSequence = convertWritablesSequence(map3.get(subsetDetails.readerName), i, i2, subsetDetails, iArr, j);
                iNDArrayArr[i3] = (INDArray) convertWritablesSequence.getFirst();
                iNDArrayArr2[i3] = (INDArray) convertWritablesSequence.getSecond();
                if (iNDArrayArr2[i3] != null) {
                    z = true;
                }
            }
            i3++;
        }
        return new Pair<>(iNDArrayArr, z ? iNDArrayArr2 : null);
    }

    private INDArray convertWritablesBatched(List<INDArray> list, SubsetDetails subsetDetails) {
        INDArray iNDArray;
        if (subsetDetails.entireReader) {
            iNDArray = list.size() == 1 ? list.get(0) : Nd4j.concat(1, (INDArray[]) list.toArray(new INDArray[list.size()]));
        } else if (subsetDetails.subsetStart == subsetDetails.subsetEndInclusive || subsetDetails.oneHot) {
            iNDArray = list.get(subsetDetails.subsetStart);
        } else {
            INDArray[] iNDArrayArr = new INDArray[(subsetDetails.subsetEndInclusive - subsetDetails.subsetStart) + 1];
            int i = 0;
            for (int i2 = subsetDetails.subsetStart; i2 <= subsetDetails.subsetEndInclusive; i2++) {
                int i3 = i;
                i++;
                iNDArrayArr[i3] = list.get(i2);
            }
            iNDArray = Nd4j.concat(1, iNDArrayArr);
        }
        if (!subsetDetails.oneHot || iNDArray.size(1) == subsetDetails.oneHotNumClasses) {
            return iNDArray;
        }
        if (iNDArray.size(1) != 1) {
            throw new UnsupportedOperationException("Cannot do conversion to one hot using batched reader: " + subsetDetails.oneHotNumClasses + " output classes, but array.size(1) is " + iNDArray.size(1) + " (must be equal to 1 or numClasses = " + subsetDetails.oneHotNumClasses + ")");
        }
        long size = iNDArray.size(0);
        INDArray create = Nd4j.create(new long[]{size, subsetDetails.oneHotNumClasses});
        for (int i4 = 0; i4 < size; i4++) {
            create.putScalar(i4, iNDArray.getInt(new int[]{i4, 0}), 1.0d);
        }
        return create;
    }

    private int countLength(List<Writable> list) {
        return countLength(list, 0, list.size() - 1);
    }

    private int countLength(List<Writable> list, int i, int i2) {
        int i3 = 0;
        for (int i4 = i; i4 <= i2; i4++) {
            NDArrayWritable nDArrayWritable = (Writable) list.get(i4);
            if (nDArrayWritable instanceof NDArrayWritable) {
                INDArray iNDArray = nDArrayWritable.get();
                if (!iNDArray.isRowVectorOrScalar()) {
                    throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is not a row vector. Can only concat row vectors with other writables. Shape: " + Arrays.toString(iNDArray.shape()));
                }
                i3 = (int) (i3 + iNDArray.length());
            } else {
                i3++;
            }
        }
        return i3;
    }

    private INDArray convertWritables(List<List<Writable>> list, int i, SubsetDetails subsetDetails) {
        try {
            return convertWritablesHelper(list, i, subsetDetails);
        } catch (IllegalStateException e) {
            throw e;
        } catch (NumberFormatException e2) {
            throw new RuntimeException("Error parsing data (writables) from record readers - value is non-numeric", e2);
        } catch (Throwable th) {
            throw new RuntimeException("Error parsing data (writables) from record readers", th);
        }
    }

    private INDArray convertWritablesHelper(List<List<Writable>> list, int i, SubsetDetails subsetDetails) {
        INDArray create;
        if (subsetDetails.entireReader) {
            if (list.get(0).size() == 1 && (list.get(0).get(0) instanceof NDArrayWritable)) {
                long[] clone = ArrayUtils.clone(list.get(0).get(0).get().shape());
                clone[0] = i;
                create = Nd4j.create(clone);
            } else {
                create = Nd4j.create(i, countLength(list.get(0)));
            }
        } else if (subsetDetails.oneHot) {
            create = Nd4j.zeros(i, subsetDetails.oneHotNumClasses);
        } else if (subsetDetails.subsetStart == subsetDetails.subsetEndInclusive && (list.get(0).get(subsetDetails.subsetStart) instanceof NDArrayWritable)) {
            long[] clone2 = ArrayUtils.clone(list.get(0).get(subsetDetails.subsetStart).get().shape());
            clone2[0] = i;
            create = Nd4j.create(clone2);
        } else {
            create = Nd4j.create(i, countLength(list.get(0), subsetDetails.subsetStart, subsetDetails.subsetEndInclusive));
        }
        for (int i2 = 0; i2 < i; i2++) {
            List<Writable> list2 = list.get(i2);
            if (subsetDetails.entireReader) {
                putExample(create, RecordConverter.toArray(list2), i2);
            } else if (subsetDetails.oneHot) {
                int i3 = list2.get(subsetDetails.subsetStart).toInt();
                if (i3 >= subsetDetails.oneHotNumClasses) {
                    throw new IllegalStateException("Cannot convert sequence writables to one-hot: class index " + i3 + " >= numClass (" + subsetDetails.oneHotNumClasses + "). (Note that classes are zero-indexed, thus only values 0 to nClasses-1 are valid)");
                }
                create.putScalar(i2, r0.toInt(), 1.0d);
            } else if (subsetDetails.subsetStart == subsetDetails.subsetEndInclusive && (list2.get(subsetDetails.subsetStart) instanceof NDArrayWritable)) {
                putExample(create, list2.get(subsetDetails.subsetStart).get(), i2);
            } else {
                Iterator<Writable> it = list2.iterator();
                for (int i4 = 0; i4 < subsetDetails.subsetStart; i4++) {
                    it.next();
                }
                int i5 = 0;
                for (int i6 = subsetDetails.subsetStart; i6 <= subsetDetails.subsetEndInclusive; i6++) {
                    NDArrayWritable nDArrayWritable = (Writable) it.next();
                    if (nDArrayWritable instanceof NDArrayWritable) {
                        INDArray iNDArray = nDArrayWritable.get();
                        create.put(new INDArrayIndex[]{NDArrayIndex.point(i2), NDArrayIndex.interval(i5, i5 + iNDArray.length())}, iNDArray);
                        i5 = (int) (i5 + iNDArray.length());
                    } else {
                        create.putScalar(i2, i5, nDArrayWritable.toDouble());
                        i5++;
                    }
                }
            }
        }
        return create;
    }

    private void putExample(INDArray iNDArray, INDArray iNDArray2, int i) {
        Preconditions.checkState(iNDArray2.size(0) == 1 && iNDArray2.rank() == iNDArray.rank(), "Cannot put array: array should have leading dimension of 1 and equal rank to output array. Attempting to put array of shape %s into output array of shape %s", iNDArray2.shape(), iNDArray.shape());
        long[] shape = iNDArray.shape();
        long[] shape2 = iNDArray2.shape();
        for (int i2 = 1; i2 < iNDArray.rank(); i2++) {
            Preconditions.checkState(shape[i2] == shape2[i2], "Single example array and output arrays differ at position %s:single example shape %s, output array shape %s", Integer.valueOf(i2), shape2, shape);
        }
        switch (iNDArray.rank()) {
            case 2:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all()}, iNDArray2);
                return;
            case 3:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArray2);
                return;
            case 4:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArray2);
                return;
            case 5:
                iNDArray.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArray2);
                return;
            default:
                throw new RuntimeException("Unexpected array rank: " + iNDArray.rank() + " with shape " + Arrays.toString(iNDArray.shape()) + " input arrays should be rank 2 to 5 inclusive");
        }
    }

    private Pair<INDArray, INDArray> convertWritablesSequence(List<List<List<Writable>>> list, int i, int i2, SubsetDetails subsetDetails, int[] iArr, long j) {
        if (i2 == -1) {
            i2 = list.get(0).size();
        }
        if (list.get(0).isEmpty()) {
            throw new ZeroLengthSequenceException("Zero length sequence encountered");
        }
        List<Writable> list2 = list.get(0).get(0);
        int i3 = 0;
        if (subsetDetails.entireReader) {
            Iterator<Writable> it = list2.iterator();
            while (it.hasNext()) {
                NDArrayWritable nDArrayWritable = (Writable) it.next();
                i3 = nDArrayWritable instanceof NDArrayWritable ? (int) (i3 + nDArrayWritable.get().size(1)) : i3 + 1;
            }
        } else if (subsetDetails.oneHot) {
            i3 = subsetDetails.oneHotNumClasses;
        } else {
            for (int i4 = subsetDetails.subsetStart; i4 <= subsetDetails.subsetEndInclusive; i4++) {
                NDArrayWritable nDArrayWritable2 = (Writable) list2.get(i4);
                i3 = nDArrayWritable2 instanceof NDArrayWritable ? (int) (i3 + nDArrayWritable2.get().size(1)) : i3 + 1;
            }
        }
        INDArray create = Nd4j.create(new int[]{i, i3, i2}, 'f');
        boolean z = false;
        Iterator<List<List<Writable>>> it2 = list.iterator();
        while (it2.hasNext()) {
            if (it2.next().size() < i2) {
                z = true;
            }
        }
        if (z && this.alignmentMode == AlignmentMode.EQUAL_LENGTH) {
            throw new UnsupportedOperationException("Alignment mode is set to EQUAL_LENGTH but variable length data was encountered. Use AlignmentMode.ALIGN_START or AlignmentMode.ALIGN_END with variable length data");
        }
        INDArray ones = z ? Nd4j.ones(i, i2) : null;
        Random random = this.timeSeriesRandomOffset ? new Random(j) : null;
        for (int i5 = 0; i5 < i; i5++) {
            List<List<Writable>> list3 = list.get(i5);
            int size = (this.alignmentMode == AlignmentMode.ALIGN_START || this.alignmentMode == AlignmentMode.EQUAL_LENGTH) ? 0 : iArr[i5] - list3.size();
            if (this.timeSeriesRandomOffset) {
                size = random.nextInt((i2 - list3.size()) + 1);
            }
            int i6 = 0;
            for (List<Writable> list4 : list3) {
                int i7 = i6;
                i6++;
                int i8 = size + i7;
                if (subsetDetails.entireReader) {
                    Iterator<Writable> it3 = list4.iterator();
                    int i9 = 0;
                    while (it3.hasNext()) {
                        NDArrayWritable nDArrayWritable3 = (Writable) it3.next();
                        if (nDArrayWritable3 instanceof NDArrayWritable) {
                            INDArray iNDArray = nDArrayWritable3.get();
                            create.put(new INDArrayIndex[]{NDArrayIndex.point(i5), NDArrayIndex.interval(i9, i9 + iNDArray.length()), NDArrayIndex.point(i8)}, iNDArray);
                            i9 = (int) (i9 + iNDArray.length());
                        } else {
                            create.putScalar(i5, i9, i8, nDArrayWritable3.toDouble());
                            i9++;
                        }
                    }
                } else if (subsetDetails.oneHot) {
                    Writable writable = null;
                    if (list4 instanceof List) {
                        writable = list4.get(subsetDetails.subsetStart);
                    } else {
                        Iterator<Writable> it4 = list4.iterator();
                        for (int i10 = 0; i10 <= subsetDetails.subsetStart; i10++) {
                            writable = it4.next();
                        }
                    }
                    int i11 = writable.toInt();
                    if (i11 >= subsetDetails.oneHotNumClasses) {
                        throw new IllegalStateException("Cannot convert sequence writables to one-hot: class index " + i11 + " >= numClass (" + subsetDetails.oneHotNumClasses + "). (Note that classes are zero-indexed, thus only values 0 to nClasses-1 are valid)");
                    }
                    create.putScalar(i5, i11, i8, 1.0d);
                } else {
                    int i12 = 0;
                    for (int i13 = subsetDetails.subsetStart; i13 <= subsetDetails.subsetEndInclusive; i13++) {
                        NDArrayWritable nDArrayWritable4 = (Writable) list4.get(i13);
                        if (nDArrayWritable4 instanceof NDArrayWritable) {
                            INDArray iNDArray2 = nDArrayWritable4.get();
                            create.put(new INDArrayIndex[]{NDArrayIndex.point(i5), NDArrayIndex.interval(i12, i12 + iNDArray2.length()), NDArrayIndex.point(i8)}, iNDArray2);
                            i12 = (int) (i12 + iNDArray2.length());
                        } else {
                            int i14 = i12;
                            i12++;
                            create.putScalar(i5, i14, i8, nDArrayWritable4.toDouble());
                        }
                    }
                }
            }
            if (z) {
                if (this.timeSeriesRandomOffset || this.alignmentMode == AlignmentMode.ALIGN_END) {
                    for (int i15 = 0; i15 < size; i15++) {
                        ones.putScalar(i5, i15, 0.0d);
                    }
                }
                int size2 = size + list3.size();
                if (this.timeSeriesRandomOffset || this.alignmentMode == AlignmentMode.ALIGN_START || size2 < i2) {
                    for (int i16 = size2; i16 < i2; i16++) {
                        ones.putScalar(i5, i16, 0.0d);
                    }
                }
            }
        }
        return new Pair<>(create, ones);
    }

    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.preProcessor = multiDataSetPreProcessor;
    }

    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public boolean resetSupported() {
        return this.resetSupported;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        if (!this.resetSupported) {
            throw new IllegalStateException("Cannot reset iterator - reset not supported (resetSupported() == false): one or more underlying (sequence) record readers do not support resetting");
        }
        Iterator<RecordReader> it = this.recordReaders.values().iterator();
        while (it.hasNext()) {
            it.next().reset();
        }
        Iterator<SequenceRecordReader> it2 = this.sequenceRecordReaders.values().iterator();
        while (it2.hasNext()) {
            it2.next().reset();
        }
    }

    public boolean hasNext() {
        Iterator<RecordReader> it = this.recordReaders.values().iterator();
        while (it.hasNext()) {
            if (!it.next().hasNext()) {
                return false;
            }
        }
        Iterator<SequenceRecordReader> it2 = this.sequenceRecordReaders.values().iterator();
        while (it2.hasNext()) {
            if (!it2.next().hasNext()) {
                return false;
            }
        }
        return true;
    }

    public MultiDataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        Map<String, List<List<Writable>>> hashMap = new HashMap<>();
        Map<String, List<List<List<Writable>>>> hashMap2 = new HashMap<>();
        List<RecordMetaDataComposableMap> arrayList = this.collectMetaData ? new ArrayList<>() : null;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            RecordReader value = entry.getValue();
            ArrayList arrayList2 = new ArrayList();
            Iterator<RecordMetaData> it = list.iterator();
            while (it.hasNext()) {
                arrayList2.add(((RecordMetaData) it.next()).getMeta().get(entry.getKey()));
            }
            List loadFromMetaData = value.loadFromMetaData(arrayList2);
            List<List<Writable>> arrayList3 = new ArrayList<>(list.size());
            Iterator it2 = loadFromMetaData.iterator();
            while (it2.hasNext()) {
                arrayList3.add(((Record) it2.next()).getRecord());
            }
            hashMap.put(entry.getKey(), arrayList3);
        }
        for (Map.Entry<String, SequenceRecordReader> entry2 : this.sequenceRecordReaders.entrySet()) {
            SequenceRecordReader value2 = entry2.getValue();
            ArrayList arrayList4 = new ArrayList();
            Iterator<RecordMetaData> it3 = list.iterator();
            while (it3.hasNext()) {
                arrayList4.add(((RecordMetaData) it3.next()).getMeta().get(entry2.getKey()));
            }
            List loadSequenceFromMetaData = value2.loadSequenceFromMetaData(arrayList4);
            List<List<List<Writable>>> arrayList5 = new ArrayList<>(list.size());
            Iterator it4 = loadSequenceFromMetaData.iterator();
            while (it4.hasNext()) {
                arrayList5.add(((SequenceRecord) it4.next()).getSequenceRecord());
            }
            hashMap2.put(entry2.getKey(), arrayList5);
        }
        return nextMultiDataSet(hashMap, null, hashMap2, arrayList);
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public AlignmentMode getAlignmentMode() {
        return this.alignmentMode;
    }

    public Map<String, RecordReader> getRecordReaders() {
        return this.recordReaders;
    }

    public Map<String, SequenceRecordReader> getSequenceRecordReaders() {
        return this.sequenceRecordReaders;
    }

    public List<SubsetDetails> getInputs() {
        return this.inputs;
    }

    public List<SubsetDetails> getOutputs() {
        return this.outputs;
    }

    public boolean isTimeSeriesRandomOffset() {
        return this.timeSeriesRandomOffset;
    }

    public Random getTimeSeriesRandomOffsetRng() {
        return this.timeSeriesRandomOffsetRng;
    }

    public boolean isResetSupported() {
        return this.resetSupported;
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public void setCollectMetaData(boolean z) {
        this.collectMetaData = z;
    }
}
