package org.deeplearning4j.spark.datavec;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/spark/datavec/RDDMiniBatches.class */
public class RDDMiniBatches implements Serializable {
    private int miniBatches;
    private JavaRDD<DataSet> toSplitJava;

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/RDDMiniBatches$MiniBatchFunction.class */
    public static class MiniBatchFunction extends BaseFlatMapFunctionAdaptee<Iterator<DataSet>, DataSet> {
        public MiniBatchFunction(int i) {
            super(new MiniBatchFunctionAdapter(i));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/RDDMiniBatches$MiniBatchFunctionAdapter.class */
    static class MiniBatchFunctionAdapter implements FlatMapFunctionAdapter<Iterator<DataSet>, DataSet> {
        private int batchSize;

        public MiniBatchFunctionAdapter(int i) {
            this.batchSize = 10;
            this.batchSize = i;
        }

        public Iterable<DataSet> call(Iterator<DataSet> it) throws Exception {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            while (it.hasNext()) {
                arrayList2.add(it.next().copy());
                if (arrayList2.size() == this.batchSize) {
                    arrayList.add(DataSet.merge(arrayList2));
                    arrayList2.clear();
                }
            }
            if (arrayList2.size() > 1) {
                arrayList.add(DataSet.merge(arrayList2));
            }
            return arrayList;
        }
    }

    public RDDMiniBatches(int i, JavaRDD<DataSet> javaRDD) {
        this.miniBatches = 10;
        this.miniBatches = i;
        this.toSplitJava = javaRDD;
    }

    public JavaRDD<DataSet> miniBatchesJava() {
        return this.toSplitJava.mapPartitions(new MiniBatchFunction(this.miniBatches));
    }
}
