package org.nd4j.linalg.cpu.nativecpu;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.api.ndarray.BaseShapeInfoProvider;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.ShapeDescriptor;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.class */
public class DirectShapeInfoProvider extends BaseShapeInfoProvider {
    private static final Logger log = LoggerFactory.getLogger(DirectShapeInfoProvider.class);
    private Map<ShapeDescriptor, Pair<DataBuffer, long[]>> shapeCache = new ConcurrentHashMap();
    private Map<LongShapeDescriptor, Pair<DataBuffer, long[]>> longCache = new ConcurrentHashMap();
    private AtomicInteger counter = new AtomicInteger(0);
    private static final int MAX_ENTRIES = 1000;

    public Pair<DataBuffer, long[]> createShapeInformation(long[] jArr, long[] jArr2, long j, char c, DataType dataType) {
        return createShapeInformation(jArr, jArr2, j, c, ArrayOptionsHelper.setOptionBit(0L, dataType));
    }

    public Pair<DataBuffer, long[]> createShapeInformation(long[] jArr, long[] jArr2, long j, char c, long j2) {
        if (j < 0) {
            j = 0;
        }
        LongShapeDescriptor longShapeDescriptor = new LongShapeDescriptor(jArr, jArr2, 0L, j, c, j2);
        if (this.longCache.containsKey(longShapeDescriptor)) {
            return this.longCache.get(longShapeDescriptor);
        }
        if (this.counter.get() >= MAX_ENTRIES) {
            return super.createShapeInformation(jArr, jArr2, j, c, j2);
        }
        synchronized (this) {
            if (this.longCache.containsKey(longShapeDescriptor)) {
                return this.longCache.get(longShapeDescriptor);
            }
            this.counter.incrementAndGet();
            Pair<DataBuffer, long[]> createShapeInformation = super.createShapeInformation(jArr, jArr2, j, c, j2);
            this.longCache.put(longShapeDescriptor, createShapeInformation);
            this.bytes.addAndGet(((DataBuffer) createShapeInformation.getFirst()).length() * 8 * 2);
            AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, ((DataBuffer) createShapeInformation.getFirst()).length() * 8 * 2);
            return createShapeInformation;
        }
    }

    public void purgeCache() {
        this.shapeCache = new ConcurrentHashMap();
    }
}
