package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.class */
public class IndexedTail {
    private static final Logger log = LoggerFactory.getLogger(IndexedTail.class);
    protected ConcurrentHashMap<Long, AtomicLong> positions;
    protected Map<Long, INDArray> updates;
    protected AtomicLong updatesCounter;
    protected AtomicLong lastDeletedIndex;
    protected final int expectedConsumers;
    protected AtomicBoolean dead;
    protected ReentrantReadWriteLock lock;
    protected final boolean allowCollapse;
    protected final long[] shape;
    protected final int collapseThreshold = 32;
    protected AtomicBoolean collapsedMode;
    protected AtomicLong collapsedIndex;

    public IndexedTail(int i) {
        this(i, false, null);
    }

    public IndexedTail(int i, boolean z, long[] jArr) {
        this.positions = new ConcurrentHashMap<>();
        this.updates = new ConcurrentHashMap();
        this.updatesCounter = new AtomicLong(0L);
        this.lastDeletedIndex = new AtomicLong(-1L);
        this.dead = new AtomicBoolean(false);
        this.lock = new ReentrantReadWriteLock();
        this.collapseThreshold = 32;
        this.collapsedMode = new AtomicBoolean(false);
        this.collapsedIndex = new AtomicLong(-1L);
        this.expectedConsumers = i;
        this.allowCollapse = z;
        if (z) {
            Preconditions.checkArgument(jArr != null, "shape can't be null if collapse is allowed");
        }
        this.shape = jArr;
    }

