package org.deeplearning4j.spark.parameterserver.pw;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import org.bytedeco.javacpp.Loader;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.SleepyTrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.deeplearning4j.spark.parameterserver.iterators.VirtualDataSetIterator;
import org.deeplearning4j.spark.parameterserver.iterators.VirtualIterator;
import org.deeplearning4j.spark.parameterserver.iterators.VirtualMultiDataSetIterator;
import org.deeplearning4j.spark.parameterserver.networking.SilentTrainingDriver;
import org.deeplearning4j.spark.parameterserver.networking.WiredEncodingHandler;
import org.deeplearning4j.spark.parameterserver.networking.messages.SilentIntroductoryMessage;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;
import org.deeplearning4j.spark.parameterserver.util.BlockingObserver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.enums.TransportType;
import org.nd4j.parameterserver.distributed.transport.MulticastTransport;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.class */
public class SharedTrainingWrapper {
    private static final Logger log = LoggerFactory.getLogger(SharedTrainingWrapper.class);
    public static SharedTrainingWrapper INSTANCE = new SharedTrainingWrapper();
    protected ParallelWrapper wrapper;
    protected VirtualDataSetIterator iteratorDS;
    protected VirtualMultiDataSetIterator iteratorMDS;
    protected List<Iterator<DataSet>> iteratorsDS;
    protected List<Iterator<MultiDataSet>> iteratorsMDS;
    protected AtomicBoolean isFirst = new AtomicBoolean(false);
    protected ThreadLocal<BlockingObserver> observer = new ThreadLocal<>();
    protected EncodedGradientsAccumulator accumulator;
    protected Model originalModel;
    protected SilentTrainingDriver driver;

    protected SharedTrainingWrapper() {
        init();
    }

    protected void init() {
        this.iteratorsDS = new CopyOnWriteArrayList();
        this.iteratorsMDS = new CopyOnWriteArrayList();
        this.iteratorDS = new VirtualDataSetIterator(this.iteratorsDS);
    }

    public static SharedTrainingWrapper getInstance() {
        return INSTANCE;
    }

    public void attachDS(Iterator<DataSet> it) {
        log.info("Attaching thread...");
        VirtualIterator virtualIterator = new VirtualIterator(it);
        BlockingObserver blockingObserver = new BlockingObserver();
        virtualIterator.addObserver(blockingObserver);
        this.iteratorsDS.add(virtualIterator);
        this.observer.set(blockingObserver);
    }

    public void attachMDS(Iterator<MultiDataSet> it) {
        log.info("Attaching thread...");
        VirtualIterator virtualIterator = new VirtualIterator(it);
        BlockingObserver blockingObserver = new BlockingObserver();
        virtualIterator.addObserver(blockingObserver);
        this.iteratorsMDS.add(virtualIterator);
        this.observer.set(blockingObserver);
    }

