Skip to content

Commit c19e280

Browse files
committed
Capture the threading.Thread.name attribute
Previously we only captured the thread name as set by `prctl`, but not the thread name as returned by `threading.current_thread().name`. Begin capturing the name for the Python thread as well. We retain only the last name set for each thread, so assignments to `Thread.name` override earlier calls to `prctl(PR_SET_NAME)`, and vice versa. This implementation uses a custom descriptor to intercept assignments to `Thread._name` and `Thread._ident` in order to detect when a thread has a name or a thread id assigned to it. Because this is tricky and a bit fragile (poking at the internals of `Thread`), I've implemented that descriptor in a Python module. At least that way if it ever breaks, it should be a bit easier for someone to investigate. Signed-off-by: Matt Wozniski <mwozniski@bloomberg.net>
1 parent e8ae4d0 commit c19e280

File tree

7 files changed

+141
-1
lines changed

7 files changed

+141
-1
lines changed

news/562.feature.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Capture the name attribute of Python `threading.Thread` objects.

src/memray/_memray.pyx

+12-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ from ._destination import FileDestination
7474
from ._destination import SocketDestination
7575
from ._metadata import Metadata
7676
from ._stats import Stats
77+
from ._thread_name_interceptor import ThreadNameInterceptor
7778

7879

7980
def set_log_level(int level):
@@ -691,7 +692,6 @@ cdef class Tracker:
691692

692693
@cython.profile(False)
693694
def __enter__(self):
694-
695695
if NativeTracker.getTracker() != NULL:
696696
raise RuntimeError("No more than one Tracker instance can be active at the same time")
697697

@@ -700,6 +700,14 @@ cdef class Tracker:
700700
raise RuntimeError("Attempting to use stale output handle")
701701
writer = move(self._writer)
702702

703+
for attr in ("_name", "_ident"):
704+
assert not hasattr(threading.Thread, attr)
705+
setattr(
706+
threading.Thread,
707+
attr,
708+
ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById),
709+
)
710+
703711
self._previous_profile_func = sys.getprofile()
704712
self._previous_thread_profile_func = threading._profile_hook
705713
threading.setprofile(start_thread_trace)
@@ -722,6 +730,9 @@ cdef class Tracker:
722730
sys.setprofile(self._previous_profile_func)
723731
threading.setprofile(self._previous_thread_profile_func)
724732

733+
for attr in ("_name", "_ident"):
734+
delattr(threading.Thread, attr)
735+
725736

726737
def start_thread_trace(frame, event, arg):
727738
if event in {"call", "c_call"}:

src/memray/_memray/tracking_api.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ Tracker::trackAllocationImpl(
842842
hooks::Allocator func,
843843
const std::optional<NativeTrace>& trace)
844844
{
845+
registerCachedThreadName();
845846
PythonStackTracker::get().emitPendingPushesAndPops();
846847

847848
if (d_unwind_native_frames) {
@@ -871,6 +872,7 @@ Tracker::trackAllocationImpl(
871872
void
872873
Tracker::trackDeallocationImpl(void* ptr, size_t size, hooks::Allocator func)
873874
{
875+
registerCachedThreadName();
874876
AllocationRecord record{reinterpret_cast<uintptr_t>(ptr), size, func};
875877
if (!d_writer->writeThreadSpecificRecord(thread_id(), record)) {
876878
std::cerr << "Failed to write output, deactivating tracking" << std::endl;
@@ -963,12 +965,37 @@ void
963965
Tracker::registerThreadNameImpl(const char* name)
964966
{
965967
RecursionGuard guard;
968+
dropCachedThreadName();
966969
if (!d_writer->writeThreadSpecificRecord(thread_id(), ThreadRecord{name})) {
967970
std::cerr << "memray: Failed to write output, deactivating tracking" << std::endl;
968971
deactivate();
969972
}
970973
}
971974

975+
void
976+
Tracker::registerCachedThreadName()
977+
{
978+
if (d_cached_thread_names.empty()) {
979+
return;
980+
}
981+
982+
auto it = d_cached_thread_names.find((uint64_t)(pthread_self()));
983+
if (it != d_cached_thread_names.end()) {
984+
auto& name = it->second;
985+
if (!d_writer->writeThreadSpecificRecord(thread_id(), ThreadRecord{name.c_str()})) {
986+
std::cerr << "memray: Failed to write output, deactivating tracking" << std::endl;
987+
deactivate();
988+
}
989+
d_cached_thread_names.erase(it);
990+
}
991+
}
992+
993+
void
994+
Tracker::dropCachedThreadName()
995+
{
996+
d_cached_thread_names.erase((uint64_t)(pthread_self()));
997+
}
998+
972999
frame_id_t
9731000
Tracker::registerFrame(const RawFrame& frame)
9741001
{

src/memray/_memray/tracking_api.h

+24
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,27 @@ class Tracker
290290
}
291291
}
292292

293+
inline static void registerThreadNameById(uint64_t thread, const char* name)
294+
{
295+
if (RecursionGuard::isActive || !Tracker::isActive()) {
296+
return;
297+
}
298+
RecursionGuard guard;
299+
300+
std::unique_lock<std::mutex> lock(*s_mutex);
301+
Tracker* tracker = getTracker();
302+
if (tracker) {
303+
if (thread == (uint64_t)(pthread_self())) {
304+
tracker->registerThreadNameImpl(name);
305+
} else {
306+
// We've got a different thread's name, but don't know what id
307+
// has been assigned to that thread (if any!). Set this update
308+
// aside to be handled later, from that thread.
309+
tracker->d_cached_thread_names.emplace(thread, name);
310+
}
311+
}
312+
}
313+
293314
// RawFrame stack interface
294315
bool pushFrame(const RawFrame& frame);
295316
bool popFrames(uint32_t count);
@@ -359,6 +380,7 @@ class Tracker
359380
const bool d_trace_python_allocators;
360381
linker::SymbolPatcher d_patcher;
361382
std::unique_ptr<BackgroundThread> d_background_thread;
383+
std::unordered_map<uint64_t, std::string> d_cached_thread_names;
362384

363385
// Methods
364386
static size_t computeMainTidSkip();
@@ -373,6 +395,8 @@ class Tracker
373395
void invalidate_module_cache_impl();
374396
void updateModuleCacheImpl();
375397
void registerThreadNameImpl(const char* name);
398+
void registerCachedThreadName();
399+
void dropCachedThreadName();
376400
void registerPymallocHooks() const noexcept;
377401
void unregisterPymallocHooks() const noexcept;
378402

src/memray/_memray/tracking_api.pxd

+4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from _memray.record_writer cimport RecordWriter
2+
from libc.stdint cimport uint64_t
23
from libcpp cimport bool
34
from libcpp.memory cimport unique_ptr
45
from libcpp.string cimport string
@@ -31,3 +32,6 @@ cdef extern from "tracking_api.h" namespace "memray::tracking_api":
3132

3233
@staticmethod
3334
void handleGreenletSwitch(object, object) except+
35+
36+
@staticmethod
37+
void registerThreadNameById(uint64_t, const char*) except+
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import threading
2+
from typing import Callable
3+
4+
5+
class ThreadNameInterceptor:
6+
"""Record the name of each threading.Thread for Memray's reports.
7+
8+
The name can be set either before or after the thread is started, and from
9+
either the same thread or a different thread. Whenever an assignment to
10+
either `Thread._name` or `Thread._ident` is performed and the other has
11+
already been set, we call a callback with the thread's ident and name.
12+
"""
13+
14+
def __init__(self, attr: str, callback: Callable[[int, str], None]) -> None:
15+
self._attr = attr
16+
self._callback = callback
17+
18+
def __set__(self, instance: threading.Thread, value: object) -> None:
19+
instance.__dict__[self._attr] = value
20+
ident = instance.__dict__.get("_ident")
21+
name = instance.__dict__.get("_name")
22+
if ident is not None and name is not None:
23+
self._callback(ident, name)

tests/integration/test_threads.py

+50
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,53 @@ def allocating_function():
9494
(valloc,) = vallocs
9595
assert valloc.size == 1234
9696
assert "my thread name" == valloc.thread_name
97+
98+
99+
def test_setting_python_thread_name(tmpdir):
100+
# GIVEN
101+
output = Path(tmpdir) / "test.bin"
102+
allocator = MemoryAllocator()
103+
name_set_inside_thread = threading.Event()
104+
name_set_outside_thread = threading.Event()
105+
prctl_rc = -1
106+
107+
def allocating_function():
108+
allocator.valloc(1234)
109+
allocator.free()
110+
111+
threading.current_thread().name = "set inside thread"
112+
allocator.valloc(1234)
113+
allocator.free()
114+
115+
name_set_inside_thread.set()
116+
name_set_outside_thread.wait()
117+
allocator.valloc(1234)
118+
allocator.free()
119+
120+
nonlocal prctl_rc
121+
prctl_rc = set_thread_name("set by prctl")
122+
allocator.valloc(1234)
123+
allocator.free()
124+
125+
# WHEN
126+
with Tracker(output):
127+
t = threading.Thread(target=allocating_function, name="set before start")
128+
t.start()
129+
name_set_inside_thread.wait()
130+
t.name = "set outside running thread"
131+
name_set_outside_thread.set()
132+
t.join()
133+
134+
# THEN
135+
expected_names = [
136+
"set before start",
137+
"set inside thread",
138+
"set outside running thread",
139+
"set by prctl" if prctl_rc == 0 else "set outside running thread",
140+
]
141+
names = [
142+
rec.thread_name
143+
for rec in FileReader(output).get_allocation_records()
144+
if rec.allocator == AllocatorType.VALLOC
145+
]
146+
assert names == expected_names

0 commit comments

Comments
 (0)