    public void put(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("update is marked @NonNull but is null");
        }
        try {
            this.lock.writeLock().lock();
            if (this.collapsedMode.get()) {
                INDArray iNDArray2 = this.updates.get(Long.valueOf(this.collapsedIndex.get()));
                Preconditions.checkArgument(!iNDArray2.isCompressed(), "lastUpdate should NOT be compressed during collapse mode");
                smartDecompress(iNDArray, iNDArray2);
            } else if (!this.allowCollapse || this.positions.size() < this.expectedConsumers) {
                this.updates.put(Long.valueOf(this.updatesCounter.getAndIncrement()), iNDArray);
            } else {
                long j = this.updatesCounter.get();
                long firstNotAppliedIndexEverywhere = firstNotAppliedIndexEverywhere();
                INDArray create = Nd4j.create(this.shape);
                long j2 = j - firstNotAppliedIndexEverywhere;
                if (j2 >= 32) {
                    log.info("Max delta to collapse: {}; Range: <{}...{}>", new Object[]{Long.valueOf(j2), Long.valueOf(firstNotAppliedIndexEverywhere), Long.valueOf(j)});
                    for (long j3 = firstNotAppliedIndexEverywhere; j3 < j; j3++) {
                        INDArray iNDArray3 = this.updates.get(Long.valueOf(j3));
                        if (iNDArray3 == null) {
                            log.error("Failed on index {}", Long.valueOf(j3));
                        }
                        smartDecompress(iNDArray3, create);
                        this.updates.remove(Long.valueOf(j3));
                    }
                    smartDecompress(iNDArray, create);
                    this.updates.put(Long.valueOf(j), create);
                    this.collapsedIndex.set(j);
                    this.updatesCounter.getAndIncrement();
                    this.collapsedMode.set(true);
                } else {
                    this.updates.put(Long.valueOf(this.updatesCounter.getAndIncrement()), iNDArray);
                }
            }
        } finally {
            this.lock.writeLock().unlock();
        }
    }

    protected long firstNotAppliedIndexEverywhere() {
        long j = -1;
        if (this.updatesCounter.get() == 0) {
            return -1L;
        }
        for (AtomicLong atomicLong : this.positions.values()) {
            if (atomicLong.get() > j) {
                j = atomicLong.get();
            }
        }
        return j + 1;
    }

    protected long maxAppliedIndexEverywhere() {
        long j = Long.MAX_VALUE;
        for (AtomicLong atomicLong : this.positions.values()) {
            if (atomicLong.get() < j) {
                j = atomicLong.get();
            }
        }
        return j;
    }

    public boolean hasAnything() {
        return hasAnything(Thread.currentThread().getId());
    }

    public boolean hasAnything(long j) {
        long localPosition = getLocalPosition(j);
        boolean z = localPosition < this.updatesCounter.get();
        log.info("hasAnything({}): {}; position: {}; updates: {}", new Object[]{Long.valueOf(j), Boolean.valueOf(z), Long.valueOf(localPosition), Long.valueOf(this.updatesCounter.get())});
        return z;
    }

    public boolean drainTo(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        return drainTo(Thread.currentThread().getId(), iNDArray);
    }

    protected long getGlobalPosition() {
        try {
            this.lock.readLock().lock();
            return this.updatesCounter.get();
        } finally {
            this.lock.readLock().unlock();
        }
    }

    protected long getLocalPosition() {
        return getLocalPosition(Thread.currentThread().getId());
    }

    protected long getDelta() {
        return getDelta(Thread.currentThread().getId());
    }

    protected long getDelta(long j) {
        return getGlobalPosition() - getLocalPosition(j);
    }

    protected long getLocalPosition(long j) {
        AtomicLong atomicLong = this.positions.get(Long.valueOf(j));
        if (atomicLong == null) {
            atomicLong = new AtomicLong(-1L);
            this.positions.put(Long.valueOf(j), atomicLong);
        }
        if (atomicLong.get() < 0) {
            return 0L;
        }
        return atomicLong.get();
    }

    public boolean drainTo(long j, @NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        AtomicLong atomicLong = this.positions.get(Long.valueOf(j));
        if (atomicLong == null) {
            atomicLong = new AtomicLong(-1L);
            this.positions.put(Long.valueOf(j), atomicLong);
        }
        ArrayList arrayList = new ArrayList();
        try {
            this.lock.readLock().lock();
            this.collapsedMode.set(false);
            long j2 = this.updatesCounter.get();
            long localPosition = getLocalPosition(j);
            long delta = getDelta(j);
            for (long j3 = localPosition; j3 < localPosition + delta; j3++) {
                INDArray iNDArray2 = this.updates.get(Long.valueOf(j3));
                if (!this.allowCollapse || iNDArray2 != null) {
                    if (iNDArray2 == null) {
                        log.info("Global: [{}]; Local: [{}]", Long.valueOf(j2), Long.valueOf(localPosition));
                        throw new RuntimeException("Element [" + j3 + "] is absent");
                    }
                    arrayList.add(iNDArray2);
                }
            }
            atomicLong.set(j2);
            this.lock.readLock().unlock();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                smartDecompress(((INDArray) it.next()).unsafeDuplication(true), iNDArray);
            }
            maintenance();
            return delta > 0;
        } catch (Throwable th) {
            this.lock.readLock().unlock();
            throw th;
        }
    }

    protected synchronized void maintenance() {
        if (this.positions.size() < this.expectedConsumers) {
            log.info("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", Integer.valueOf(this.positions.size()), Integer.valueOf(this.expectedConsumers));
            return;
        }
        long maxAppliedIndexEverywhere = maxAppliedIndexEverywhere();
        long[] jArr = new long[this.positions.size()];
        int i = 0;
        Iterator<AtomicLong> it = this.positions.values().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            jArr[i2] = it.next().get();
        }
        log.info("Min idx: {}; last deleted index: {}; stored updates: {}; positions: {}", new Object[]{Long.valueOf(maxAppliedIndexEverywhere), Long.valueOf(this.lastDeletedIndex.get()), Integer.valueOf(this.updates.size()), jArr});
        if (maxAppliedIndexEverywhere <= this.lastDeletedIndex.get()) {
            return;
        }
        long j = this.lastDeletedIndex.get();
        while (true) {
            long j2 = j;
            if (j2 >= maxAppliedIndexEverywhere) {
                this.lastDeletedIndex.set(maxAppliedIndexEverywhere);
                return;
            } else {
                this.updates.remove(Long.valueOf(j2));
                j = j2 + 1;
            }
        }
    }

    protected int updatesSize() {
        return this.updates.size();
    }

    protected INDArray smartDecompress(INDArray iNDArray, @NonNull INDArray iNDArray2) {
        if (iNDArray2 == null) {
            throw new NullPointerException("target is marked @NonNull but is null");
        }
        if (iNDArray.isCompressed() || iNDArray.data().dataType() == DataType.INT) {
            int i = iNDArray.data().getInt(3L);
            if (i == 0) {
                Nd4j.getExecutioner().thresholdDecode(iNDArray, iNDArray2);
            } else {
                if (i != 1) {
                    throw new ND4JIllegalStateException("Unknown encoding mode: [" + i + "]");
                }
                Nd4j.getExecutioner().bitmapDecode(iNDArray, iNDArray2);
            }
        } else {
            iNDArray2.addi(iNDArray);
        }
        return iNDArray2;
    }

    protected boolean isDead() {
        return this.dead.get();
    }

    protected void notifyDead() {
        this.dead.set(true);
    }

    public void purge() {
        this.positions.clear();
        this.updates.clear();
        this.updatesCounter.set(0L);
        this.lastDeletedIndex.set(-1L);
        this.collapsedMode.set(false);
        this.collapsedIndex.set(-1L);
    }
}
