Skip to content

Commit

Permalink
Using untyped storage ref instead of tensor ref in memory tracker (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketpurandare authored Mar 28, 2024
1 parent 276d2f0 commit f655b9f
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions max_mem_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,31 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils.weak import WeakIdKeyDictionary
import weakref
import math

# Track all the memory being used by Tensors.
# Only max is tracked but others can be added.
MEMORY_USE = WeakIdKeyDictionary()
MEMORY_MAX = 0
MEMORY_ID = 0
# Minimum allocation size
PYTORCH_MIN_ALLOCATE = 2**9

def update_stats():
global MEMORY_MAX
curr_use = 0
for k, v in MEMORY_USE.items():
curr_use += k.nelement() * k.element_size()
curr_use += math.ceil(k.size() * k.element_size()/PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE

if MEMORY_MAX < curr_use:
MEMORY_MAX = curr_use

# Should be called on every Tensor created
def track(t):
def track(t:torch.Tensor):
def cb(_):
update_stats()

wt = weakref.ref(t, cb)
MEMORY_USE[t] = wt
st = t.untyped_storage()
wt = weakref.ref(st, cb)
MEMORY_USE[st] = wt
update_stats()

# Use this Mode to call track on every Tensor being created by functions
Expand Down

0 comments on commit f655b9f

Please sign in to comment.