    public SharedTrainingResult run(SharedTrainingWorker sharedTrainingWorker) {
        if (!this.isFirst.compareAndSet(false, true)) {
            try {
                this.observer.get().waitTillDone();
                log.info("Feeder thread done...");
                return new SharedTrainingResult();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        SharedTrainingConfiguration sharedTrainingConfiguration = (SharedTrainingConfiguration) sharedTrainingWorker.getBroadcastConfiguration().getValue();
        VoidConfiguration voidConfiguration = ((SharedTrainingConfiguration) sharedTrainingWorker.getBroadcastConfiguration().getValue()).getVoidConfiguration();
        ComputationGraph computationGraph = null;
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int numberOfWorkersPerNode = sharedTrainingConfiguration.getNumberOfWorkersPerNode() > 0 ? sharedTrainingConfiguration.getNumberOfWorkersPerNode() : numberOfDevices > 1 ? numberOfDevices : Math.min(6, Math.max(1, Loader.totalCores() / 4));
        if (numberOfDevices > 1 && numberOfWorkersPerNode > numberOfDevices) {
            log.warn("WARNING! Using more workers then number of available computational devices!");
        }
        if (this.wrapper == null) {
            log.info("Starting ParallelWrapper at thread {}", Long.valueOf(Thread.currentThread().getId()));
            computationGraph = sharedTrainingWorker.getInitialModel();
            if (computationGraph == null) {
                computationGraph = sharedTrainingWorker.getInitialModelGraph();
            }
            if (computationGraph == null) {
                throw new DL4JInvalidConfigException("No model was defined for training");
            }
            WiredEncodingHandler wiredEncodingHandler = new WiredEncodingHandler(sharedTrainingConfiguration.getThreshold(), sharedTrainingConfiguration.getMinThreshold(), sharedTrainingConfiguration.getThresholdStep(), sharedTrainingConfiguration.getStepTrigger(), sharedTrainingConfiguration.getStepDelay(), sharedTrainingConfiguration.getShakeFrequency());
            if (this.accumulator == null) {
                this.accumulator = new EncodedGradientsAccumulator.Builder(numberOfWorkersPerNode).messageHandler(wiredEncodingHandler).encodingThreshold(sharedTrainingConfiguration.getThreshold()).memoryParameters(sharedTrainingConfiguration.getBufferSize() > 0 ? sharedTrainingConfiguration.getBufferSize() : EncodedGradientsAccumulator.getOptimalBufferSize(computationGraph, numberOfWorkersPerNode, 2), numberOfWorkersPerNode * 2).build();
                RoutedTransport routedTransport = voidConfiguration.getTransportType() == TransportType.ROUTED ? new RoutedTransport() : voidConfiguration.getTransportType() == TransportType.BROADCAST ? new MulticastTransport() : null;
                if (routedTransport == null) {
                    throw new DL4JInvalidConfigException("No Transport implementation was defined for this training session!");
                }
                if (!VoidParameterServer.getInstance().isInit()) {
                    voidConfiguration.setForcedRole((NodeRole) null);
                }
                this.driver = new SilentTrainingDriver(this.accumulator);
                VoidParameterServer.getInstance().init(voidConfiguration, routedTransport, this.driver);
                this.originalModel = computationGraph;
                String str = System.getenv("SPARK_PUBLIC_DNS");
                if (str == null && voidConfiguration.getNetworkMask() != null) {
                    str = new NetworkOrganizer(voidConfiguration.getNetworkMask()).getMatchingAddress();
                }
                if (str == null) {
                    str = System.getenv("DL4J_VOID_IP");
                }
                if (str == null) {
                    str = "127.0.0.1";
                    log.warn("Can't get IP address to start VoidParameterServer client. Using localhost instead");
                }
                VoidParameterServer.getInstance().sendMessageToAllShards(new SilentIntroductoryMessage(str, voidConfiguration.getUnicastPort()));
            }
            if (sharedTrainingConfiguration.getDebugLongerIterations() > 0) {
                log.warn("Adding SleepyListener: {} ms", Long.valueOf(sharedTrainingConfiguration.getDebugLongerIterations()));
                computationGraph.addListeners(new IterationListener[]{SleepyTrainingListener.builder().timerIteration(sharedTrainingConfiguration.getDebugLongerIterations()).build()});
            }
            if (numberOfWorkersPerNode > 1) {
                log.info("Params at PW: {}", Double.valueOf(this.originalModel.params().meanNumber().doubleValue()));
                this.wrapper = new ParallelWrapper.Builder(this.originalModel).workers(numberOfWorkersPerNode).workspaceMode(sharedTrainingConfiguration.getWorkspaceMode()).trainingMode(ParallelWrapper.TrainingMode.CUSTOM).gradientsAccumulator(this.accumulator).prefetchBuffer(sharedTrainingConfiguration.getPrefetchSize()).build();
            } else {
                log.info("Using standalone model instead...");
                this.accumulator.fallbackToSingleConsumerMode(true);
                this.accumulator.touch();
                if (computationGraph instanceof ComputationGraph) {
                    this.originalModel.getConfiguration().setTrainingWorkspaceMode(sharedTrainingConfiguration.getWorkspaceMode());
                    this.originalModel.setGradientsAccumulator(this.accumulator);
                } else if (computationGraph instanceof MultiLayerNetwork) {
                    this.originalModel.getLayerWiseConfigurations().setTrainingWorkspaceMode(sharedTrainingConfiguration.getWorkspaceMode());
                    this.originalModel.setGradientsAccumulator(this.accumulator);
                }
            }
        }
        this.driver.bypassMode(false);
        if (this.wrapper != null) {
            if (this.iteratorDS != null) {
                this.wrapper.fit(this.iteratorDS);
            } else {
                if (this.iteratorMDS == null) {
                    throw new DL4JInvalidConfigException("No iterators were defined for training");
                }
                this.wrapper.fit(this.iteratorMDS);
            }
        } else if (this.iteratorDS != null) {
            if (computationGraph instanceof ComputationGraph) {
                this.originalModel.fit(this.iteratorDS);
            } else if (computationGraph instanceof MultiLayerNetwork) {
                this.originalModel.fit(this.iteratorDS);
            }
        } else {
            if (this.iteratorMDS == null) {
                throw new DL4JInvalidConfigException("No iterators were defined for training");
            }
            this.originalModel.fit(this.iteratorMDS);
        }
        if (sharedTrainingConfiguration.isEpochReset()) {
            this.wrapper.shutdown();
            this.wrapper = null;
        }
        init();
        this.accumulator.reset();
        this.driver.bypassMode(true);
        this.isFirst.set(false);
        log.info("Master thread done...");
        INDArray iNDArray = null;
        if (computationGraph instanceof ComputationGraph) {
            iNDArray = this.originalModel.getUpdater().getUpdaterStateViewArray();
        } else if (computationGraph instanceof MultiLayerNetwork) {
            iNDArray = this.originalModel.getUpdater().getStateViewArray();
        }
        return SharedTrainingResult.builder().aggregationsCount(1).scoreSum(this.originalModel.score()).updaterStateArray(iNDArray).listenerMetaData(new ArrayList()).listenerStaticInfo(new ArrayList()).listenerUpdates(new ArrayList()).build();
    }

    public void passDataSet(DataSet dataSet) {
    }

    public void passDataSet(MultiDataSet multiDataSet) {
    }

    public void blockUntilFinished() throws InterruptedException {
        if (this.observer.get() == null) {
            throw new IllegalStateException("This method can't be called before iterators initialization");
        }
        this.observer.get().wait();
    }
}
