package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;
import lombok.NonNull;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.AtomicThrowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.class */
public class EncodedGradientsAccumulator implements GradientsAccumulator, Registerable {
    private static final Logger log = LoggerFactory.getLogger(EncodedGradientsAccumulator.class);
    public static final long DEFAULT_INITIAL_MEMORY = 104857600;
    protected ThreadLocal<INDArray> accumulator;
    protected int parties;
    protected MessageHandler handler;
    protected List<BlockingQueue<INDArray>> messages;
    protected List<MemoryWorkspace> workspaces;
    protected List<ReentrantLock> locks;
    protected AtomicInteger workersCounter;
    protected ThreadLocal<Integer> index;
    protected long initialMemory;
    protected int queueSize;
    protected Double boundary;
    protected boolean encodingDebugMode;
    protected IndexedTail externalSource;
    protected AtomicBoolean isFirst;
    protected AtomicBoolean isDone;
    protected AtomicInteger barrier;
    protected AtomicInteger secondary;
    protected AtomicBoolean registered;
    protected AtomicBoolean bypassMode;
    protected final AtomicInteger currentConsumers;
    protected final AtomicThrowable throwable;
    protected boolean isDebug;
    protected final boolean relocatable;
    protected ThreadLocal<AtomicLong> updatesApplied;
    protected AtomicBoolean externalUpdatesAvailable;
    protected WorkspaceConfiguration appliedConfiguration;

    /* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator$Builder.class */
    public static class Builder {
        protected int parties;
        protected ThresholdAlgorithm thresholdAlgorithm;
        protected ResidualPostProcessor residualPostProcessor;
        protected MessageHandler handler;
        protected boolean encodingDebugMode;
        protected long initialMemory = EncodedGradientsAccumulator.DEFAULT_INITIAL_MEMORY;
        protected int queueSize = 5;
        protected Double boundary = null;

        public Builder(int i) {
            if (i < 1) {
                throw new DL4JInvalidConfigException("Number of parties for GradientsAccumulation should be positive value");
            }
            this.parties = i;
        }

        public Builder messageHandler(@NonNull MessageHandler messageHandler) {
            if (messageHandler == null) {
                throw new NullPointerException("handler is marked @NonNull but is null");
            }
            this.handler = messageHandler;
            return this;
        }

        public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
            this.thresholdAlgorithm = thresholdAlgorithm;
            return this;
        }

