package org.deeplearning4j.spark.parameterserver.functions;

import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.input.PortableDataStream;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.parameterserver.callbacks.MultiDataSetDeserializationCallback;
import org.deeplearning4j.spark.parameterserver.callbacks.PortableDataStreamMDSCallback;
import org.deeplearning4j.spark.parameterserver.iterators.MultiPdsIterator;
import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;

/* compiled from: SharedFlatMapMultiPDS.java */
/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiPDSAdapter.class */
class SharedFlatMapMultiPDSAdapter<R extends TrainingResult> implements FlatMapFunctionAdapter<Iterator<PortableDataStream>, R> {
    protected final SharedTrainingWorker worker;
    protected final PortableDataStreamMDSCallback callback;

    public SharedFlatMapMultiPDSAdapter(TrainingWorker<R> trainingWorker) {
        this(trainingWorker, null);
    }

    public SharedFlatMapMultiPDSAdapter(TrainingWorker<R> trainingWorker, PortableDataStreamMDSCallback portableDataStreamMDSCallback) {
        this.worker = (SharedTrainingWorker) trainingWorker;
        if (portableDataStreamMDSCallback == null) {
            this.callback = new MultiDataSetDeserializationCallback();
        } else {
            this.callback = portableDataStreamMDSCallback;
        }
    }

    public Iterable<R> call(Iterator<PortableDataStream> it) throws Exception {
        SharedTrainingWrapper.getInstance().attachMDS(new MultiPdsIterator(it, this.callback));
        return Collections.singletonList(SharedTrainingWrapper.getInstance().run(this.worker));
    }
}
