package org.nd4j.linalg.api.memory;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/memory/DeviceAllocationsTracker.class */
public class DeviceAllocationsTracker {
    private static final Logger log = LoggerFactory.getLogger(DeviceAllocationsTracker.class);
    private Map<AllocationKind, AtomicLong> bytesMap = new HashMap();

    public DeviceAllocationsTracker() {
        for (AllocationKind allocationKind : AllocationKind.values()) {
            this.bytesMap.put(allocationKind, new AtomicLong(0L));
        }
    }

    public void updateState(@NonNull AllocationKind allocationKind, long j) {
        if (allocationKind == null) {
            throw new NullPointerException("kind is marked @NonNull but is null");
        }
        this.bytesMap.get(allocationKind).addAndGet(j);
    }

    public long getState(@NonNull AllocationKind allocationKind) {
        if (allocationKind == null) {
            throw new NullPointerException("kind is marked @NonNull but is null");
        }
        return this.bytesMap.get(allocationKind).get();
    }
}
