package org.deeplearning4j.spark.parameterserver.training;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationFunction;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationTuple;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAggregateFunction;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapDataSet;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapMultiDataSet;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapMultiPDS;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPDS;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPaths;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPathsMDS;
import org.deeplearning4j.spark.parameterserver.networking.SilentTrainingDriver;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.enums.TransportType;
import org.nd4j.parameterserver.distributed.transport.MulticastTransport;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.class */
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker> implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(SharedTrainingMaster.class);
    protected List<TrainingHook> trainingHooks;
    protected VoidConfiguration voidConfiguration;
    protected Integer numWorkers;
    protected Integer numWorkersPerNode;
    protected RDDTrainingApproach rddTrainingApproach;
    protected StorageLevel storageLevel;
    protected boolean collectTrainingStats;
    protected int rddDataSetNumExamples;
    protected long debugLongerIterations;
    protected double threshold;
    protected double thresholdStep;
    protected double minThreshold;
    protected double stepTrigger;
    protected int stepDelay;
    protected int shakeFrequency;
    protected Repartition repartition;
    protected RepartitionStrategy repartitionStrategy;
    protected ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    protected Random rng;
    protected AtomicBoolean isFirstRun;
    protected transient Broadcast<NetBroadcastTuple> broadcastModel;
    protected transient Broadcast<SharedTrainingConfiguration> broadcastConfiguration;
    protected transient Transport transport;
    protected transient SilentTrainingDriver trainingDriver;

    /* loaded from: input_file:org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster$Builder.class */
    public static class Builder {
        protected double threshold;
        protected double thresholdStep;
        protected double minThreshold;
        protected double stepTrigger;
        protected int stepDelay;
        protected int shakeFrequency;
        protected Repartition repartition;
        protected RepartitionStrategy repartitionStrategy;
        protected StorageLevel storageLevel;
        protected StorageLevel storageLevelStreams;
        protected VoidConfiguration voidConfiguration;
        protected RDDTrainingApproach rddTrainingApproach;
        protected long rngSeed;
        protected String exportDirectory;
        protected Integer numWorkers;
        protected boolean collectTrainingStats;
        protected Transport transport;
        protected int batchSize;
        protected long debugLongerIterations;
        protected int numWorkersPerNode;

        public Builder(int i) {
            this(0.001d, i);
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, int i) {
            this(voidConfiguration, 0.001d, i);
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration");
            }
        }

        public Builder(double d, int i) {
            this(VoidConfiguration.builder().executionMode(ExecutionMode.MANAGED).forcedRole(NodeRole.SHARD).controllerAddress(System.getenv("SPARK_PUBLIC_DNS")).build(), null, d, i);
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, double d, int i) {
            this(voidConfiguration, null, d, i);
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration");
            }
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, Integer num, double d, int i) {
            this.threshold = 0.001d;
            this.thresholdStep = 1.0E-5d;
            this.minThreshold = 1.0E-5d;
            this.stepTrigger = 0.05d;
            this.stepDelay = 50;
            this.shakeFrequency = 0;
            this.repartition = Repartition.Always;
            this.repartitionStrategy = RepartitionStrategy.Balanced;
            this.storageLevel = StorageLevel.MEMORY_ONLY_SER();
            this.storageLevelStreams = StorageLevel.MEMORY_ONLY();
            this.rddTrainingApproach = RDDTrainingApproach.Export;
            this.exportDirectory = null;
            this.debugLongerIterations = 0L;
            this.numWorkersPerNode = -1;
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration");
            }
            this.threshold = d;
            this.voidConfiguration = voidConfiguration;
            this.voidConfiguration.setExecutionMode(ExecutionMode.MANAGED);
        }

        public Builder collectTrainingStats(boolean z) {
            this.collectTrainingStats = z;
            return this;
        }

        public Builder repartitionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public Builder storageLevel(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        public Builder rddTrainingApproach(RDDTrainingApproach rDDTrainingApproach) {
            this.rddTrainingApproach = rDDTrainingApproach;
            return this;
        }

        public Builder exportDirectory(String str) {
            this.exportDirectory = str;
            return this;
        }

        public Builder rngSeed(long j) {
            this.rngSeed = j;
            return this;
        }

        public Builder updatesThreshold(double d) {
            this.threshold = d;
            return this;
        }

        public Builder minUpdatesThreshold(double d) {
            this.minThreshold = d;
            return this;
        }

        public Builder thresholdStep(double d) {
            if (d < 0.0d) {
                throw new DL4JInvalidConfigException("shakeFrequency should be non-negative value");
            }
            this.thresholdStep = d;
            return this;
        }

        public Builder stepTrigger(double d) {
            if (d < 0.0d || d > 100.0d) {
                throw new DL4JInvalidConfigException("stepTrigger value should be in range of 0..100");
            }
            return this;
        }

        public Builder stepDelay(int i) {
            this.stepDelay = i;
            return this;
        }

        public Builder shakeFrequency(int i) {
            if (i < 0) {
                throw new DL4JInvalidConfigException("shakeFrequency should be non-negative value");
            }
            if (i == 1) {
                SharedTrainingMaster.log.warn("shakeFrequency of 1 means that all updates will be sparse, and might lead to worse performance");
            }
            this.shakeFrequency = i;
            return this;
        }

        public Builder batchSizePerWorker(int i) {
            this.batchSize = i;
            return this;
        }

        public Builder workersPerNode(int i) {
            if (i < 1) {
                i = -1;
            }
            this.numWorkersPerNode = i;
            return this;
        }

        @Deprecated
        public Builder debugLongerIterations(long j) {
            if (j < 0) {
                j = 0;
            }
            this.debugLongerIterations = j;
            return this;
        }

        public Builder transport(Transport transport) {
            this.transport = transport;
            return this;
        }

        public SharedTrainingMaster build() {
            SharedTrainingMaster sharedTrainingMaster = new SharedTrainingMaster(this.voidConfiguration, this.numWorkers, this.rddTrainingApproach, this.storageLevel, true, this.repartitionStrategy, this.repartition, this.threshold, this.minThreshold, this.thresholdStep, this.stepTrigger, this.stepDelay, this.shakeFrequency, this.batchSize, this.debugLongerIterations, this.numWorkersPerNode);
            if (this.transport != null) {
                sharedTrainingMaster.transport = this.transport;
            }
            return sharedTrainingMaster;
        }
    }

    protected SharedTrainingMaster() {
        this.debugLongerIterations = 0L;
        this.stepTrigger = 0.05d;
        this.stepDelay = 50;
    }

    public SharedTrainingMaster(@NonNull VoidConfiguration voidConfiguration, Integer num, RDDTrainingApproach rDDTrainingApproach, StorageLevel storageLevel, boolean z, RepartitionStrategy repartitionStrategy, Repartition repartition, double d, double d2, double d3, double d4, int i, int i2, int i3, long j, int i4) {
        this.debugLongerIterations = 0L;
        this.stepTrigger = 0.05d;
        this.stepDelay = 50;
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration");
        }
        this.voidConfiguration = voidConfiguration;
        this.numWorkers = num;
        this.threshold = d;
        this.minThreshold = d2;
        this.thresholdStep = d3;
        this.stepTrigger = d4;
        this.stepDelay = i;
        this.rddTrainingApproach = rDDTrainingApproach;
        this.repartitionStrategy = repartitionStrategy;
        this.repartition = repartition;
        this.storageLevel = storageLevel;
        this.collectTrainingStats = z;
        this.isFirstRun = new AtomicBoolean(false);
        this.batchSizePerWorker = i3;
        this.rddDataSetNumExamples = i3;
        this.debugLongerIterations = j;
        this.numWorkersPerNode = Integer.valueOf(i4);
        if (z) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
    }

    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHooks != null) {
            this.trainingHooks.remove(trainingHook);
        }
    }

    public void addHook(@NonNull TrainingHook trainingHook) {
        if (trainingHook == null) {
            throw new NullPointerException("trainingHook");
        }
        if (this.trainingHooks == null) {
            this.trainingHooks = new ArrayList();
        }
        this.trainingHooks.add(trainingHook);
    }

    public String toJson() {
        try {
            return getJsonMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public String toYaml() {
        try {
            return getYamlMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public static SharedTrainingMaster fromJson(String str) {
        try {
            return (SharedTrainingMaster) getJsonMapper().readValue(str, SharedTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    public static SharedTrainingMaster fromYaml(String str) {
        try {
            return (SharedTrainingMaster) getYamlMapper().readValue(str, SharedTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }

    /* renamed from: getWorkerInstance, reason: merged with bridge method [inline-methods] */
    public SharedTrainingWorker m10getWorkerInstance(SparkDl4jMultiLayer sparkDl4jMultiLayer) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations(), sparkDl4jMultiLayer.getNetwork().params(), sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray());
        SharedTrainingConfiguration build = SharedTrainingConfiguration.builder().threshold(this.threshold).minThreshold(this.minThreshold).shakeFrequency(this.shakeFrequency).thresholdStep(this.thresholdStep).stepTrigger(this.stepTrigger).stepDelay(this.stepDelay).voidConfiguration(this.voidConfiguration).debugLongerIterations(this.debugLongerIterations).numberOfWorkersPerNode(this.numWorkersPerNode.intValue()).build();
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        if (this.broadcastModel == null) {
            this.broadcastModel = sparkDl4jMultiLayer.getSparkContext().broadcast(netBroadcastTuple);
        }
        if (this.broadcastConfiguration == null) {
            this.broadcastConfiguration = sparkDl4jMultiLayer.getSparkContext().broadcast(build);
        }
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new SharedTrainingWorker(this.broadcastModel, this.broadcastConfiguration);
    }

    /* renamed from: getWorkerInstance, reason: merged with bridge method [inline-methods] */
    public SharedTrainingWorker m9getWorkerInstance(SparkComputationGraph sparkComputationGraph) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkComputationGraph.getNetwork().getConfiguration(), sparkComputationGraph.getNetwork().params(), sparkComputationGraph.getNetwork().getUpdater().getStateViewArray());
        SharedTrainingConfiguration build = SharedTrainingConfiguration.builder().threshold(this.threshold).minThreshold(this.minThreshold).shakeFrequency(this.shakeFrequency).thresholdStep(this.thresholdStep).voidConfiguration(this.voidConfiguration).debugLongerIterations(this.debugLongerIterations).numberOfWorkersPerNode(this.numWorkersPerNode.intValue()).build();
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        if (this.broadcastModel == null) {
            this.broadcastModel = sparkComputationGraph.getSparkContext().broadcast(netBroadcastTuple);
        }
        if (this.broadcastConfiguration == null) {
            this.broadcastConfiguration = sparkComputationGraph.getSparkContext().broadcast(build);
        }
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new SharedTrainingWorker(this.broadcastModel, this.broadcastConfiguration);
    }

    protected int numObjectsEachWorker(int i) {
        return this.batchSizePerWorker / i;
    }

    protected int getNumDataSetObjectsPerSplit(int i) {
        int intValue;
        if (i == 1) {
            intValue = this.numWorkers.intValue() * this.batchSizePerWorker;
        } else {
            int numObjectsEachWorker = numObjectsEachWorker(i);
            if (numObjectsEachWorker < 1) {
                numObjectsEachWorker = 1;
            }
            intValue = numObjectsEachWorker * this.numWorkers.intValue();
        }
        return intValue;
    }

    protected <T> JavaRDD<T>[] getSplitRDDs(JavaRDD<T> javaRDD, int i, int i2) {
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(i2);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<T>[] balancedRandomSplit = SparkUtils.balancedRandomSplit(i, numDataSetObjectsPerSplit, javaRDD, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return balancedRandomSplit;
    }

    protected <T, Repr extends JavaRDDLike<T, Repr>> long getTotalDataSetObjectCount(JavaRDDLike<T, Repr> javaRDDLike) {
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDDLike.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        return count;
    }

    protected void executeTrainingDirect(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIteration(sparkDl4jMultiLayer, javaRDD, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingDirectMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIterationMDS(sparkComputationGraph, javaRDD, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingDirect(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIteration(sparkComputationGraph, javaRDD, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingPathsHelper(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<String> javaRDD, int i) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIterationPaths(sparkDl4jMultiLayer, null, javaRDD, 1, 1, i);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingPathsHelper(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIterationPaths(null, sparkComputationGraph, javaRDD, 1, 1, i);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingPathsMDSHelper(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        int i2 = 1 + 1;
        doIterationPathsMDS(sparkComputationGraph, javaRDD, 1, 1, i);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void prepareNetworkAndStuff(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph) {
        if (sparkDl4jMultiLayer == null && sparkComputationGraph == null) {
            throw new IllegalStateException("Both MLN & CG are undefined");
        }
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getSparkContext().defaultParallelism() : sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            this.voidConfiguration.setControllerAddress(System.getenv("SPARK_PUBLIC_DNS"));
        }
        if (this.voidConfiguration.getControllerAddress() == null && this.voidConfiguration.getNetworkMask() != null) {
            this.voidConfiguration.setControllerAddress(new NetworkOrganizer(this.voidConfiguration.getNetworkMask()).getMatchingAddress());
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            this.voidConfiguration.setControllerAddress(System.getenv("DL4J_VOID_IP"));
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            throw new DL4JInvalidConfigException("Can't get Spark Master local address. Please specify it manually using VoidConfiguration.setControllerAddress(String) method or VoidConfiguration.setNetworkMask(String) method");
        }
        log.info("Setting controller address to {}:{}", this.voidConfiguration.getControllerAddress(), Integer.valueOf(this.voidConfiguration.getUnicastPort()));
        this.voidConfiguration.setShardAddresses(new String[]{this.voidConfiguration.getControllerAddress()});
        this.voidConfiguration.setNumberOfShards(1);
        RoutedTransport routedTransport = this.voidConfiguration.getTransportType() == TransportType.ROUTED ? new RoutedTransport() : this.voidConfiguration.getTransportType() == TransportType.BROADCAST ? new MulticastTransport() : this.transport;
        if (routedTransport == null) {
            throw new DL4JInvalidConfigException("No Transport implementation was defined for this training session!");
        }
        if (sparkDl4jMultiLayer != null) {
            sparkDl4jMultiLayer.getNetwork().init();
        } else {
            sparkComputationGraph.getNetwork().init();
        }
        if (this.isFirstRun.compareAndSet(false, true)) {
            this.trainingDriver = new SilentTrainingDriver(sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getNetwork().params() : sparkComputationGraph.getNetwork().params(), sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getNetwork().getOptimizer().getStepFunction() : sparkComputationGraph.getNetwork().getOptimizer().getStepFunction());
            VoidParameterServer.getInstance().init(this.voidConfiguration, routedTransport, this.trainingDriver);
        }
    }

    protected void finalizeTraining() {
        if (this.trainingDriver != null) {
            this.trainingDriver.finishTraining(0L, 0L);
        }
    }

    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        prepareNetworkAndStuff(sparkDl4jMultiLayer, null);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkDl4jMultiLayer, javaRDD);
        } else {
            if (this.rddTrainingApproach != RDDTrainingApproach.Export) {
                throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
            }
            executeTrainingPathsHelper(sparkDl4jMultiLayer, (JavaRDD<String>) exportIfRequired(sparkDl4jMultiLayer.getSparkContext(), javaRDD), this.batchSizePerWorker);
        }
    }

    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        prepareNetworkAndStuff(sparkDl4jMultiLayer, null);
        doIterationPDS(sparkDl4jMultiLayer, null, javaPairRDD.values(), 1, 1);
    }

    public void executeTrainingPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<String> javaRDD) {
        prepareNetworkAndStuff(sparkDl4jMultiLayer, null);
        executeTrainingPathsHelper(sparkDl4jMultiLayer, javaRDD, this.batchSizePerWorker);
    }

    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkComputationGraph, javaRDD);
        } else {
            if (this.rddTrainingApproach != RDDTrainingApproach.Export) {
                throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
            }
            executeTrainingPathsHelper(sparkComputationGraph, (JavaRDD<String>) exportIfRequired(sparkComputationGraph.getSparkContext(), javaRDD), this.batchSizePerWorker);
        }
    }

    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        doIterationPDS(null, sparkComputationGraph, javaPairRDD.values(), 1, 1);
    }

    public void executeTrainingPaths(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        executeTrainingPathsHelper(sparkComputationGraph, javaRDD, this.batchSizePerWorker);
    }

    public void executeTrainingPathsMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        executeTrainingPathsMDSHelper(sparkComputationGraph, javaRDD, this.batchSizePerWorker);
    }

    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirectMDS(sparkComputationGraph, javaRDD);
        } else {
            if (this.rddTrainingApproach != RDDTrainingApproach.Export) {
                throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
            }
            executeTrainingPathsMDSHelper(sparkComputationGraph, exportIfRequiredMDS(sparkComputationGraph.getSparkContext(), javaRDD), this.batchSizePerWorker);
        }
    }

    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaPairRDD<String, PortableDataStream> javaPairRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        doIterationMultiPDS(sparkComputationGraph, javaPairRDD.values(), 1, 1);
    }

    public void setCollectTrainingStats(boolean z) {
        this.collectTrainingStats = z;
    }

    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public SparkTrainingStats getTrainingStats() {
        return null;
    }

    public void setListeners(Collection<IterationListener> collection) {
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<IterationListener> collection) {
    }

    protected void processResults(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<SharedTrainingResult> javaRDD) {
        if (sparkDl4jMultiLayer == null && sparkComputationGraph == null) {
            throw new IllegalStateException("Both MLN & CG are null");
        }
        finalizeTraining();
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        SharedTrainingAccumulationTuple sharedTrainingAccumulationTuple = (SharedTrainingAccumulationTuple) javaRDD.treeAggregate((Object) null, new SharedTrainingAggregateFunction(), new SharedTrainingAccumulationFunction(), 4);
        SparkTrainingStats sparkTrainingStats = sharedTrainingAccumulationTuple.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        if (sharedTrainingAccumulationTuple.getUpdaterStateArray() != null) {
            if (sharedTrainingAccumulationTuple.getAggregationsCount() > 1) {
                sharedTrainingAccumulationTuple.getUpdaterStateArray().divi(Integer.valueOf(sharedTrainingAccumulationTuple.getAggregationsCount()));
            }
            if (sparkDl4jMultiLayer != null) {
                if (sparkDl4jMultiLayer.getNetwork().getUpdater() != null && sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray() != null) {
                    sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray().assign(sharedTrainingAccumulationTuple.getUpdaterStateArray());
                }
            } else if (sparkComputationGraph.getNetwork().getUpdater() != null && sparkComputationGraph.getNetwork().getUpdater().getStateViewArray() != null) {
                sparkComputationGraph.getNetwork().getUpdater().getStateViewArray().assign(sharedTrainingAccumulationTuple.getUpdaterStateArray());
            }
        }
        double scoreSum = sharedTrainingAccumulationTuple.getScoreSum() / Math.max(1, sharedTrainingAccumulationTuple.getAggregationsCount());
        if (sparkDl4jMultiLayer != null) {
            sparkDl4jMultiLayer.getNetwork().setScore(scoreSum);
        } else {
            sparkComputationGraph.getNetwork().setScore(scoreSum);
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(sparkTrainingStats);
        }
        if (this.statsStorage != null) {
            Collection<StorageMetaData> listenerMetaData = sharedTrainingAccumulationTuple.getListenerMetaData();
            if (listenerMetaData != null && listenerMetaData.size() > 0) {
                this.statsStorage.putStorageMetaData(listenerMetaData);
            }
            Collection<Persistable> listenerStaticInfo = sharedTrainingAccumulationTuple.getListenerStaticInfo();
            if (listenerStaticInfo != null && listenerStaticInfo.size() > 0) {
                this.statsStorage.putStaticInfo(listenerStaticInfo);
            }
            Collection<Persistable> listenerUpdates = sharedTrainingAccumulationTuple.getListenerUpdates();
            if (listenerUpdates != null && listenerUpdates.size() > 0) {
                this.statsStorage.putUpdate(listenerUpdates);
            }
        }
        Nd4j.getExecutioner().commit();
    }

    protected void doIteration(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, null, repartition.mapPartitions(new SharedFlatMapDataSet(m10getWorkerInstance(sparkDl4jMultiLayer))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new SharedFlatMapMultiDataSet(m9getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIteration(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new SharedFlatMapDataSet(m9getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationPathsMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2, int i3) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(i3), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new SharedFlatMapPathsMDS(m9getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2, int i3) {
        if (sparkDl4jMultiLayer == null && sparkComputationGraph == null) {
            throw new DL4JInvalidConfigException("Both MLN & CompGraph are NULL");
        }
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(i3), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(new SharedFlatMapPaths(sparkDl4jMultiLayer != null ? m10getWorkerInstance(sparkDl4jMultiLayer) : m9getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationPDS(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(new SharedFlatMapPDS(m10getWorkerInstance(sparkDl4jMultiLayer))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationMultiPDS(SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, updatesThreshold={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Double.valueOf(this.threshold), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new SharedFlatMapMultiPDS(m9getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }
}
