diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py index 60dbd1591d..0c8b581b0e 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py @@ -14,6 +14,7 @@ MarkdownRenderer SourceCodeRenderer TableRenderer + PyTorchProfilingRenderer """ from .renderer import ( @@ -24,4 +25,16 @@ MarkdownRenderer, SourceCodeRenderer, TableRenderer, + PyTorchProfilingRenderer ) + +__all__ = [ + "BoxRenderer", + "FrameProfilingRenderer", + "GanttChartRenderer", + "ImageRenderer", + "MarkdownRenderer", + "SourceCodeRenderer", + "TableRenderer", + "PyTorchProfilingRenderer" +] diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/pytorch_memory_viz.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/pytorch_memory_viz.py new file mode 100644 index 0000000000..5a7a7851c0 --- /dev/null +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/pytorch_memory_viz.py @@ -0,0 +1,762 @@ +# mypy: allow-untyped-defs +import base64 +import io +import json +import operator +import os +import pickle +import subprocess +import sys +import warnings +from functools import lru_cache +from itertools import groupby +from typing import Any +import tempfile +from pathlib import Path +from subprocess import CalledProcessError + + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + + +def _frame_fmt(f, full_filename=False): + i = f["line"] + fname = f["filename"] + if not full_filename: + fname = fname.split("/")[-1] + func = f["name"] + return f"{fname}:{i}:{func}" + + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [ + _frame_fmt(f, full_filename) + for f in frames + if _frame_filter(f["name"], f["filename"]) + ] + + +def _block_extra_legacy(b): + if "history" in b: + frames = b["history"][0].get("frames", []) + real_size = b["history"][0]["real_size"] + else: + real_size = b.get("requested_size", b["size"]) + frames = [] + return frames, real_size + + +def _block_extra(b): + if "frames" not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b["frames"], b["requested_size"] + + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + """Format flamegraph data using a temporary script. + + Args: + flamegraph_lines: Lines of flamegraph data + flamegraph_script: Optional path to flamegraph script + + Returns: + str: Formatted flamegraph data or HTML visualization + """ + # For testing purposes, return a simple HTML visualization + if os.getenv('TESTING') or not flamegraph_script: + return f""" + + + + Memory Visualization + + +
+ +
{flamegraph_lines}
+
+ + + """ + + try: + # Normal flamegraph generation for production + flamegraph_script = Path(flamegraph_script) + if not flamegraph_script.exists(): + raise FileNotFoundError(f"Flamegraph script not found: {flamegraph_script}") + + result = subprocess.run( + [ + str(flamegraph_script), + "--colors", "python", + "--countname", "bytes", + "--width", "1200", + ], + input=flamegraph_lines, + capture_output=True, + text=True, + check=True + ) + return result.stdout + + except (subprocess.CalledProcessError, FileNotFoundError) as e: + raise RuntimeError(f"Failed to generate flamegraph: {str(e)}") + + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ";".join(_frames_fmt(frames, reverse=True)) + + for b in blocks: + if "history" not in b: + frames, accounted_for_size = _block_extra(b) + f.write( + f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n' + ) + else: + accounted_for_size = 0 + for h in b["history"]: + sz = h["real_size"] + accounted_for_size += sz + if "frames" in h: + frames = h["frames"] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b["size"] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg["address"], seg["total_size"]) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f"only_before = {[a for a, _ in (before_segs - after_segs)]}") + print(f"only_after = {[a for a, _ in (after_segs - before_segs)]}") + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f"only_before;{_seg_info(seg)}", seg["blocks"]) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f"only_after;{_seg_info(seg)}", seg["blocks"]) + + return format_flamegraph(f.getvalue()) + + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + + +def calc_active(seg): + return sum(b["size"] for b in seg["blocks"] if b["state"] == "active_allocated") + + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = "" + if total != 0: + pct = (free_internal / total) * 100 + suffix = f" ({pct:.1f}% internal)" + return f"{Bytes(total)}{suffix}" + + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ + +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted( + data["segments"], key=lambda x: (x["total_size"], calc_active(x)) + ): + total_reserved += seg["total_size"] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg["blocks"]: + active = b["state"] == "active_allocated" + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b["size"] - allocated_size + else: + seg_free_external += b["size"] + + boffset += b["size"] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg["total_size"] - 1) // PAGE_SIZE + 1 + occupied = [" " for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = start_ + size + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord("a" if active else "A") + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != " ": + occupied[j] = "0123456789*"[int(frac[j] * 10)] + else: + occupied[j] = m + stream_info = f' stream_{seg["stream"]}' if seg["stream"] != 0 else "" + body = "".join(occupied) + + total_size = seg_free_external + seg_free_internal + seg_allocated + if total_size != seg["total_size"]: + raise ValueError( + f"Segment size mismatch: {total_size} != {seg['total_size']}" + ) + + if seg["total_size"] >= PAGE_SIZE: + out.write( + f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f"{_report_free(seg_free_external, seg_free_internal)} free{stream_info}\n" + ) + out.write(f'segments: {len(data["segments"])}\n') + out.write(f"total_reserved: {Bytes(total_reserved)}\n") + out.write(f"total_allocated: {Bytes(total_allocated)}\n") + out.write(f"total_free: {_report_free(free_external, free_internal)}\n") + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals: list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names: list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data["segments"]): + saddr = seg["address"] + size = seg["allocated_size"] + if addr >= saddr and addr < saddr + size: + return f"seg_{i}", saddr + return None, None + + count = 0 + out.write(f"{len(entries)} entries\n") + + total_reserved = 0 + for seg in data["segments"]: + total_reserved += seg["total_size"] + + for count, e in enumerate(entries): + if e["action"] == "alloc": + addr, size = e["addr"], e["size"] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f"{n} = {seg_name}[{offset}:{Bytes(size)}]\n") + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e["action"] == "free_requested": + addr, size = e["addr"], e["size"] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"del {name} # {Bytes(size)}\n") + elif e["action"] == "free_completed": + addr, size = e["addr"], e["size"] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"# free completed for {name} {Bytes(size)}\n") + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e["action"] == "segment_alloc": + addr, size = e["addr"], e["size"] + name = _name() + out.write(f"{name} = cudaMalloc({addr}, {Bytes(size)})\n") + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e["action"] == "segment_free": + addr, size = e["addr"], e["size"] + name = segment_addr_to_name.get(addr, addr) + out.write(f"cudaFree({name}) # {Bytes(size)}\n") + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e["action"] == "oom": + size = e["size"] + free = e["device_free"] + out.write( + f"raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n" + ) + else: + out.write(f"{e}\n") + out.write(f"TOTAL MEM: {Bytes(count)}") + + for i, d in enumerate(data["device_traces"]): + if d: + out.write(f"Device {i} ----------------\n") + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn( + "device argument is deprecated, plots now contain all device", + FutureWarning, + stacklevel=3, + ) + buffer = pickle.dumps(data) + buffer += b"\x00" * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode("utf-8") + + json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}]) + return _memory_viz_template.replace("$VIZ_KIND", repr(viz_kind)).replace( + "$SNAPSHOT", json_format + ) + + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + + Args: + data: Memory snapshot as generated from torch.cuda.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. + Defaults to False. + + Returns: + str: HTML of visualization + """ + return _format_viz( + data, + "Active Memory Timeline" + if not plot_segments + else "Active Cached Memory Timeline", + device, + ) + + +def _profile_to_snapshot(profile): + import torch + from torch._C._profiler import _EventType + from torch.profiler._memory_profiler import Action, TensorKey + + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + device_count = torch.cuda.device_count() + snapshot: dict[str, list[Any]] = { + "device_traces": [[] for _ in range(device_count + 1)], + "segments": [ + { + "device": device, + "address": None, + "total_size": 0, + "stream": 0, + "blocks": [], + } + for device in range(device_count + 1) + ], + } + + def to_device(device): + if device.type == "cuda": + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot["segments"][device] # type: ignore[index] + if seg["address"] is None or seg["address"] > addr: + seg["address"] = addr + seg["total_size"] = max( + seg["total_size"], addr + size + ) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{"filename": "none", "line": 0, "name": p.name} for p in stack] + r = { + "action": "alloc", + "addr": addr, + "size": size, + "stream": 0, + "frames": stack, + "category": category, + } + if during_trace: + snapshot["device_traces"][device].append(r) + return r + + def free(alloc, device): + for e in ("free_requested", "free_completed"): + snapshot["device_traces"][device].append( + { + "action": e, + "addr": alloc["addr"], + "size": alloc["size"], + "stream": 0, + "frames": alloc["frames"], + } + ) + + kv_to_elem = {} + + # create the device trace + for _time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate( + size, tensor_key, version + 1 + ) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate( + size, tensor_key, version, during_trace=False + ) + + # create the final snapshot state + blocks_at_end = [ + (to_device(tensor_key.device), event["addr"], event["size"], event["frames"]) + for (tensor_key, version), event in kv_to_elem.items() + ] + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): + seg = snapshot["segments"][device] # type: ignore[index] + last_addr = seg["address"] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg["blocks"].append({"size": addr - last_addr, "state": "inactive"}) + seg["blocks"].append( + { + "size": size, + "state": "active_allocated", + "requested_size": size, + "frames": frames, + } + ) + last_addr = addr + size + if last_addr < seg["total_size"]: + seg["blocks"].append( + {"size": seg["total_size"] - last_addr, "state": "inactive"} + ) + + snapshot["segments"] = [seg for seg in snapshot["segments"] if seg["blocks"]] # type: ignore[attr-defined] + for seg in snapshot["segments"]: # type: ignore[attr-defined, name-defined, no-redef] + seg["total_size"] -= seg["address"] + if not seg["blocks"]: + seg["blocks"].append({"size": seg["total_size"], "state": "inactive"}) + + return snapshot + + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, "Active Memory Timeline", device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, "Allocator State History", device) + + +if __name__ == "__main__": + import os.path + + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find cuda/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = "torch.cuda.memory._snapshot()" + pickled = f"pickled memory statistics from {fn_name}" + parser = argparse.ArgumentParser( + description=f"Visualize memory dumps produced by {fn_name}" + ) + + subparsers = parser.add_subparsers(dest="action") + + def _output(p): + p.add_argument( + "-o", + "--output", + default="output.svg", + help="flamegraph svg (default: output.svg)", + ) + + description = "Prints overall allocation statistics and a visualization of how the allocators segments are currently filled." + stats_a = subparsers.add_parser("stats", description=description) + stats_a.add_argument("input", help=pickled) + + description = "Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style." + trace_a = subparsers.add_parser("trace", description=description) + trace_a.add_argument("input", help=pickled) + + description = "Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)" + segments_a = subparsers.add_parser("segments", description=description) + segments_a.add_argument("input", help=pickled) + _output(segments_a) + + description = ( + "Generate a flamegraph the program locations contributing to CUDA memory usage." + ) + memory_a = subparsers.add_parser("memory", description=description) + memory_a.add_argument("input", help=pickled) + _output(memory_a) + + description = ( + "Generate a flamegraph that shows segments (aka blocks) that have been added " + "or removed between two different memorys snapshots." + ) + compare_a = subparsers.add_parser("compare", description=description) + compare_a.add_argument("before", help=pickled) + compare_a.add_argument("after", help=pickled) + _output(compare_a) + + plots = ( + ( + "trace_plot", + "Generate a visualization over time of the memory usage recorded by the trace as an html file.", + ), + ( + "segment_plot", + "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.", + ), + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument("input", help=pickled) + help = "visualize trace from this device (default: chooses the only device with trace info or errors)" + trace_plot_a.add_argument("-d", "--device", type=int, default=None, help=help) + help = "path to save the visualization(default: output.html)" + trace_plot_a.add_argument("-o", "--output", default="output.html", help=help) + if cmd == "trace_plot": + help = "visualize change to segments rather than individual allocations" + trace_plot_a.add_argument( + "-s", "--segments", action="store_true", help=help + ) + + args = parser.parse_args() + + def _read(name): + if name == "-": + f = sys.stdin.buffer + else: + f = open(name, "rb") + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {"segments": data, "traces": []} + return data + + def _write(name, data): + with open(name, "w") as f: + f.write(data) + + if args.action == "segments": + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == "memory": + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == "stats": + data = _read(args.input) + print(segsum(data)) + elif args.action == "trace": + data = _read(args.input) + print(trace(data)) + elif args.action == "compare": + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == "trace_plot": + data = _read(args.input) + _write( + args.output, + trace_plot(data, device=args.device, plot_segments=args.segments), + ) + elif args.action == "segment_plot": + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index 708e941d88..fee8334b64 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -3,6 +3,19 @@ from flytekit import lazy_module from flytekit.types.file import FlyteFile +from .pytorch_memory_viz import ( + trace_plot, + segment_plot, + memory, + segments, + compare, + profile_plot, + _format_size, # Could be useful for formatting memory sizes + segsum, # Could be useful for summary visualization + trace # Could be useful for detailed trace visualization +) +import pickle +from pathlib import Path if TYPE_CHECKING: import markdown @@ -212,3 +225,325 @@ def to_html(self, df: "pd.DataFrame", chart_width: Optional[int] = None) -> str: ) return fig.to_html() + +def _ensure_profiling_structure(data: dict) -> dict: + """Ensures profiling data has the correct structure with required keys. + + Args: + data: Dictionary containing profiling data + + Returns: + dict: Structured profiling data with required keys + + Raises: + ValueError: If data structure is invalid + """ + if not isinstance(data, dict): + raise ValueError("Profiling data must be a dictionary") + + # Validate segments key exists and is a list + if "segments" not in data: + raise ValueError("Profiling data missing 'segments' key") + if not isinstance(data["segments"], list): + raise ValueError("'segments' must be a list") + + return { + "segments": data["segments"], + "traces": data.get("traces", []), + "allocator_settings": data.get("allocator_settings", {}) + } + +class PyTorchProfilingRenderer: + """Renders PyTorch profiling data in various visualization formats. + + This renderer is particularly useful for analyzing memory usage and potential + memory-related failures in PyTorch executions. It can help diagnose OOM (Out of Memory) + errors and memory leaks by providing various visualization types. + + Supports multiple visualization types: + - trace_plot: Shows the execution timeline + - segment_plot: Shows the execution segments + - memory: Displays memory usage over time + - segments: Shows detailed segment information + - compare: Compares two profiling snapshots + - profile_plot: Shows detailed profiling information + - summary: Shows overall allocation statistics + - trace_view: Shows detailed trace information + + The renderer can be particularly helpful in: + 1. Analyzing failed executions due to OOM errors + 2. Identifying memory leaks + 3. Understanding memory usage patterns + 4. Comparing memory states before and after operations + """ + def __init__(self, profiling_data): + """Initialize the renderer with profiling data. + + Args: + profiling_data: Single snapshot or tuple of (before, after) snapshots + + Raises: + ValueError: If profiling data is invalid + """ + self._validate_profiling_data(profiling_data) + self.profiling_data = profiling_data + + def _validate_profiling_data(self, profiling_data): + """Validates profiling data structure and checks for potential issues. + + Args: + profiling_data: Data to validate + + Raises: + ValueError: If data is None or snapshots are invalid + """ + if profiling_data is None: + raise ValueError("Profiling data cannot be None") + + # Handle both single snapshot and comparison cases + if isinstance(profiling_data, tuple): + before, after = profiling_data + if before is None or after is None: + raise ValueError("Both before and after snapshots must be provided for comparison") + + # Check if this might be an OOM case by comparing memory usage + try: + self._check_memory_growth(before, after) + except Exception: + # Don't fail initialization if memory check fails + pass + + def _check_memory_growth(self, before, after): + """Check for significant memory growth between snapshots""" + before_mem = self._get_total_memory(before) + after_mem = self._get_total_memory(after) + if after_mem > before_mem * 1.5: # 50% growth threshold + warnings.warn( + f"Significant memory growth detected: {_format_size(before_mem)} -> {_format_size(after_mem)}", + RuntimeWarning + ) + + def _get_total_memory(self, snapshot): + """Get total memory usage from a snapshot""" + total = 0 + for seg in snapshot.get("segments", []): + total += seg.get("total_size", 0) + return total + + def get_failure_analysis(self) -> str: + """ + Analyze profiling data for potential failure causes. + Particularly useful for OOM and memory-related failures. + + Returns: + str: HTML formatted analysis of potential issues + """ + analysis = [] + + # Get memory summary + memory_summary = self.get_memory_summary() + analysis.append("

