diff --git a/backend/ttnn_visualizer/csv_queries.py b/backend/ttnn_visualizer/csv_queries.py index 4b119575..542b593d 100644 --- a/backend/ttnn_visualizer/csv_queries.py +++ b/backend/ttnn_visualizer/csv_queries.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # # SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. - +import csv import os +import tempfile +from io import StringIO from pathlib import Path from typing import List, Dict, Union, Optional import pandas as pd +from tt_perf_report import perf_report +from ttnn_visualizer.exceptions import DataFormatError from ttnn_visualizer.models import TabSession from ttnn_visualizer.ssh_client import get_client @@ -538,3 +542,66 @@ def get_all_entries( return self.runner.execute_query( columns=self.PERF_RESULTS_COLUMNS, as_dict=as_dict, limit=limit ) + + +class OpsPerformanceReportQueries: + REPORT_COLUMNS = [ + "id", + "total_percent", + "bound", + "op_code", + "device_time", + "op_to_op_gap", + "cores", + "dram", + "dram_percent", + "flops", + "flops_percent", + "math_fidelity", + "output_datatype", + "input_0_datatype", + "input_1_datatype", + "dram_sharded", + "input_0_memory", + "inner_dim_block_size", + "output_subblock_h", + "output_subblock_w" + ] + + DEFAULT_SIGNPOST = None + DEFAULT_IGNORE_SIGNPOSTS = None + DEFAULT_MIN_PERCENTAGE = 0.5 + DEFAULT_ID_RANGE = None + DEFAULT_NO_ADVICE = False + + @classmethod + def generate_report(cls, session): + raw_csv = OpsPerformanceQueries.get_raw_csv(session) + csv_file = StringIO(raw_csv) + csv_output_file = tempfile.mktemp(suffix=".csv") + perf_report.generate_perf_report( + csv_file, + cls.DEFAULT_SIGNPOST, + cls.DEFAULT_IGNORE_SIGNPOSTS, + cls.DEFAULT_MIN_PERCENTAGE, + cls.DEFAULT_ID_RANGE, + csv_output_file, + cls.DEFAULT_NO_ADVICE, + ) + + report = [] + + try: + with open(csv_output_file, newline="") as csvfile: + reader = csv.reader(csvfile, delimiter=",") + next(reader, None) + for row in reader: + report.append({ + column: row[index] for index, column in enumerate(cls.REPORT_COLUMNS) + }) + except csv.Error as e: + raise DataFormatError() from e + finally: + os.unlink(csv_output_file) + + return report diff --git a/backend/ttnn_visualizer/exceptions.py b/backend/ttnn_visualizer/exceptions.py index 03d53b74..066cd2bc 100644 --- a/backend/ttnn_visualizer/exceptions.py +++ b/backend/ttnn_visualizer/exceptions.py @@ -34,3 +34,7 @@ def __init__(self, message, status): class DatabaseFileNotFoundException(Exception): pass + + +class DataFormatError(Exception): + pass diff --git a/backend/ttnn_visualizer/requirements.txt b/backend/ttnn_visualizer/requirements.txt index c263df9e..29965c43 100644 --- a/backend/ttnn_visualizer/requirements.txt +++ b/backend/ttnn_visualizer/requirements.txt @@ -1,4 +1,3 @@ - gunicorn~=22.0.0 uvicorn==0.30.1 paramiko~=3.4.0 @@ -17,6 +16,7 @@ wheel build PyYAML==6.0.2 python-dotenv==1.0.1 +tt-perf-report==1.0.0 # Dev dependencies mypy diff --git a/backend/ttnn_visualizer/views.py b/backend/ttnn_visualizer/views.py index 9b4d69ee..39365b2e 100644 --- a/backend/ttnn_visualizer/views.py +++ b/backend/ttnn_visualizer/views.py @@ -14,8 +14,9 @@ from flask import Blueprint, Response, jsonify from flask import request, current_app -from ttnn_visualizer.csv_queries import DeviceLogProfilerQueries, OpsPerformanceQueries +from ttnn_visualizer.csv_queries import DeviceLogProfilerQueries, OpsPerformanceQueries, OpsPerformanceReportQueries from ttnn_visualizer.decorators import with_session +from ttnn_visualizer.exceptions import DataFormatError from ttnn_visualizer.enums import ConnectionTestStates from ttnn_visualizer.exceptions import RemoteConnectionException from ttnn_visualizer.file_uploads import ( @@ -387,6 +388,20 @@ def get_profiler_perf_results_data_raw(session: TabSession): ) +@api.route("/profiler/perf-results/report", methods=["GET"]) +@with_session +def get_profiler_perf_results_report(session: TabSession): + if not session.profiler_path: + return Response(status=HTTPStatus.NOT_FOUND) + + try: + report = OpsPerformanceReportQueries.generate_report(session) + except DataFormatError: + return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY) + + return jsonify(report), 200 + + @api.route("/profiler/device-log/raw", methods=["GET"]) @with_session def get_profiler_data_raw(session: TabSession): diff --git a/pyproject.toml b/pyproject.toml index 98c47994..15dd126e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ "flask-socketio==5.4.1", "flask-sqlalchemy==3.1.1", "PyYAML==6.0.2", - "python-dotenv==1.0.1" + "python-dotenv==1.0.1", + "tt-perf-report==1.0.0" ] classifiers = [