From b846e9ea994c4802e5f3c0eb72986f56ad2d04c7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 9 Dec 2024 11:55:48 +0100 Subject: [PATCH] fix --- optimum_benchmark/trackers/latency.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 29bb1bf9..f3850e1c 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -246,8 +246,6 @@ def elapsed(self): @contextmanager def track(self): - assert self.start_time is not None, "This method can only be called inside of a '.session()' context" - if self.is_pytorch_cuda: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -338,8 +336,6 @@ def elapsed(self): @contextmanager def track(self): - assert self.start_time is not None, "This method can only be called inside of a '.session()' context" - if self.is_pytorch_cuda: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -359,8 +355,6 @@ def track(self): self.per_token_end_events.extend(self.per_token_events[1:]) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - assert self.start_time is not None, "This method can only be called inside of a '.session()' context" - if self.is_pytorch_cuda: event = torch.cuda.Event(enable_timing=True) event.record() @@ -486,8 +480,6 @@ def elapsed(self): @contextmanager def track(self): - assert self.start_time is not None, "This method can only be called inside of a '.session()' context" - if self.is_pytorch_cuda: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -507,8 +499,6 @@ def track(self): self.per_step_end_events.extend(self.per_step_events[1:]) def __call__(self, pipeline, step_index, timestep, callback_kwargs): - assert self.start_time is not None, "This method can only be called inside of a '.session()' context" - if self.is_pytorch_cuda: event = torch.cuda.Event(enable_timing=True) event.record()