package org.deeplearning4j.optimize.solvers.accumulation;

import com.google.common.util.concurrent.AtomicDouble;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.NDArrayCompressor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.class */
public class EncodingHandler implements MessageHandler {
    public static final long THRESHOLD_LOG_FREQ_MS = 10000;
    protected transient GradientsAccumulator accumulator;
    protected ThresholdAlgorithm initialThresholdAlgorithm;
    protected ResidualPostProcessor initialResidualPostProcessor;
    protected Double boundary;
    protected boolean encodingDebugMode;
    protected NDArrayCompressor compressor;
    protected AtomicInteger atomicBoundary = new AtomicInteger(-1);
    protected ThreadLocal<ThresholdAlgorithm> thresholdAlgorithm = new ThreadLocal<>();
    protected Map<Long, ThresholdAlgorithm> allThreadThresholdAlgorithms = new ConcurrentHashMap();
    protected ThreadLocal<ResidualPostProcessor> residualPostProcessor = new ThreadLocal<>();
    protected ThreadLocal<AtomicLong> iterations = new ThreadLocal<>();
    protected ThreadLocal<AtomicLong> lastStep = new ThreadLocal<>();
    protected ThreadLocal<AtomicDouble> lastThreshold = new ThreadLocal<>();
    protected ThreadLocal<AtomicDouble> lastSparsityRatio = new ThreadLocal<>();
    protected ThreadLocal<AtomicDouble> currentThreshold = new ThreadLocal<>();
    protected ThreadLocal<AtomicBoolean> bitmapMode = new ThreadLocal<>();
    protected ThreadLocal<AtomicBoolean> lastIterWasDense = new ThreadLocal<>();
    protected AtomicLong lastThresholdLogTime = new AtomicLong();
    private static final Logger log = LoggerFactory.getLogger(EncodingHandler.class);
    protected static ThreadLocal<DecimalFormat> formatter = new ThreadLocal<>();
    protected static ThreadLocal<DecimalFormat> formatter2 = new ThreadLocal<>();