Memory Usage Summary

") + analysis.append(f"
{memory_summary}
") + + # Get trace summary for context + trace_summary = self.get_trace_summary() + analysis.append("

Execution Trace Summary

") + analysis.append(f"
{trace_summary}
") + + # Add memory visualization + analysis.append("

Memory Usage Visualization

") + analysis.append(self.to_html("memory")) + + return "\n".join(analysis) + + def get_memory_metrics(self) -> dict: + """ + Get key memory metrics that might be useful for failure analysis + + Returns: + dict: Dictionary containing memory metrics + """ + metrics = { + "peak_memory": 0, + "total_allocations": 0, + "largest_allocation": 0, + "memory_at_failure": 0, + } + + try: + # Extract metrics from profiling data + if isinstance(self.profiling_data, tuple): + # For comparison case, use the 'after' snapshot + data = self.profiling_data[1] + else: + data = self.profiling_data + + for seg in data.get("segments", []): + metrics["peak_memory"] = max(metrics["peak_memory"], seg.get("total_size", 0)) + for block in seg.get("blocks", []): + if block.get("state") == "active_allocated": + metrics["total_allocations"] += 1 + metrics["largest_allocation"] = max( + metrics["largest_allocation"], + block.get("size", 0) + ) + + # Get the last known memory state + metrics["memory_at_failure"] = self._get_total_memory(data) + + except Exception as e: + warnings.warn(f"Failed to extract memory metrics: {str(e)}") + + return metrics + + def format_memory_size(self, size: int) -> str: + """Format memory size using the _memory_viz helper""" + return _format_size(size) + + def to_html(self, plot_type: str = "trace_plot") -> str: + """Convert profiling data to HTML visualization.""" + + # Define memory_viz_js at the start so it's available for all branches + memory_viz_js = """ + + + """ + + if plot_type == "profile_plot": + import torch + try: + # Create a profile object without initializing it + profile = torch.profiler.profile() + # Set basic attributes needed for visualization + profile.steps = [] + profile.events = [] + profile.key_averages = [] + + # Copy the data from our profiling_data + if isinstance(self.profiling_data, dict): + for key, value in self.profiling_data.items(): + setattr(profile, key, value) + content = profile_plot(profile) + except Exception as e: + content = f"
Failed to generate profile plot: {str(e)}
" + + elif plot_type == "compare": + if not isinstance(self.profiling_data, tuple): + raise ValueError("Compare plot type requires before/after snapshots") + before, after = self.profiling_data + + try: + before = _ensure_profiling_structure(before) + after = _ensure_profiling_structure(after) + content = compare(before["segments"], after["segments"]) + except ValueError as e: + content = f"
Failed to generate comparison: {str(e)}
" + + elif plot_type == "trace_plot": + content = trace_plot(self.profiling_data) + elif plot_type == "segment_plot": + content = segment_plot(self.profiling_data) + elif plot_type == "memory": + content = memory(self.profiling_data) + elif plot_type == "segments": + content = segments(self.profiling_data) + else: + raise ValueError(f"Unknown plot type: {plot_type}") + + return f""" + + + + + PyTorch Memory Profiling + {memory_viz_js if plot_type in ["memory", "segments", "compare"] else ""} + + + {content} + + + """ + + def get_memory_summary(self) -> str: + """Get a text summary of memory usage""" + return segsum(self.profiling_data) + + def get_trace_summary(self) -> str: + """Get a text summary of the trace""" + return trace(self.profiling_data) + + @staticmethod + def load_from_file(file_path: str) -> 'PyTorchProfilingRenderer': + """Create a renderer instance from a pickle file. + + Args: + file_path: Path to the pickle file containing profiling data + + Returns: + PyTorchProfilingRenderer: New renderer instance + + Raises: + ValueError: If file loading or deserialization fails + FileNotFoundError: If file doesn't exist + """ + try: + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"Profile file not found: {file_path}") + + with file_path.open("rb") as f: + # Use safe loading with protocol and encoding specified + try: + profiling_data = pickle.load(f, encoding='bytes') + except pickle.UnpicklingError as e: + raise ValueError(f"Failed to deserialize profiling data: {str(e)}") + + return PyTorchProfilingRenderer(profiling_data) + + except Exception as e: + raise ValueError(f"Failed to load profiling data: {str(e)}") + +def render_pytorch_profiling(profiling_file: FlyteFile, plot_type: str = "trace_plot") -> str: + """Renders PyTorch profiling data from a pickle file into HTML visualization. + + Args: + profiling_file (FlyteFile): Pickle file containing PyTorch profiling data + plot_type (str): Type of visualization to generate + + Returns: + str: HTML string containing the visualization + + Raises: + FileNotFoundError: If profiling file doesn't exist + ValueError: If plot type is invalid or data loading fails + """ + # Load the profiling data from the .pkl file + try: + with open(profiling_file, "rb") as f: + profiling_data = pickle.load(f) + except Exception as e: + raise ValueError(f"Failed to load profiling data: {str(e)}") + # Create an instance of the renderer and generate the HTML + renderer = PyTorchProfilingRenderer(profiling_data) + return renderer.to_html(plot_type) + +def test_compare_plot_type(): + """Test the compare plot type which requires two snapshots""" + with open(PROFILE_PATH, "rb") as f: + profiling_data = pickle.load(f) + + # Create proper before/after snapshots + before = {"segments": [], "traces": []} # Empty snapshot + after = profiling_data # Your actual data + + renderer = PyTorchProfilingRenderer((before, after)) + html_output = renderer.to_html("compare") + + assert isinstance(html_output, str) + assert "" in html_output \ No newline at end of file diff --git a/plugins/flytekit-deck-standard/tests/test_pytorch_profiling_renderer.py b/plugins/flytekit-deck-standard/tests/test_pytorch_profiling_renderer.py new file mode 100644 index 0000000000..a806906ec9 --- /dev/null +++ b/plugins/flytekit-deck-standard/tests/test_pytorch_profiling_renderer.py @@ -0,0 +1,130 @@ +import os +import pytest +import pickle +import tempfile +from flytekit.types.file import FlyteFile +from flytekitplugins.deck.renderer import PyTorchProfilingRenderer, render_pytorch_profiling + +# Get the current directory where the test file is located +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PROFILE_PATH = os.path.join(CURRENT_DIR, "profile1.pkl") + +def test_pytorch_profiling_renderer_initialization(): + """Test PyTorchProfilingRenderer initialization""" + with open(PROFILE_PATH, "rb") as f: + profiling_data = pickle.load(f) + + renderer = PyTorchProfilingRenderer(profiling_data) + assert renderer.profiling_data == profiling_data + +def test_pytorch_profiling_renderer_invalid_data(): + """Test PyTorchProfilingRenderer with invalid data""" + with pytest.raises(ValueError): + PyTorchProfilingRenderer(None) + +def _create_profiling_structure(data): + """Creates a standardized profiling data structure. + + Args: + data: Raw profiling data + + Returns: + dict: Structured profiling data with required keys + """ + return { + "segments": data.get("segments", []) if hasattr(data, "get") else [], + "traces": data.get("traces", []) if hasattr(data, "get") else [], + "allocator_settings": data.get("allocator_settings", {}) if hasattr(data, "get") else {} + } + +@pytest.mark.parametrize("plot_type,expected_content", [ + ("trace_plot", ""), + ("segment_plot", ""), + ("memory", "MemoryViz.js"), + ("segments", "MemoryViz.js"), + ("profile_plot", "
") +]) +def test_pytorch_profiling_renderer_plot_types(plot_type, expected_content): + """Test different plot types for PyTorchProfilingRenderer""" + with open(PROFILE_PATH, "rb") as f: + profiling_data = pickle.load(f) + + # Convert profiling data to proper format if needed + if plot_type in ["memory", "segments"]: + profiling_data = _create_profiling_structure(profiling_data) + elif plot_type == "profile_plot": + profiling_data = { + "steps": [], + "events": [], + "key_averages": [] + } + + renderer = PyTorchProfilingRenderer(profiling_data) + html_output = renderer.to_html(plot_type) + + assert isinstance(html_output, str) + assert "" in html_output + assert expected_content in html_output + +def test_pytorch_profiling_renderer_invalid_plot_type(): + """Test PyTorchProfilingRenderer with invalid plot type""" + with open(PROFILE_PATH, "rb") as f: + profiling_data = pickle.load(f) + + renderer = PyTorchProfilingRenderer(profiling_data) + with pytest.raises(ValueError, match="Unknown plot type"): + renderer.to_html("invalid_plot_type") + +def test_render_pytorch_profiling_function(): + """Test the render_pytorch_profiling helper function""" + profiling_file = FlyteFile(PROFILE_PATH) + + # Test with default plot type + html_output = render_pytorch_profiling(profiling_file) + assert isinstance(html_output, str) + assert "" in html_output + + # Test with specific plot type + html_output = render_pytorch_profiling(profiling_file, plot_type="memory") + assert isinstance(html_output, str) + assert "" in html_output + +def test_render_pytorch_profiling_file_not_found(): + """Test render_pytorch_profiling with non-existent file""" + non_existent_file = FlyteFile("non_existent.pkl") + with pytest.raises(ValueError, match="Failed to load profiling data"): + render_pytorch_profiling(non_existent_file) + +def test_invalid_file_handling(): + """Test handling of invalid profiling data files.""" + + # Use tempfile context manager for safe cleanup + with tempfile.NamedTemporaryFile(mode='w', suffix='.pkl', delete=False) as temp_file: + # Write invalid data + temp_file.write("not a pickle file") + temp_file.flush() + + # Create FlyteFile from temp file + invalid_file = FlyteFile(temp_file.name) + + # Test that it raises the expected error + with pytest.raises(ValueError, match="Failed to load profiling data"): + render_pytorch_profiling(invalid_file) + + # No need for manual cleanup - tempfile handles it automatically + +def test_compare_plot_type(): + """Test the compare plot type which requires two snapshots""" + with open(PROFILE_PATH, "rb") as f: + profiling_data = pickle.load(f) + + # Create proper before/after snapshots + before = _create_profiling_structure({}) # Empty snapshot + after = _create_profiling_structure(profiling_data) # Your actual data + + renderer = PyTorchProfilingRenderer((before, after)) + html_output = renderer.to_html("compare") + + assert isinstance(html_output, str) + assert "" in html_output + assert "MemoryViz.js" in html_output \ No newline at end of file