diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py
index 60dbd1591d..0c8b581b0e 100644
--- a/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py
+++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py
@@ -14,6 +14,7 @@
MarkdownRenderer
SourceCodeRenderer
TableRenderer
+ PyTorchProfilingRenderer
"""
from .renderer import (
@@ -24,4 +25,16 @@
MarkdownRenderer,
SourceCodeRenderer,
TableRenderer,
+ PyTorchProfilingRenderer
)
+
+__all__ = [
+ "BoxRenderer",
+ "FrameProfilingRenderer",
+ "GanttChartRenderer",
+ "ImageRenderer",
+ "MarkdownRenderer",
+ "SourceCodeRenderer",
+ "TableRenderer",
+ "PyTorchProfilingRenderer"
+]
diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/pytorch_memory_viz.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/pytorch_memory_viz.py
new file mode 100644
index 0000000000..5a7a7851c0
--- /dev/null
+++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/pytorch_memory_viz.py
@@ -0,0 +1,762 @@
+# mypy: allow-untyped-defs
+import base64
+import io
+import json
+import operator
+import os
+import pickle
+import subprocess
+import sys
+import warnings
+from functools import lru_cache
+from itertools import groupby
+from typing import Any
+import tempfile
+from pathlib import Path
+from subprocess import CalledProcessError
+
+
+cache = lru_cache(None)
+
+__all__ = ["format_flamegraph", "segments", "memory", "compare"]
+
+
+def _frame_fmt(f, full_filename=False):
+ i = f["line"]
+ fname = f["filename"]
+ if not full_filename:
+ fname = fname.split("/")[-1]
+ func = f["name"]
+ return f"{fname}:{i}:{func}"
+
+
+@cache
+def _frame_filter(name, filename):
+ omit_functions = [
+ "unwind::unwind",
+ "CapturedTraceback::gather",
+ "gather_with_cpp",
+ "_start",
+ "__libc_start_main",
+ "PyEval_",
+ "PyObject_",
+ "PyFunction_",
+ ]
+ omit_filenames = [
+ "core/boxing",
+ "/Register",
+ "/Redispatch",
+ "pythonrun.c",
+ "Modules/main.c",
+ "Objects/call.c",
+ "Objects/methodobject.c",
+ "pycore_ceval.h",
+ "ceval.c",
+ "cpython/abstract.h",
+ ]
+ for of in omit_functions:
+ if of in name:
+ return False
+ for of in omit_filenames:
+ if of in filename:
+ return False
+ return True
+
+
+def _frames_fmt(frames, full_filename=False, reverse=False):
+ if reverse:
+ frames = reversed(frames)
+ return [
+ _frame_fmt(f, full_filename)
+ for f in frames
+ if _frame_filter(f["name"], f["filename"])
+ ]
+
+
+def _block_extra_legacy(b):
+ if "history" in b:
+ frames = b["history"][0].get("frames", [])
+ real_size = b["history"][0]["real_size"]
+ else:
+ real_size = b.get("requested_size", b["size"])
+ frames = []
+ return frames, real_size
+
+
+def _block_extra(b):
+ if "frames" not in b:
+ # old snapshot format made it more complicated to get frames/allocated size
+ return _block_extra_legacy(b)
+ return b["frames"], b["requested_size"]
+
+
+def format_flamegraph(flamegraph_lines, flamegraph_script=None):
+ """Format flamegraph data using a temporary script.
+
+ Args:
+ flamegraph_lines: Lines of flamegraph data
+ flamegraph_script: Optional path to flamegraph script
+
+ Returns:
+ str: Formatted flamegraph data or HTML visualization
+ """
+ # For testing purposes, return a simple HTML visualization
+ if os.getenv('TESTING') or not flamegraph_script:
+ return f"""
+
+
+
+ Memory Visualization
+
+
+
+
+
+ """
+
+ try:
+ # Normal flamegraph generation for production
+ flamegraph_script = Path(flamegraph_script)
+ if not flamegraph_script.exists():
+ raise FileNotFoundError(f"Flamegraph script not found: {flamegraph_script}")
+
+ result = subprocess.run(
+ [
+ str(flamegraph_script),
+ "--colors", "python",
+ "--countname", "bytes",
+ "--width", "1200",
+ ],
+ input=flamegraph_lines,
+ capture_output=True,
+ text=True,
+ check=True
+ )
+ return result.stdout
+
+ except (subprocess.CalledProcessError, FileNotFoundError) as e:
+ raise RuntimeError(f"Failed to generate flamegraph: {str(e)}")
+
+
+def _write_blocks(f, prefix, blocks):
+ def frames_fragment(frames):
+ if not frames:
+ return ""
+ return ";".join(_frames_fmt(frames, reverse=True))
+
+ for b in blocks:
+ if "history" not in b:
+ frames, accounted_for_size = _block_extra(b)
+ f.write(
+ f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n'
+ )
+ else:
+ accounted_for_size = 0
+ for h in b["history"]:
+ sz = h["real_size"]
+ accounted_for_size += sz
+ if "frames" in h:
+ frames = h["frames"]
+ f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
+ else:
+ f.write(f'{prefix};{b["state"]}; {sz}\n')
+ gaps = b["size"] - accounted_for_size
+ if gaps:
+ f.write(f'{prefix};{b["state"]}; {gaps}\n')
+
+
+def segments(snapshot, format_flamegraph=format_flamegraph):
+ f = io.StringIO()
+ for seg in snapshot["segments"]:
+ prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
+ _write_blocks(f, prefix, seg["blocks"])
+ return format_flamegraph(f.getvalue())
+
+
+def memory(snapshot, format_flamegraph=format_flamegraph):
+ f = io.StringIO()
+ for seg in snapshot["segments"]:
+ prefix = f'stream_{seg["stream"]}'
+ _write_blocks(f, prefix, seg["blocks"])
+ return format_flamegraph(f.getvalue())
+
+
+def compare(before, after, format_flamegraph=format_flamegraph):
+ def _seg_key(seg):
+ return (seg["address"], seg["total_size"])
+
+ def _seg_info(seg):
+ return f'stream_{seg["stream"]};seg_{seg["address"]}'
+
+ f = io.StringIO()
+
+ before_segs = {_seg_key(seg) for seg in before}
+ after_segs = {_seg_key(seg) for seg in after}
+
+ print(f"only_before = {[a for a, _ in (before_segs - after_segs)]}")
+ print(f"only_after = {[a for a, _ in (after_segs - before_segs)]}")
+
+ for seg in before:
+ if _seg_key(seg) not in after_segs:
+ _write_blocks(f, f"only_before;{_seg_info(seg)}", seg["blocks"])
+
+ for seg in after:
+ if _seg_key(seg) not in before_segs:
+ _write_blocks(f, f"only_after;{_seg_info(seg)}", seg["blocks"])
+
+ return format_flamegraph(f.getvalue())
+
+
+def _format_size(num):
+ # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
+ if abs(num) < 1024.0:
+ return f"{num:3.1f}{unit}B"
+ num /= 1024.0
+ return f"{num:.1f}YiB"
+
+
+class Bytes:
+ def __init__(self, value):
+ self.value = value
+
+ def __add__(self, rhs):
+ return Bytes(self.value + rhs)
+
+ def __repr__(self):
+ return _format_size(self.value)
+
+
+def calc_active(seg):
+ return sum(b["size"] for b in seg["blocks"] if b["state"] == "active_allocated")
+
+
+def _report_free(free_external, free_internal):
+ total = free_external + free_internal
+ suffix = ""
+ if total != 0:
+ pct = (free_internal / total) * 100
+ suffix = f" ({pct:.1f}% internal)"
+ return f"{Bytes(total)}{suffix}"
+
+
+PAGE_SIZE = 1024 * 1024 * 20
+legend = f"""\
+
+Legend:
+ [a ] - a segment in the allocator
+ ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
+ a-z: pages filled with a single block's content
+ ' ': page is completely free
+ *: page if completely full with multiple blocks
+ 0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
+ (X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
+"""
+
+
+def segsum(data):
+ r"""Visually reports how the allocator has filled its segments.
+
+ This printout can help debug fragmentation issues since free fragments
+ will appear as gaps in this printout. The amount of free space is reported
+ for each segment.
+ We distinguish between internal free memory which occurs because the
+ allocator rounds the allocation size, and external free memory, which are
+ the gaps between allocations in a segment.
+ Args:
+ data: snapshot dictionary created from _snapshot()
+ """
+ out = io.StringIO()
+ out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
+ total_reserved = 0
+ total_allocated = 0
+ free_external = 0
+ free_internal = 0
+ for seg in sorted(
+ data["segments"], key=lambda x: (x["total_size"], calc_active(x))
+ ):
+ total_reserved += seg["total_size"]
+
+ seg_free_external = 0
+ seg_free_internal = 0
+ seg_allocated = 0
+ all_ranges = []
+ boffset = 0
+ for b in seg["blocks"]:
+ active = b["state"] == "active_allocated"
+ if active:
+ _, allocated_size = _block_extra(b)
+ all_ranges.append((boffset, allocated_size, True))
+ seg_allocated += allocated_size
+ seg_free_internal += b["size"] - allocated_size
+ else:
+ seg_free_external += b["size"]
+
+ boffset += b["size"]
+
+ total_allocated += seg_allocated
+ free_external += seg_free_external
+ free_internal += seg_free_internal
+
+ nseg = (seg["total_size"] - 1) // PAGE_SIZE + 1
+ occupied = [" " for _ in range(nseg)]
+ frac = [0.0 for _ in range(nseg)]
+ active_size = 0
+ for i, (start_, size, active) in enumerate(all_ranges):
+ active_size += size
+ finish_ = start_ + size
+ start = start_ // PAGE_SIZE
+ finish = (finish_ - 1) // PAGE_SIZE + 1
+ m = chr(ord("a" if active else "A") + (i % 26))
+ for j in range(start, finish):
+ s = max(start_, j * PAGE_SIZE)
+ e = min(finish_, (j + 1) * PAGE_SIZE)
+ frac[j] += (e - s) / PAGE_SIZE
+ if occupied[j] != " ":
+ occupied[j] = "0123456789*"[int(frac[j] * 10)]
+ else:
+ occupied[j] = m
+ stream_info = f' stream_{seg["stream"]}' if seg["stream"] != 0 else ""
+ body = "".join(occupied)
+
+ total_size = seg_free_external + seg_free_internal + seg_allocated
+ if total_size != seg["total_size"]:
+ raise ValueError(
+ f"Segment size mismatch: {total_size} != {seg['total_size']}"
+ )
+
+ if seg["total_size"] >= PAGE_SIZE:
+ out.write(
+ f'[{body}] {Bytes(seg["total_size"])} allocated, '
+ f"{_report_free(seg_free_external, seg_free_internal)} free{stream_info}\n"
+ )
+ out.write(f'segments: {len(data["segments"])}\n')
+ out.write(f"total_reserved: {Bytes(total_reserved)}\n")
+ out.write(f"total_allocated: {Bytes(total_allocated)}\n")
+ out.write(f"total_free: {_report_free(free_external, free_internal)}\n")
+ out.write(legend)
+ assert free_internal + free_external + total_allocated == total_reserved
+ return out.getvalue()
+
+
+def trace(data):
+ out = io.StringIO()
+
+ def format(entries):
+ segment_intervals: list = []
+ segment_addr_to_name = {}
+ allocation_addr_to_name = {}
+
+ free_names: list = []
+ next_name = 0
+
+ def _name():
+ nonlocal next_name
+ if free_names:
+ return free_names.pop()
+ r, m = next_name // 26, next_name % 26
+ next_name += 1
+ return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
+
+ def find_segment(addr):
+ for name, saddr, size in segment_intervals:
+ if addr >= saddr and addr < saddr + size:
+ return name, saddr
+ for i, seg in enumerate(data["segments"]):
+ saddr = seg["address"]
+ size = seg["allocated_size"]
+ if addr >= saddr and addr < saddr + size:
+ return f"seg_{i}", saddr
+ return None, None
+
+ count = 0
+ out.write(f"{len(entries)} entries\n")
+
+ total_reserved = 0
+ for seg in data["segments"]:
+ total_reserved += seg["total_size"]
+
+ for count, e in enumerate(entries):
+ if e["action"] == "alloc":
+ addr, size = e["addr"], e["size"]
+ n = _name()
+ seg_name, seg_addr = find_segment(addr)
+ if seg_name is None:
+ seg_name = "MEM"
+ offset = addr
+ else:
+ offset = addr - seg_addr
+ out.write(f"{n} = {seg_name}[{offset}:{Bytes(size)}]\n")
+ allocation_addr_to_name[addr] = (n, size, count)
+ count += size
+ elif e["action"] == "free_requested":
+ addr, size = e["addr"], e["size"]
+ name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
+ out.write(f"del {name} # {Bytes(size)}\n")
+ elif e["action"] == "free_completed":
+ addr, size = e["addr"], e["size"]
+ count -= size
+ name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
+ out.write(f"# free completed for {name} {Bytes(size)}\n")
+ if name in allocation_addr_to_name:
+ free_names.append(name)
+ del allocation_addr_to_name[name]
+ elif e["action"] == "segment_alloc":
+ addr, size = e["addr"], e["size"]
+ name = _name()
+ out.write(f"{name} = cudaMalloc({addr}, {Bytes(size)})\n")
+ segment_intervals.append((name, addr, size))
+ segment_addr_to_name[addr] = name
+ elif e["action"] == "segment_free":
+ addr, size = e["addr"], e["size"]
+ name = segment_addr_to_name.get(addr, addr)
+ out.write(f"cudaFree({name}) # {Bytes(size)}\n")
+ if name in segment_addr_to_name:
+ free_names.append(name)
+ del segment_addr_to_name[name]
+ elif e["action"] == "oom":
+ size = e["size"]
+ free = e["device_free"]
+ out.write(
+ f"raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n"
+ )
+ else:
+ out.write(f"{e}\n")
+ out.write(f"TOTAL MEM: {Bytes(count)}")
+
+ for i, d in enumerate(data["device_traces"]):
+ if d:
+ out.write(f"Device {i} ----------------\n")
+ format(d)
+ return out.getvalue()
+
+
+_memory_viz_template = r"""
+
+
+
+
+
+
+
+"""
+
+
+def _format_viz(data, viz_kind, device):
+ if device is not None:
+ warnings.warn(
+ "device argument is deprecated, plots now contain all device",
+ FutureWarning,
+ stacklevel=3,
+ )
+ buffer = pickle.dumps(data)
+ buffer += b"\x00" * (3 - len(buffer) % 3)
+ # Encode the buffer with base64
+ encoded_buffer = base64.b64encode(buffer).decode("utf-8")
+
+ json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}])
+ return _memory_viz_template.replace("$VIZ_KIND", repr(viz_kind)).replace(
+ "$SNAPSHOT", json_format
+ )
+
+
+def trace_plot(data, device=None, plot_segments=False):
+ """Generate a visualization over time of the memory usage recorded by the trace as an html file.
+
+ Args:
+ data: Memory snapshot as generated from torch.cuda.memory._snapshot()
+ device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
+ plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
+ Defaults to False.
+
+ Returns:
+ str: HTML of visualization
+ """
+ return _format_viz(
+ data,
+ "Active Memory Timeline"
+ if not plot_segments
+ else "Active Cached Memory Timeline",
+ device,
+ )
+
+
+def _profile_to_snapshot(profile):
+ import torch
+ from torch._C._profiler import _EventType
+ from torch.profiler._memory_profiler import Action, TensorKey
+
+ memory_profile = profile._memory_profile()
+
+ allocation_stacks = {}
+ for event in memory_profile._op_tree.sorted_nodes:
+ if event.tag == _EventType.Allocation:
+ parent = event.parent
+ python_parents = []
+ while parent:
+ if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
+ python_parents.append(parent)
+ parent = parent.parent
+ key = TensorKey.from_allocation(event.extra_fields)
+
+ # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
+ # key will be None. I should add some way to identify these, I just haven't yet.
+ if key and event.extra_fields.alloc_size > 0:
+ allocation_stacks[key] = python_parents
+
+ device_count = torch.cuda.device_count()
+ snapshot: dict[str, list[Any]] = {
+ "device_traces": [[] for _ in range(device_count + 1)],
+ "segments": [
+ {
+ "device": device,
+ "address": None,
+ "total_size": 0,
+ "stream": 0,
+ "blocks": [],
+ }
+ for device in range(device_count + 1)
+ ],
+ }
+
+ def to_device(device):
+ if device.type == "cuda":
+ return device.index
+ else:
+ return device_count
+
+ def allocate(size, tensor_key, version, during_trace=True):
+ device = to_device(tensor_key.device)
+ addr = tensor_key.storage.ptr
+
+ seg = snapshot["segments"][device] # type: ignore[index]
+ if seg["address"] is None or seg["address"] > addr:
+ seg["address"] = addr
+ seg["total_size"] = max(
+ seg["total_size"], addr + size
+ ) # record max addr for now, we will make it the size later
+ category = memory_profile._categories.get(tensor_key, version)
+ category = category.name.lower() if category is not None else "unknown"
+ stack = allocation_stacks.get(tensor_key, ())
+ stack = [{"filename": "none", "line": 0, "name": p.name} for p in stack]
+ r = {
+ "action": "alloc",
+ "addr": addr,
+ "size": size,
+ "stream": 0,
+ "frames": stack,
+ "category": category,
+ }
+ if during_trace:
+ snapshot["device_traces"][device].append(r)
+ return r
+
+ def free(alloc, device):
+ for e in ("free_requested", "free_completed"):
+ snapshot["device_traces"][device].append(
+ {
+ "action": e,
+ "addr": alloc["addr"],
+ "size": alloc["size"],
+ "stream": 0,
+ "frames": alloc["frames"],
+ }
+ )
+
+ kv_to_elem = {}
+
+ # create the device trace
+ for _time, action, (tensor_key, version), size in memory_profile.timeline:
+ if not isinstance(tensor_key, TensorKey):
+ continue
+ if action == Action.CREATE:
+ kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
+ elif action == Action.DESTROY:
+ free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
+ elif action == Action.INCREMENT_VERSION:
+ free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
+ kv_to_elem[(tensor_key, version + 1)] = allocate(
+ size, tensor_key, version + 1
+ )
+ elif action == Action.PREEXISTING:
+ kv_to_elem[(tensor_key, version)] = allocate(
+ size, tensor_key, version, during_trace=False
+ )
+
+ # create the final snapshot state
+ blocks_at_end = [
+ (to_device(tensor_key.device), event["addr"], event["size"], event["frames"])
+ for (tensor_key, version), event in kv_to_elem.items()
+ ]
+ for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)):
+ seg = snapshot["segments"][device] # type: ignore[index]
+ last_addr = seg["address"]
+ for _, addr, size, frames in blocks:
+ if last_addr < addr:
+ seg["blocks"].append({"size": addr - last_addr, "state": "inactive"})
+ seg["blocks"].append(
+ {
+ "size": size,
+ "state": "active_allocated",
+ "requested_size": size,
+ "frames": frames,
+ }
+ )
+ last_addr = addr + size
+ if last_addr < seg["total_size"]:
+ seg["blocks"].append(
+ {"size": seg["total_size"] - last_addr, "state": "inactive"}
+ )
+
+ snapshot["segments"] = [seg for seg in snapshot["segments"] if seg["blocks"]] # type: ignore[attr-defined]
+ for seg in snapshot["segments"]: # type: ignore[attr-defined, name-defined, no-redef]
+ seg["total_size"] -= seg["address"]
+ if not seg["blocks"]:
+ seg["blocks"].append({"size": seg["total_size"], "state": "inactive"})
+
+ return snapshot
+
+
+def profile_plot(profile, device=None):
+ """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
+
+ Args:
+ profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
+ device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
+
+ Returns:
+ str: HTML of visualization
+ """
+ snapshot = _profile_to_snapshot(profile)
+ return _format_viz(snapshot, "Active Memory Timeline", device)
+
+
+def segment_plot(data: Any, device=None):
+ return _format_viz(data, "Allocator State History", device)
+
+
+if __name__ == "__main__":
+ import os.path
+
+ thedir = os.path.realpath(os.path.dirname(__file__))
+ if thedir in sys.path:
+ # otherwise we find cuda/random.py as random...
+ sys.path.remove(thedir)
+ import argparse
+
+ fn_name = "torch.cuda.memory._snapshot()"
+ pickled = f"pickled memory statistics from {fn_name}"
+ parser = argparse.ArgumentParser(
+ description=f"Visualize memory dumps produced by {fn_name}"
+ )
+
+ subparsers = parser.add_subparsers(dest="action")
+
+ def _output(p):
+ p.add_argument(
+ "-o",
+ "--output",
+ default="output.svg",
+ help="flamegraph svg (default: output.svg)",
+ )
+
+ description = "Prints overall allocation statistics and a visualization of how the allocators segments are currently filled."
+ stats_a = subparsers.add_parser("stats", description=description)
+ stats_a.add_argument("input", help=pickled)
+
+ description = "Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style."
+ trace_a = subparsers.add_parser("trace", description=description)
+ trace_a.add_argument("input", help=pickled)
+
+ description = "Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)"
+ segments_a = subparsers.add_parser("segments", description=description)
+ segments_a.add_argument("input", help=pickled)
+ _output(segments_a)
+
+ description = (
+ "Generate a flamegraph the program locations contributing to CUDA memory usage."
+ )
+ memory_a = subparsers.add_parser("memory", description=description)
+ memory_a.add_argument("input", help=pickled)
+ _output(memory_a)
+
+ description = (
+ "Generate a flamegraph that shows segments (aka blocks) that have been added "
+ "or removed between two different memorys snapshots."
+ )
+ compare_a = subparsers.add_parser("compare", description=description)
+ compare_a.add_argument("before", help=pickled)
+ compare_a.add_argument("after", help=pickled)
+ _output(compare_a)
+
+ plots = (
+ (
+ "trace_plot",
+ "Generate a visualization over time of the memory usage recorded by the trace as an html file.",
+ ),
+ (
+ "segment_plot",
+ "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.",
+ ),
+ )
+ for cmd, description in plots:
+ trace_plot_a = subparsers.add_parser(cmd, description=description)
+ trace_plot_a.add_argument("input", help=pickled)
+ help = "visualize trace from this device (default: chooses the only device with trace info or errors)"
+ trace_plot_a.add_argument("-d", "--device", type=int, default=None, help=help)
+ help = "path to save the visualization(default: output.html)"
+ trace_plot_a.add_argument("-o", "--output", default="output.html", help=help)
+ if cmd == "trace_plot":
+ help = "visualize change to segments rather than individual allocations"
+ trace_plot_a.add_argument(
+ "-s", "--segments", action="store_true", help=help
+ )
+
+ args = parser.parse_args()
+
+ def _read(name):
+ if name == "-":
+ f = sys.stdin.buffer
+ else:
+ f = open(name, "rb")
+ data = pickle.load(f)
+ if isinstance(data, list): # segments only...
+ data = {"segments": data, "traces": []}
+ return data
+
+ def _write(name, data):
+ with open(name, "w") as f:
+ f.write(data)
+
+ if args.action == "segments":
+ data = _read(args.input)
+ _write(args.output, segments(data))
+ elif args.action == "memory":
+ data = _read(args.input)
+ _write(args.output, memory(data))
+ elif args.action == "stats":
+ data = _read(args.input)
+ print(segsum(data))
+ elif args.action == "trace":
+ data = _read(args.input)
+ print(trace(data))
+ elif args.action == "compare":
+ before = _read(args.before)
+ after = _read(args.after)
+ _write(args.output, compare(before, after))
+ elif args.action == "trace_plot":
+ data = _read(args.input)
+ _write(
+ args.output,
+ trace_plot(data, device=args.device, plot_segments=args.segments),
+ )
+ elif args.action == "segment_plot":
+ data = _read(args.input)
+ _write(args.output, segment_plot(data, device=args.device))
diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
index 708e941d88..fee8334b64 100644
--- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
+++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
@@ -3,6 +3,19 @@
from flytekit import lazy_module
from flytekit.types.file import FlyteFile
+from .pytorch_memory_viz import (
+ trace_plot,
+ segment_plot,
+ memory,
+ segments,
+ compare,
+ profile_plot,
+ _format_size, # Could be useful for formatting memory sizes
+ segsum, # Could be useful for summary visualization
+ trace # Could be useful for detailed trace visualization
+)
+import pickle
+from pathlib import Path
if TYPE_CHECKING:
import markdown
@@ -212,3 +225,325 @@ def to_html(self, df: "pd.DataFrame", chart_width: Optional[int] = None) -> str:
)
return fig.to_html()
+
+def _ensure_profiling_structure(data: dict) -> dict:
+ """Ensures profiling data has the correct structure with required keys.
+
+ Args:
+ data: Dictionary containing profiling data
+
+ Returns:
+ dict: Structured profiling data with required keys
+
+ Raises:
+ ValueError: If data structure is invalid
+ """
+ if not isinstance(data, dict):
+ raise ValueError("Profiling data must be a dictionary")
+
+ # Validate segments key exists and is a list
+ if "segments" not in data:
+ raise ValueError("Profiling data missing 'segments' key")
+ if not isinstance(data["segments"], list):
+ raise ValueError("'segments' must be a list")
+
+ return {
+ "segments": data["segments"],
+ "traces": data.get("traces", []),
+ "allocator_settings": data.get("allocator_settings", {})
+ }
+
+class PyTorchProfilingRenderer:
+ """Renders PyTorch profiling data in various visualization formats.
+
+ This renderer is particularly useful for analyzing memory usage and potential
+ memory-related failures in PyTorch executions. It can help diagnose OOM (Out of Memory)
+ errors and memory leaks by providing various visualization types.
+
+ Supports multiple visualization types:
+ - trace_plot: Shows the execution timeline
+ - segment_plot: Shows the execution segments
+ - memory: Displays memory usage over time
+ - segments: Shows detailed segment information
+ - compare: Compares two profiling snapshots
+ - profile_plot: Shows detailed profiling information
+ - summary: Shows overall allocation statistics
+ - trace_view: Shows detailed trace information
+
+ The renderer can be particularly helpful in:
+ 1. Analyzing failed executions due to OOM errors
+ 2. Identifying memory leaks
+ 3. Understanding memory usage patterns
+ 4. Comparing memory states before and after operations
+ """
+ def __init__(self, profiling_data):
+ """Initialize the renderer with profiling data.
+
+ Args:
+ profiling_data: Single snapshot or tuple of (before, after) snapshots
+
+ Raises:
+ ValueError: If profiling data is invalid
+ """
+ self._validate_profiling_data(profiling_data)
+ self.profiling_data = profiling_data
+
+ def _validate_profiling_data(self, profiling_data):
+ """Validates profiling data structure and checks for potential issues.
+
+ Args:
+ profiling_data: Data to validate
+
+ Raises:
+ ValueError: If data is None or snapshots are invalid
+ """
+ if profiling_data is None:
+ raise ValueError("Profiling data cannot be None")
+
+ # Handle both single snapshot and comparison cases
+ if isinstance(profiling_data, tuple):
+ before, after = profiling_data
+ if before is None or after is None:
+ raise ValueError("Both before and after snapshots must be provided for comparison")
+
+ # Check if this might be an OOM case by comparing memory usage
+ try:
+ self._check_memory_growth(before, after)
+ except Exception:
+ # Don't fail initialization if memory check fails
+ pass
+
+ def _check_memory_growth(self, before, after):
+ """Check for significant memory growth between snapshots"""
+ before_mem = self._get_total_memory(before)
+ after_mem = self._get_total_memory(after)
+ if after_mem > before_mem * 1.5: # 50% growth threshold
+ warnings.warn(
+ f"Significant memory growth detected: {_format_size(before_mem)} -> {_format_size(after_mem)}",
+ RuntimeWarning
+ )
+
+ def _get_total_memory(self, snapshot):
+ """Get total memory usage from a snapshot"""
+ total = 0
+ for seg in snapshot.get("segments", []):
+ total += seg.get("total_size", 0)
+ return total
+
+ def get_failure_analysis(self) -> str:
+ """
+ Analyze profiling data for potential failure causes.
+ Particularly useful for OOM and memory-related failures.
+
+ Returns:
+ str: HTML formatted analysis of potential issues
+ """
+ analysis = []
+
+ # Get memory summary
+ memory_summary = self.get_memory_summary()
+ analysis.append("Memory Usage Summary
")
+ analysis.append(f"{memory_summary}
")
+
+ # Get trace summary for context
+ trace_summary = self.get_trace_summary()
+ analysis.append("Execution Trace Summary
")
+ analysis.append(f"{trace_summary}
")
+
+ # Add memory visualization
+ analysis.append("Memory Usage Visualization
")
+ analysis.append(self.to_html("memory"))
+
+ return "\n".join(analysis)
+
+ def get_memory_metrics(self) -> dict:
+ """
+ Get key memory metrics that might be useful for failure analysis
+
+ Returns:
+ dict: Dictionary containing memory metrics
+ """
+ metrics = {
+ "peak_memory": 0,
+ "total_allocations": 0,
+ "largest_allocation": 0,
+ "memory_at_failure": 0,
+ }
+
+ try:
+ # Extract metrics from profiling data
+ if isinstance(self.profiling_data, tuple):
+ # For comparison case, use the 'after' snapshot
+ data = self.profiling_data[1]
+ else:
+ data = self.profiling_data
+
+ for seg in data.get("segments", []):
+ metrics["peak_memory"] = max(metrics["peak_memory"], seg.get("total_size", 0))
+ for block in seg.get("blocks", []):
+ if block.get("state") == "active_allocated":
+ metrics["total_allocations"] += 1
+ metrics["largest_allocation"] = max(
+ metrics["largest_allocation"],
+ block.get("size", 0)
+ )
+
+ # Get the last known memory state
+ metrics["memory_at_failure"] = self._get_total_memory(data)
+
+ except Exception as e:
+ warnings.warn(f"Failed to extract memory metrics: {str(e)}")
+
+ return metrics
+
+ def format_memory_size(self, size: int) -> str:
+ """Format memory size using the _memory_viz helper"""
+ return _format_size(size)
+
+ def to_html(self, plot_type: str = "trace_plot") -> str:
+ """Convert profiling data to HTML visualization."""
+
+ # Define memory_viz_js at the start so it's available for all branches
+ memory_viz_js = """
+
+
+ """
+
+ if plot_type == "profile_plot":
+ import torch
+ try:
+ # Create a profile object without initializing it
+ profile = torch.profiler.profile()
+ # Set basic attributes needed for visualization
+ profile.steps = []
+ profile.events = []
+ profile.key_averages = []
+
+ # Copy the data from our profiling_data
+ if isinstance(self.profiling_data, dict):
+ for key, value in self.profiling_data.items():
+ setattr(profile, key, value)
+ content = profile_plot(profile)
+ except Exception as e:
+ content = f"Failed to generate profile plot: {str(e)}
"
+
+ elif plot_type == "compare":
+ if not isinstance(self.profiling_data, tuple):
+ raise ValueError("Compare plot type requires before/after snapshots")
+ before, after = self.profiling_data
+
+ try:
+ before = _ensure_profiling_structure(before)
+ after = _ensure_profiling_structure(after)
+ content = compare(before["segments"], after["segments"])
+ except ValueError as e:
+ content = f"Failed to generate comparison: {str(e)}
"
+
+ elif plot_type == "trace_plot":
+ content = trace_plot(self.profiling_data)
+ elif plot_type == "segment_plot":
+ content = segment_plot(self.profiling_data)
+ elif plot_type == "memory":
+ content = memory(self.profiling_data)
+ elif plot_type == "segments":
+ content = segments(self.profiling_data)
+ else:
+ raise ValueError(f"Unknown plot type: {plot_type}")
+
+ return f"""
+
+
+
+
+ PyTorch Memory Profiling
+ {memory_viz_js if plot_type in ["memory", "segments", "compare"] else ""}
+
+
+ {content}
+
+
+ """
+
+ def get_memory_summary(self) -> str:
+ """Get a text summary of memory usage"""
+ return segsum(self.profiling_data)
+
+ def get_trace_summary(self) -> str:
+ """Get a text summary of the trace"""
+ return trace(self.profiling_data)
+
+ @staticmethod
+ def load_from_file(file_path: str) -> 'PyTorchProfilingRenderer':
+ """Create a renderer instance from a pickle file.
+
+ Args:
+ file_path: Path to the pickle file containing profiling data
+
+ Returns:
+ PyTorchProfilingRenderer: New renderer instance
+
+ Raises:
+ ValueError: If file loading or deserialization fails
+ FileNotFoundError: If file doesn't exist
+ """
+ try:
+ file_path = Path(file_path)
+ if not file_path.exists():
+ raise FileNotFoundError(f"Profile file not found: {file_path}")
+
+ with file_path.open("rb") as f:
+ # Use safe loading with protocol and encoding specified
+ try:
+ profiling_data = pickle.load(f, encoding='bytes')
+ except pickle.UnpicklingError as e:
+ raise ValueError(f"Failed to deserialize profiling data: {str(e)}")
+
+ return PyTorchProfilingRenderer(profiling_data)
+
+ except Exception as e:
+ raise ValueError(f"Failed to load profiling data: {str(e)}")
+
+def render_pytorch_profiling(profiling_file: FlyteFile, plot_type: str = "trace_plot") -> str:
+ """Renders PyTorch profiling data from a pickle file into HTML visualization.
+
+ Args:
+ profiling_file (FlyteFile): Pickle file containing PyTorch profiling data
+ plot_type (str): Type of visualization to generate
+
+ Returns:
+ str: HTML string containing the visualization
+
+ Raises:
+ FileNotFoundError: If profiling file doesn't exist
+ ValueError: If plot type is invalid or data loading fails
+ """
+ # Load the profiling data from the .pkl file
+ try:
+ with open(profiling_file, "rb") as f:
+ profiling_data = pickle.load(f)
+ except Exception as e:
+ raise ValueError(f"Failed to load profiling data: {str(e)}")
+ # Create an instance of the renderer and generate the HTML
+ renderer = PyTorchProfilingRenderer(profiling_data)
+ return renderer.to_html(plot_type)
+
+def test_compare_plot_type():
+ """Test the compare plot type which requires two snapshots"""
+ with open(PROFILE_PATH, "rb") as f:
+ profiling_data = pickle.load(f)
+
+ # Create proper before/after snapshots
+ before = {"segments": [], "traces": []} # Empty snapshot
+ after = profiling_data # Your actual data
+
+ renderer = PyTorchProfilingRenderer((before, after))
+ html_output = renderer.to_html("compare")
+
+ assert isinstance(html_output, str)
+ assert "" in html_output
\ No newline at end of file
diff --git a/plugins/flytekit-deck-standard/tests/test_pytorch_profiling_renderer.py b/plugins/flytekit-deck-standard/tests/test_pytorch_profiling_renderer.py
new file mode 100644
index 0000000000..a806906ec9
--- /dev/null
+++ b/plugins/flytekit-deck-standard/tests/test_pytorch_profiling_renderer.py
@@ -0,0 +1,130 @@
+import os
+import pytest
+import pickle
+import tempfile
+from flytekit.types.file import FlyteFile
+from flytekitplugins.deck.renderer import PyTorchProfilingRenderer, render_pytorch_profiling
+
+# Get the current directory where the test file is located
+CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+PROFILE_PATH = os.path.join(CURRENT_DIR, "profile1.pkl")
+
+def test_pytorch_profiling_renderer_initialization():
+ """Test PyTorchProfilingRenderer initialization"""
+ with open(PROFILE_PATH, "rb") as f:
+ profiling_data = pickle.load(f)
+
+ renderer = PyTorchProfilingRenderer(profiling_data)
+ assert renderer.profiling_data == profiling_data
+
+def test_pytorch_profiling_renderer_invalid_data():
+ """Test PyTorchProfilingRenderer with invalid data"""
+ with pytest.raises(ValueError):
+ PyTorchProfilingRenderer(None)
+
+def _create_profiling_structure(data):
+ """Creates a standardized profiling data structure.
+
+ Args:
+ data: Raw profiling data
+
+ Returns:
+ dict: Structured profiling data with required keys
+ """
+ return {
+ "segments": data.get("segments", []) if hasattr(data, "get") else [],
+ "traces": data.get("traces", []) if hasattr(data, "get") else [],
+ "allocator_settings": data.get("allocator_settings", {}) if hasattr(data, "get") else {}
+ }
+
+@pytest.mark.parametrize("plot_type,expected_content", [
+ ("trace_plot", ""),
+ ("segment_plot", ""),
+ ("memory", "MemoryViz.js"),
+ ("segments", "MemoryViz.js"),
+ ("profile_plot", "")
+])
+def test_pytorch_profiling_renderer_plot_types(plot_type, expected_content):
+ """Test different plot types for PyTorchProfilingRenderer"""
+ with open(PROFILE_PATH, "rb") as f:
+ profiling_data = pickle.load(f)
+
+ # Convert profiling data to proper format if needed
+ if plot_type in ["memory", "segments"]:
+ profiling_data = _create_profiling_structure(profiling_data)
+ elif plot_type == "profile_plot":
+ profiling_data = {
+ "steps": [],
+ "events": [],
+ "key_averages": []
+ }
+
+ renderer = PyTorchProfilingRenderer(profiling_data)
+ html_output = renderer.to_html(plot_type)
+
+ assert isinstance(html_output, str)
+ assert "" in html_output
+ assert expected_content in html_output
+
+def test_pytorch_profiling_renderer_invalid_plot_type():
+ """Test PyTorchProfilingRenderer with invalid plot type"""
+ with open(PROFILE_PATH, "rb") as f:
+ profiling_data = pickle.load(f)
+
+ renderer = PyTorchProfilingRenderer(profiling_data)
+ with pytest.raises(ValueError, match="Unknown plot type"):
+ renderer.to_html("invalid_plot_type")
+
+def test_render_pytorch_profiling_function():
+ """Test the render_pytorch_profiling helper function"""
+ profiling_file = FlyteFile(PROFILE_PATH)
+
+ # Test with default plot type
+ html_output = render_pytorch_profiling(profiling_file)
+ assert isinstance(html_output, str)
+ assert "" in html_output
+
+ # Test with specific plot type
+ html_output = render_pytorch_profiling(profiling_file, plot_type="memory")
+ assert isinstance(html_output, str)
+ assert "" in html_output
+
+def test_render_pytorch_profiling_file_not_found():
+ """Test render_pytorch_profiling with non-existent file"""
+ non_existent_file = FlyteFile("non_existent.pkl")
+ with pytest.raises(ValueError, match="Failed to load profiling data"):
+ render_pytorch_profiling(non_existent_file)
+
+def test_invalid_file_handling():
+ """Test handling of invalid profiling data files."""
+
+ # Use tempfile context manager for safe cleanup
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.pkl', delete=False) as temp_file:
+ # Write invalid data
+ temp_file.write("not a pickle file")
+ temp_file.flush()
+
+ # Create FlyteFile from temp file
+ invalid_file = FlyteFile(temp_file.name)
+
+ # Test that it raises the expected error
+ with pytest.raises(ValueError, match="Failed to load profiling data"):
+ render_pytorch_profiling(invalid_file)
+
+ # No need for manual cleanup - tempfile handles it automatically
+
+def test_compare_plot_type():
+ """Test the compare plot type which requires two snapshots"""
+ with open(PROFILE_PATH, "rb") as f:
+ profiling_data = pickle.load(f)
+
+ # Create proper before/after snapshots
+ before = _create_profiling_structure({}) # Empty snapshot
+ after = _create_profiling_structure(profiling_data) # Your actual data
+
+ renderer = PyTorchProfilingRenderer((before, after))
+ html_output = renderer.to_html("compare")
+
+ assert isinstance(html_output, str)
+ assert "" in html_output
+ assert "MemoryViz.js" in html_output
\ No newline at end of file