    public EncodingHandler(ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, Double d, boolean z) {
        this.initialThresholdAlgorithm = thresholdAlgorithm;
        this.initialResidualPostProcessor = residualPostProcessor;
        this.boundary = d;
        this.encodingDebugMode = z;
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.MessageHandler
    public void initialize(@NonNull GradientsAccumulator gradientsAccumulator) {
        if (gradientsAccumulator == null) {
            throw new NullPointerException("accumulator is marked @NonNull but is null");
        }
        this.accumulator = gradientsAccumulator;
        this.compressor = Nd4j.getCompressor().getCompressor("THRESHOLD");
        if (this.compressor == null) {
            throw new ND4JIllegalStateException("Can't find Threshold compressor implementation!");
        }
    }

    public INDArray encodeUpdates(int i, int i2, INDArray iNDArray) {
        INDArray createArrayFromShapeBuffer;
        if (this.thresholdAlgorithm.get() == null) {
            synchronized (this) {
                this.thresholdAlgorithm.set(this.initialThresholdAlgorithm.m220clone());
                this.allThreadThresholdAlgorithms.put(Long.valueOf(Thread.currentThread().getId()), this.thresholdAlgorithm.get());
                if (this.initialResidualPostProcessor != null) {
                    this.residualPostProcessor.set(this.initialResidualPostProcessor.m217clone());
                }
            }
        }
        Double d = null;
        Boolean bool = null;
        Double d2 = null;
        if (this.lastThreshold.get() != null) {
            d = Double.valueOf(this.lastThreshold.get().get());
            bool = Boolean.valueOf(this.lastIterWasDense.get().get());
            d2 = (bool.booleanValue() || this.lastSparsityRatio.get() == null) ? null : Double.valueOf(this.lastSparsityRatio.get().get());
        }
        double calculateThreshold = this.thresholdAlgorithm.get().calculateThreshold(i, i2, d, bool, d2, iNDArray);
        if (this.bitmapMode.get() == null) {
            this.bitmapMode.set(new AtomicBoolean(true));
            this.currentThreshold.set(new AtomicDouble(calculateThreshold));
            this.iterations.set(new AtomicLong(0L));
            this.lastStep.set(new AtomicLong(0L));
            this.lastThreshold.set(new AtomicDouble(calculateThreshold));
            this.lastIterWasDense.set(new AtomicBoolean());
        }
        this.currentThreshold.get().set(calculateThreshold);
        this.lastThreshold.get().set(calculateThreshold);
        residualDebugOutputIfRequired(iNDArray);
        this.iterations.get().incrementAndGet();
        if (this.boundary != null && this.atomicBoundary.get() < 0) {
            this.atomicBoundary.compareAndSet(-1, (int) (iNDArray.lengthLong() * this.boundary.doubleValue()));
        }
        if (this.bitmapMode.get().get()) {
            createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt((iNDArray.lengthLong() / 16) + 5), iNDArray.shapeInfoDataBuffer());
            long bitmapEncode = Nd4j.getExecutioner().bitmapEncode(iNDArray, createArrayFromShapeBuffer, this.currentThreshold.get().get());
            if (bitmapEncode < ((iNDArray.lengthLong() / 16) + 5) / 2) {
                this.bitmapMode.get().set(false);
                log.debug("Switched to threshold encoding: iteration {}, epoch {}, threshold {}, number of values {}", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(calculateThreshold), Long.valueOf(bitmapEncode)});
            }
            this.lastSparsityRatio.set(null);
            this.lastIterWasDense.get().set(true);
        } else {
            createArrayFromShapeBuffer = Nd4j.getExecutioner().thresholdEncode(iNDArray, this.currentThreshold.get().get(), this.boundary == null ? null : Integer.valueOf(this.atomicBoundary.get()));
            if (createArrayFromShapeBuffer == null) {
                this.bitmapMode.get().set(false);
                if (this.lastSparsityRatio.get() == null) {
                    this.lastSparsityRatio.set(new AtomicDouble(EvaluationBinary.DEFAULT_EDGE_VALUE));
                } else {
                    this.lastSparsityRatio.get().set(EvaluationBinary.DEFAULT_EDGE_VALUE);
                }
                this.lastIterWasDense.get().set(false);
                logThresholdIfReq(false, i, i2);
                return null;
            }
            double d3 = createArrayFromShapeBuffer.data().getInt(0L);
            if (d3 >= iNDArray.lengthLong() / 16) {
                log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(calculateThreshold), Double.valueOf(d3)});
                this.bitmapMode.get().set(true);
                INDArray createArrayFromShapeBuffer2 = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt((iNDArray.lengthLong() / 16) + 5), iNDArray.shapeInfoDataBuffer());
                Nd4j.getExecutioner().bitmapEncode(iNDArray, createArrayFromShapeBuffer2, this.currentThreshold.get().get());
                applyPostProcessor(i, i2, Double.valueOf(calculateThreshold), iNDArray);
                this.lastSparsityRatio.set(null);
                this.lastIterWasDense.get().set(true);
                logThresholdIfReq(true, i, i2);
                return createArrayFromShapeBuffer2;
            }
            double length = d3 / iNDArray.length();
            if (this.lastSparsityRatio.get() == null) {
                this.lastSparsityRatio.set(new AtomicDouble(length));
            } else {
                this.lastSparsityRatio.get().set(length);
            }
            this.lastIterWasDense.get().set(false);
        }
        applyPostProcessor(i, i2, Double.valueOf(calculateThreshold), iNDArray);
        logThresholdIfReq(this.lastIterWasDense.get().get(), i, i2);
        return createArrayFromShapeBuffer;
    }

    public void applyPostProcessor(int i, int i2, Double d, INDArray iNDArray) {
        if (this.initialResidualPostProcessor == null) {
            return;
        }
        this.residualPostProcessor.get().processResidual(i, i2, d.doubleValue(), iNDArray);
    }

    @Deprecated
    public INDArray decodeUpdates(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    protected void sendMessage(INDArray iNDArray, int i, int i2) {
        this.accumulator.receiveUpdate(iNDArray);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.MessageHandler
    public boolean broadcastUpdates(INDArray iNDArray, int i, int i2) {
        INDArray encodeUpdates = encodeUpdates(i, i2, iNDArray);
        if (encodeUpdates == null) {
            return false;
        }
        sendMessage(encodeUpdates, i, i2);
        return true;
    }

    protected void logThresholdIfReq(boolean z, int i, int i2) {
        long currentTimeMillis = System.currentTimeMillis();
        long j = this.lastThresholdLogTime.get();
        if (j + THRESHOLD_LOG_FREQ_MS > currentTimeMillis || !this.lastThresholdLogTime.compareAndSet(j, currentTimeMillis)) {
            return;
        }
        String format = format(this.lastThreshold.get().get());
        if (z) {
            log.info("Threshold at iter {}, epoch {} [thread {}]: {}, DENSE updates", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Long.valueOf(Thread.currentThread().getId()), format});
        } else {
            AtomicDouble atomicDouble = this.lastSparsityRatio.get();
            log.info("Threshold at iter {}, epoch {}: {}, SPARSE updates, last sparsity ratio: {}", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Long.valueOf(Thread.currentThread().getId()), format, atomicDouble == null ? "-" : format(atomicDouble.get())});
        }
    }

    protected void residualDebugOutputIfRequired(INDArray iNDArray) {
        if (this.encodingDebugMode) {
            double d = this.currentThreshold.get().get();
            String format = format(d);
            INDArray abs = Transforms.abs(iNDArray, true);
            double doubleValue = abs.meanNumber().doubleValue();
            double doubleValue2 = abs.maxNumber().doubleValue();
            double doubleValue3 = abs.percentileNumber(50).doubleValue();
            double doubleValue4 = abs.percentileNumber(95).doubleValue();
            double doubleValue5 = abs.percentileNumber(99).doubleValue();
            double doubleValue6 = abs.percentileNumber(Double.valueOf(99.9d)).doubleValue();
            double doubleValue7 = abs.percentileNumber(Double.valueOf(99.99d)).doubleValue();
            String replace = format(doubleValue).replace('E', 'e');
            String replace2 = format(doubleValue2).replace('E', 'e');
            String replace3 = format(doubleValue3).replace('E', 'e');
            String replace4 = format(doubleValue4).replace('E', 'e');
            String replace5 = format(doubleValue5).replace('E', 'e');
            String replace6 = format(doubleValue6).replace('E', 'e');
            String replace7 = format(doubleValue7).replace('E', 'e');
            String replace8 = format(doubleValue / d).replace('E', 'e');
            String replace9 = format(doubleValue2 / d).replace('E', 'e');
            String replace10 = format(doubleValue3 / d).replace('E', 'e');
            String replace11 = format(doubleValue4 / d).replace('E', 'e');
            String replace12 = format(doubleValue5 / d).replace('E', 'e');
            String replace13 = format(doubleValue6 / d).replace('E', 'e');
            String replace14 = format(doubleValue7 / d).replace('E', 'e');
            long length = abs.length();
            long longValue = abs.gte(Double.valueOf(d)).sumNumber().longValue();
            log.info("Encoding debug info, residual vector: length: {}, threshold: {}, count > thr: {}, sparsity: {}, amean: {} ({}x); amax: {} ({}x); 50%: {} ({}x); 95%: {} ({}x}; 99%: {} ({}x);  99.9%: {} ({}x); 99.99%: {} ({}x)", new Object[]{Long.valueOf(length), format, Long.valueOf(longValue), format(longValue / length), replace, replace8, replace2, replace9, replace3, replace10, replace4, replace11, replace5, replace12, replace6, replace13, replace7, replace14});
        }
    }

    protected static String format(double d) {
        if (d == EvaluationBinary.DEFAULT_EDGE_VALUE) {
            return "0.0";
        }
        if ((d > -0.1d || d <= -100.0d) && (d < 0.1d || d >= 100.0d)) {
            if (formatter.get() == null) {
                formatter.set(new DecimalFormat("0.###E0"));
            }
            return formatter.get().format(d).replace('E', 'e');
        }
        if (formatter2.get() == null) {
            formatter2.set(new DecimalFormat("0.###"));
        }
        return formatter2.get().format(d);
    }

    public ThresholdAlgorithm getAverageThresholdAlgorithm() {
        Collection<ThresholdAlgorithm> values = this.allThreadThresholdAlgorithms.values();
        if (values.isEmpty()) {
            return null;
        }
        if (values.size() == 1) {
            return values.iterator().next();
        }
        ThresholdAlgorithmReducer thresholdAlgorithmReducer = null;
        for (ThresholdAlgorithm thresholdAlgorithm : values) {
            if (thresholdAlgorithmReducer == null) {
                thresholdAlgorithmReducer = thresholdAlgorithm.newReducer();
            }
            thresholdAlgorithmReducer.add(thresholdAlgorithm);
        }
        ThresholdAlgorithm finalResult = thresholdAlgorithmReducer.getFinalResult();
        this.thresholdAlgorithm = new ThreadLocal<>();
        this.allThreadThresholdAlgorithms.clear();
        return finalResult;
    }
}