        public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor) {
            this.residualPostProcessor = residualPostProcessor;
            return this;
        }

        public Builder updatesBoundary(double d) {
            if (d >= 1.0d) {
                return this;
            }
            if (d <= EvaluationBinary.DEFAULT_EDGE_VALUE) {
                throw new DL4JInvalidConfigException("Boundary should have positive value");
            }
            this.boundary = Double.valueOf(d);
            return this;
        }

        public Builder memoryParameters(long j, int i) {
            this.initialMemory = j;
            this.queueSize = i;
            return this;
        }

        public Builder encodingDebugMode(boolean z) {
            this.encodingDebugMode = z;
            return this;
        }

        public EncodedGradientsAccumulator build() {
            if (this.handler == null) {
                Preconditions.checkNotNull(this.thresholdAlgorithm, "Both threshold algorithm and handler are null - one or the other must be set");
                this.handler = new EncodingHandler(this.thresholdAlgorithm, this.residualPostProcessor, this.boundary, this.encodingDebugMode);
            }
            return new EncodedGradientsAccumulator(this.parties, this.handler, this.initialMemory, this.queueSize, this.boundary, this.encodingDebugMode);
        }
    }

    public EncodedGradientsAccumulator(int i, double d) {
        this(i, new AdaptiveThresholdAlgorithm(d), new ResidualClippingPostProcessor(5.0d, 5), false);
    }

    public EncodedGradientsAccumulator(int i, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, boolean z) {
        this(i, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, Double.valueOf(1.0d), z), DEFAULT_INITIAL_MEMORY, 10, Double.valueOf(1.0d), z);
    }

    protected EncodedGradientsAccumulator(int i, @NonNull MessageHandler messageHandler, long j, int i2, Double d, boolean z) {
        this.accumulator = new ThreadLocal<>();
        this.messages = new ArrayList();
        this.workspaces = new ArrayList();
        this.locks = new ArrayList();
        this.workersCounter = new AtomicInteger(0);
        this.index = new ThreadLocal<>();
        this.initialMemory = DEFAULT_INITIAL_MEMORY;
        this.queueSize = 5;
        this.boundary = Double.valueOf(1.0d);
        this.isFirst = new AtomicBoolean(false);
        this.isDone = new AtomicBoolean(true);
        this.barrier = new AtomicInteger(0);
        this.secondary = new AtomicInteger(0);
        this.registered = new AtomicBoolean(false);
        this.bypassMode = new AtomicBoolean(false);
        this.currentConsumers = new AtomicInteger(0);
        this.throwable = new AtomicThrowable();
        this.isDebug = false;
        this.updatesApplied = new ThreadLocal<>();
        this.externalUpdatesAvailable = new AtomicBoolean(false);
        this.appliedConfiguration = WorkspaceConfiguration.builder().minSize(5242880L).overallocationLimit(0.3d).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).build();
        if (messageHandler == null) {
            throw new NullPointerException("handler is marked @NonNull but is null");
        }
        this.parties = i;
        this.handler = messageHandler;
        this.initialMemory = j;
        this.queueSize = i2;
        this.boundary = d;
        this.encodingDebugMode = z;
        WorkspaceConfiguration build = WorkspaceConfiguration.builder().initialSize(j).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT).policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();
        this.relocatable = Nd4j.getAffinityManager().getNumberOfDevices() > 1 && !Nd4j.getAffinityManager().isCrossDeviceAccessSupported();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        if (i > numberOfDevices && numberOfDevices != 1) {
            throw new ND4JIllegalStateException("Number of parties [" + i + "] should be less or equal to number of devices [" + numberOfDevices + "]");
        }
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        for (int i3 = 0; i3 < i; i3++) {
            this.messages.add(new LinkedBlockingQueue(i2));
            int i4 = numberOfDevices > 1 ? i3 % numberOfDevices : 0;
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(i4));
            this.workspaces.add(Nd4j.getWorkspaceManager().createNewWorkspace(build, "CGA-" + i3, Integer.valueOf(i4)));
            this.locks.add(new ReentrantLock());
        }
        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(intValue));
        messageHandler.initialize(this);
    }

    public static long getOptimalBufferSize(long j, int i, int i2) {
        return ((j / 16) + 65536) * i * i2 * 4;
    }

    public static long getOptimalBufferSize(Model model, int i, int i2) {
        return getOptimalBufferSize(model.params().length(), i, i2);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.Registerable
    public void fallbackToSingleConsumerMode(boolean z) {
        if (this.externalSource != null && (this.externalSource instanceof Registerable)) {
            ((Registerable) this.externalSource).fallbackToSingleConsumerMode(z);
        }
        this.bypassMode.set(z);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.Registerable
    public void registerConsumers(int i) {
        if (this.registered.get()) {
            if (this.isDebug) {
                log.info("Master thread locks at RC");
            }
            while (this.registered.get()) {
                ThreadUtils.uncheckedSleep(1L);
                if (this.throwable.isTriggered()) {
                    throw new RuntimeException(this.throwable.get());
                }
            }
            if (this.isDebug) {
                log.info("Master thread unlocks at RC");
            }
        }
        if (this.externalSource != null && (this.externalSource instanceof Registerable)) {
            ((Registerable) this.externalSource).registerConsumers(i);
        }
        this.currentConsumers.set(i);
        this.registered.set(true);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public IndexedTail getExternalSource() {
        return this.externalSource;
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void markExternalUpdates(boolean z) {
        this.externalUpdatesAvailable.set(z);
    }

    protected void synchronize(int i) {
        synchronize(i, false);
    }

    /* JADX WARN: Removed duplicated region for block: B:13:0x00aa  */
    /* JADX WARN: Removed duplicated region for block: B:19:0x00ef  */
    /* JADX WARN: Removed duplicated region for block: B:22:? A[RETURN, SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:25:0x00cb  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected void synchronize(int r6, boolean r7) {
        /*
            Method dump skipped, instructions count: 269
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator.synchronize(int, boolean):void");
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void applyUpdate(StepFunction stepFunction, INDArray iNDArray, INDArray iNDArray2, boolean z) {
        if (this.updatesApplied.get() == null) {
            this.updatesApplied.set(new AtomicLong(0L));
        }
        try {
            Nd4j.getMemoryManager().memset(iNDArray2);
            int i = 0;
            while (!this.messages.get(this.index.get().intValue()).isEmpty()) {
                INDArray poll = this.messages.get(this.index.get().intValue()).poll();
                int i2 = poll.data().getInt(3L);
                if (i2 == 0) {
                    Nd4j.getExecutioner().thresholdDecode(poll, iNDArray2);
                } else {
                    if (i2 != 1) {
                        throw new DL4JInvalidConfigException("Unknown compression header received: " + i2);
                    }
                    Nd4j.getExecutioner().bitmapDecode(poll, iNDArray2);
                }
                i++;
            }
            if (i > 0 && this.isDebug) {
                log.info("Local updates to be applied: {}", Integer.valueOf(i));
            }
            if (this.externalSource != null) {
                int i3 = 0;
                if (this.externalSource.hasAnything()) {
                    this.externalSource.drainTo(iNDArray2);
                    i++;
                    i3 = 0 + 1;
                }
                if (this.isDebug) {
                    log.info("thread {} finished at Externals", Long.valueOf(Thread.currentThread().getId()));
                }
                if (i3 > 0 && this.isDebug) {
                    log.info("External updates to be applied: {}", Integer.valueOf(i3));
                }
            }
            if (z) {
                synchronize(this.currentConsumers.get(), z);
            }
            if (i > 0) {
                stepFunction.step(iNDArray, iNDArray2);
                this.updatesApplied.get().addAndGet(i);
                if (this.isDebug) {
                    log.info("Total updates applied so far for thread [{}]: [{}]", Thread.currentThread().getName(), this.updatesApplied.get());
                }
            }
        } catch (Exception e) {
            this.throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void applyUpdate(StepFunction stepFunction, INDArray iNDArray, INDArray iNDArray2, double d) {
        try {
            Nd4j.getMemoryManager().memset(iNDArray2);
            int i = 0;
            while (!this.messages.get(this.index.get().intValue()).isEmpty()) {
                INDArray poll = this.messages.get(this.index.get().intValue()).poll();
                int i2 = poll.data().getInt(3L);
                if (i2 == 0) {
                    Nd4j.getExecutioner().thresholdDecode(poll, iNDArray2);
                } else {
                    if (i2 != 1) {
                        throw new DL4JInvalidConfigException("Unknown compression header received: " + i2);
                    }
                    Nd4j.getExecutioner().bitmapDecode(poll, iNDArray2);
                }
                i++;
            }
            if (i > 0 && this.isDebug) {
                log.info("Local updates to be applied: {}", Integer.valueOf(i));
            }
            if (this.externalSource != null) {
                int i3 = 0;
                if (this.externalSource.hasAnything()) {
                    this.externalSource.drainTo(iNDArray2);
                    i++;
                    i3 = 0 + 1;
                }
                if (i3 > 0 && this.isDebug) {
                    log.info("External updates to be applied: {}", Integer.valueOf(i3));
                }
            }
            synchronize(this.currentConsumers.get(), true);
            if (i > 0) {
                stepFunction.step(iNDArray, iNDArray2, d);
            }
        } catch (Exception e) {
            this.throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void setExternalSource(IndexedTail indexedTail) {
        this.externalSource = indexedTail;
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void touch() {
        if (this.index.get() == null) {
            if (Nd4j.getAffinityManager().getNumberOfDevices() <= 1 || this.parties <= 1) {
                this.index.set(Integer.valueOf(this.workersCounter.getAndIncrement()));
            } else {
                this.index.set(Integer.valueOf(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()));
            }
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:38:0x00be, code lost:
    
        if (r5.bypassMode.get() == false) goto L32;
     */
    /* JADX WARN: Code restructure failed: missing block: B:40:0x00c8, code lost:
    
        if (r5.registered.get() != false) goto L54;
     */
    /* JADX WARN: Code restructure failed: missing block: B:41:0x00cb, code lost:
    
        org.deeplearning4j.util.ThreadUtils.uncheckedSleep(1);
     */
    /* JADX WARN: Code restructure failed: missing block: B:42:0x00d6, code lost:
    
        if (r5.throwable.isTriggered() == false) goto L55;
     */
    /* JADX WARN: Code restructure failed: missing block: B:45:0x00e7, code lost:
    
        throw new java.lang.RuntimeException(r5.throwable.get());
     */
    /* JADX WARN: Code restructure failed: missing block: B:49:0x00ec, code lost:
    
        if (r5.isDebug == false) goto L41;
     */
    /* JADX WARN: Code restructure failed: missing block: B:50:0x00ef, code lost:
    
        org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator.log.info("thread {} unlocking at Register", java.lang.Long.valueOf(java.lang.Thread.currentThread().getId()));
     */
    /* JADX WARN: Code restructure failed: missing block: B:51:0x0102, code lost:
    
        r5.handler.broadcastUpdates(r5.accumulator.get(), r7, r8);
        synchronize(r5.currentConsumers.get());
     */
    /* JADX WARN: Code restructure failed: missing block: B:52:0x013b, code lost:
    
        return;
     */
    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void storeUpdate(org.nd4j.linalg.api.ndarray.INDArray r6, int r7, int r8) {
        /*
            Method dump skipped, instructions count: 316
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator.storeUpdate(org.nd4j.linalg.api.ndarray.INDArray, int, int):void");
    }

    /* JADX WARN: Finally extract failed */
    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void receiveUpdate(INDArray iNDArray) {
        for (int i = 0; i < this.parties; i++) {
            try {
                this.locks.get(i).lock();
                MemoryWorkspace notifyScopeEntered = this.workspaces.get(i).notifyScopeEntered();
                Throwable th = null;
                try {
                    if (iNDArray.data().length() > (this.initialMemory / this.queueSize) / Nd4j.sizeOfDataType(iNDArray.data().dataType())) {
                        throw new ND4JIllegalStateException("Not enough memory to handle update: [" + (iNDArray.data().length() * Nd4j.sizeOfDataType(iNDArray.data().dataType())) + " bytes required]. Please increase memory amount for GradientsAccumulator");
                    }
                    try {
                        this.messages.get(i).put(iNDArray.unsafeDuplication());
                        if (notifyScopeEntered != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                        this.locks.get(i).unlock();
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        log.warn("Something bad at index_{}", Integer.valueOf(i));
                        throw new RuntimeException(e);
                    }
                } catch (Throwable th3) {
                    if (notifyScopeEntered != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    throw th3;
                }
            } catch (Exception e2) {
                this.throwable.setIfFirst(e2);
                throw new RuntimeException(e2);
            }
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void reset() {
        this.accumulator = new ThreadLocal<>();
        this.workersCounter.set(0);
        this.index = new ThreadLocal<>();
        for (int i = 0; i < this.parties; i++) {
            this.messages.get(i).clear();
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public boolean hasAnything() {
        return this.externalSource != null && this.externalSource.hasAnything();
    }

    public MessageHandler getHandler() {
        return this.handler;
    }
}
