Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute the real prefill latency using the logits processor #150

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 70 additions & 47 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
else:
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

if backend.config.task in TEXT_GENERATION_TASKS:
LOGGER.info("\t+ Additional warmup for Text Generation")
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
LOGGER.info("\t+ Additional warmup for Image Diffusion")
_ = backend.call(self.call_inputs, self.config.call_kwargs)

if self.config.memory:
LOGGER.info("\t+ Creating inference memory tracker")
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_memory_tracking(backend)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
Expand All @@ -142,10 +145,11 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
self.report.log_memory()

if self.config.latency:
LOGGER.info("\t+ Creating inference latency tracker")
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_latency_tracking(backend)
if backend.config.name in PER_TOKEN_BACKENDS:
self.run_fine_grained_text_generation_latency_tracking(backend)
else:
self.run_text_generation_latency_tracking(backend)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
self.run_image_diffusion_latency_tracking(backend)
else:
Expand All @@ -155,8 +159,6 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
self.report.log_throughput()

if self.config.energy:
LOGGER.info("\t+ Creating inference energy tracker")
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_energy_tracking(backend)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
Expand All @@ -170,7 +172,11 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
## Memory tracking
def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)
self.memory_tracker.reset()

with self.memory_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

Expand All @@ -184,24 +190,56 @@ def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):

def run_image_diffusion_memory_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)

with self.memory_tracker.track():
_ = backend.call(self.call_inputs, self.config.call_kwargs)

self.report.call.memory = self.memory_tracker.get_max_memory()

def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)

with self.memory_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.forward.memory = self.memory_tracker.get_max_memory()

## Latency tracking
def run_fine_grained_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running fine-grained Text Generation latency tracking")
self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name)
self.config.generate_kwargs["logits_processor"] = LogitsProcessorList(
[self.logits_processor, *self.config.generate_kwargs.get("logits_processor", [])]
)

while self.logits_processor.get_elapsed_time() < self.config.duration:
with self.logits_processor.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)

self.report.per_token.latency = self.logits_processor.get_per_token_latency()
self.report.prefill.latency = self.logits_processor.get_prefill_latency()
self.report.decode.latency = self.logits_processor.get_decode_latency()

self.report.per_token.throughput = Throughput.from_latency(
self.report.per_token.latency, self.text_generation_per_token_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.prefill.throughput = Throughput.from_latency(
self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)

def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running latency tracking")
self.latency_tracker.reset()
LOGGER.info("\t+ Running Text Generation latency tracking")
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)

while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)
Expand All @@ -212,40 +250,21 @@ def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT])
self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)

if backend.config.name in PER_TOKEN_BACKENDS:
self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name)
self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.logits_processor])
self.logits_processor.reset()

while self.logits_processor.get_elapsed_time() < self.config.duration:
with self.logits_processor.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
generate_latency = self.latency_tracker.get_latency()

self.report.decode.latency = self.logits_processor.get_decode_latency()
self.report.per_token.latency = self.logits_processor.get_per_token_latency()
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.per_token.throughput = Throughput.from_latency(
self.report.per_token.latency,
self.text_generation_per_token_volume,
unit=TEXT_GENERATION_THROUGHPUT_UNIT,
)
else:
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
generate_latency = self.latency_tracker.get_latency()

self.report.decode.latency = generate_latency - forward_latency
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.decode.latency = generate_latency - forward_latency
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)

def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running latency tracking")
self.latency_tracker.reset()
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)

while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.call(self.call_inputs, self.config.call_kwargs)
Expand All @@ -257,7 +276,8 @@ def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT])

def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running latency tracking")
self.latency_tracker.reset()
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)

while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)
Expand All @@ -270,7 +290,8 @@ def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]):
## Energy tracking
def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)

with self.energy_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)
forward_energy = self.energy_tracker.get_energy()
Expand All @@ -292,7 +313,8 @@ def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]):

def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)

with self.energy_tracker.track():
_ = backend.call(self.call_inputs, self.config.call_kwargs)

Expand All @@ -303,7 +325,8 @@ def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]):

def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)

with self.energy_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

Expand Down
21 changes: 19 additions & 2 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def track(self):

self.tok_events: List[Union[float, torch.cuda.Event]] = []

if self.device == "cuda" and self.backend == "pytorch":
prefill_event = torch.cuda.Event(enable_timing=True)
prefill_event.record()
self.tok_events.append(prefill_event)

yield # this is where generate is called, and for each token, we record an event

self.run_events.append(self.tok_events)
Expand All @@ -235,17 +240,29 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):

return scores

def get_prefill_latency(self) -> Latency:
if self.device == "cuda" and self.backend == "pytorch":
# synchronize the device to make sure all events have been recorded
torch.cuda.synchronize()
latencies_list = [
self.run_events[i][0].elapsed_time(self.run_events[i][1]) / 1e3 for i in range(len(self.run_events))
]
else:
latencies_list = [(self.run_events[i][1] - self.run_events[i][0]) for i in range(len(self.run_events))]

return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

def get_per_token_latency(self) -> Latency:
latencies_list = []
for tok_events in self.run_events:
if self.device == "cuda" and self.backend == "pytorch":
# synchronize the device to make sure all events have been recorded
torch.cuda.synchronize()
latencies_list.extend(
[tok_events[i - 1].elapsed_time(tok_events[i]) / 1e3 for i in range(1, len(tok_events))]
[tok_events[i].elapsed_time(tok_events[i + 1]) / 1e3 for i in range(1, len(tok_events) - 1)]
)
else:
latencies_list.extend([(tok_events[i] - tok_events[i - 1]) for i in range(1, len(tok_events))])
latencies_list.extend([(tok_events[i] - tok_events[i + 1]) for i in range(1, len(tok_events) - 1)])

return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

Expand Down
Loading