package org.deeplearning4j.spark.parameterserver.functions;

import java.util.Collections;
import java.util.Iterator;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;
import org.nd4j.linalg.dataset.DataSet;

/* compiled from: SharedFlatMapDataSet.java */
/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSetAdapter.class */
class SharedFlatMapDataSetAdapter<R extends TrainingResult> implements FlatMapFunctionAdapter<Iterator<DataSet>, R> {
    private final SharedTrainingWorker worker;

    public SharedFlatMapDataSetAdapter(TrainingWorker<R> trainingWorker) {
        this.worker = (SharedTrainingWorker) trainingWorker;
    }

    public Iterable<R> call(Iterator<DataSet> it) throws Exception {
        SharedTrainingWrapper.getInstance().attachDS(it);
        return Collections.singletonList(SharedTrainingWrapper.getInstance().run(this.worker));
    }
}
