package org.deeplearning4j.spark.parameterserver.networking;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.FancyBlockingQueue;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.spark.parameterserver.networking.messages.SilentUpdatesMessage;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/networking/SilentTrainingDriver.class */
public class SilentTrainingDriver implements TrainingDriver<SilentUpdatesMessage> {
    private static final Logger log = LoggerFactory.getLogger(SilentTrainingDriver.class);
    protected transient INDArray params;
    protected transient INDArray updates;
    protected transient StepFunction stepFunction;
    protected transient GradientsAccumulator accumulator;
    protected transient VoidConfiguration voidConfiguration;
    protected transient Transport transport;
    protected transient AtomicLong updatesCount;
    protected transient AtomicBoolean hasSomething;
    protected transient AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected transient AtomicLong denseCounter = new AtomicLong(0);
    protected transient AtomicLong sparseCounter = new AtomicLong(0);
    protected transient BlockingQueue<INDArray> updatesBuffer;
    protected transient Storage storage;
    protected transient Clipboard clipboard;

    public SilentTrainingDriver(@NonNull GradientsAccumulator gradientsAccumulator) {
        if (gradientsAccumulator == null) {
            throw new NullPointerException("accumulator");
        }
        log.info("Creating TrainingDriver for worker...");
        this.accumulator = gradientsAccumulator;
        this.updatesCount = new AtomicLong(0L);
        this.updatesBuffer = new FancyBlockingQueue(new LinkedBlockingQueue(1024));
        this.accumulator.setExternalSource(this.updatesBuffer);
    }

    public SilentTrainingDriver(@NonNull INDArray iNDArray, @NonNull StepFunction stepFunction) {
        if (iNDArray == null) {
            throw new NullPointerException("params");
        }
        if (stepFunction == null) {
            throw new NullPointerException("stepFunction");
        }
        log.info("Creating TrainingDriver for master...");
        log.info("Params at Master BEFORE: {}", Double.valueOf(iNDArray.meanNumber().doubleValue()));
        this.params = iNDArray;
        this.stepFunction = stepFunction;
        this.updatesCount = new AtomicLong(0L);
        this.hasSomething = new AtomicBoolean(false);
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.updates = Nd4j.create(iNDArray.shape(), iNDArray.ordering());
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    public BlockingQueue<INDArray> getUpdatesBuffer() {
        return this.updatesBuffer;
    }

    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, Storage storage, Clipboard clipboard) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration");
        }
        if (transport == null) {
            throw new NullPointerException("transport");
        }
        this.voidConfiguration = voidConfiguration;
        this.transport = transport;
    }

    public void bypassMode(boolean z) {
        this.bypassMode.set(z);
        if (z) {
            this.updatesBuffer.clear();
        }
    }

    public void startTraining(SilentUpdatesMessage silentUpdatesMessage) {
        if (this.accumulator != null) {
            if (silentUpdatesMessage.getOriginatorId() == this.transport.getOwnOriginatorId()) {
                return;
            }
            try {
                if (!this.bypassMode.get()) {
                    this.updatesBuffer.put(silentUpdatesMessage.getUpdates());
                }
                return;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (this.params == null || this.stepFunction == null) {
            throw new DL4JInvalidConfigException("Neither GradientsAccumulator or StepFunction is defined!");
        }
        synchronized (this) {
            int i = silentUpdatesMessage.getUpdates().data().getInt(3L);
            if (i == 0) {
                Nd4j.getExecutioner().thresholdDecode(silentUpdatesMessage.getUpdates(), this.updates);
                this.sparseCounter.incrementAndGet();
            } else {
                if (i != 1) {
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + i);
                }
                Nd4j.getExecutioner().bitmapDecode(silentUpdatesMessage.getUpdates(), this.updates);
                this.denseCounter.incrementAndGet();
            }
            this.hasSomething.set(true);
            if (this.updatesCount.incrementAndGet() % Math.max(this.transport.numberOfKnownClients(), 5) == 0) {
                this.stepFunction.step(this.params, this.updates);
                Nd4j.getMemoryManager().memset(this.updates);
                this.hasSomething.set(false);
            }
        }
        if (this.transport.numberOfKnownClients() > 1) {
            this.transport.sendMessageToAllClients(silentUpdatesMessage, new Long[]{Long.valueOf(silentUpdatesMessage.getOriginatorId()), Long.valueOf(this.transport.getOwnOriginatorId())});
        }
    }

    public void pickTraining(SilentUpdatesMessage silentUpdatesMessage) {
        throw new UnsupportedOperationException();
    }

    public void aggregationFinished(VoidAggregation voidAggregation) {
        throw new UnsupportedOperationException();
    }

    public void finishTraining(long j, long j2) {
        if (this.params == null || this.stepFunction == null || !this.hasSomething.get()) {
            return;
        }
        this.stepFunction.step(this.params, this.updates);
        this.updates.assign(Double.valueOf(0.0d));
    }

    public void addCompletionHook(long j, long j2, long j3) {
        throw new UnsupportedOperationException();
    }

    public String targetMessageClass() {
        return SilentUpdatesMessage.class.getSimpleName();
    }
}
