From 1760d8a37197f85235b5315debe88f5e105f336e Mon Sep 17 00:00:00 2001 From: dorotat-nv Date: Tue, 21 Jan 2025 20:48:53 +0100 Subject: [PATCH] added tflops callback and flag --- ci/benchmarks/perf/esm2_pretrain.yaml | 1 + .../src/bionemo/esm2/scripts/train_esm2.py | 92 ++++++++++++------- 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/ci/benchmarks/perf/esm2_pretrain.yaml b/ci/benchmarks/perf/esm2_pretrain.yaml index a27604733..32c143cfa 100644 --- a/ci/benchmarks/perf/esm2_pretrain.yaml +++ b/ci/benchmarks/perf/esm2_pretrain.yaml @@ -62,4 +62,5 @@ script: |- --accumulate-grad-batches=${acc_grad} \ --pipeline-model-parallel-size=${pp} \ --tensor-model-parallel-size={tp} \ + --log-tflops-per-sec-per-gpu \ --disable-checkpointing; diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py index 928ff81fa..c17c92d18 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +from dataclasses import asdict from pathlib import Path from typing import List, Optional, Sequence, get_args @@ -24,6 +25,7 @@ from nemo.collections import llm from nemo.lightning import resume from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback from nemo.lightning.pytorch.optim import MegatronOptimizerModule from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype @@ -98,6 +100,7 @@ def main( overlap_param_gather: bool = True, average_in_collective: bool = True, grad_reduce_in_fp32: bool = False, + log_tflops_per_sec_per_gpu: bool = False, ) -> None: """Train an ESM2 model on UR data. @@ -159,6 +162,7 @@ def main( overlap_param_gather (bool): overlap parameter gather average_in_collective (bool): average in collective grad_reduce_in_fp32 (bool): gradient reduction in fp32 + log_tflops_per_sec_per_gpu (bool): Enables FLOP tracking callback to measure teraFLOPs/second performance per GPU device """ # Create the result directory if it does not exist. result_dir.mkdir(parents=True, exist_ok=True) @@ -210,41 +214,6 @@ def main( ) ) - callbacks = [ - PerplexityLoggingCallback(log_train=False, log_val=True), - RichModelSummary(max_depth=4), - LearningRateMonitor(), - nl_callbacks.PreemptionCallback(), - ] - if nsys_profiling: - if nsys_end_step is None: - nsys_end_step = num_steps - callbacks.append( - nl_callbacks.NsysCallback( - start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True - ) - ) - - trainer = nl.Trainer( - devices=devices, - max_steps=num_steps, - accelerator="gpu", - strategy=strategy, - limit_val_batches=limit_val_batches, # This controls upsampling and downsampling - val_check_interval=val_check_interval, - log_every_n_steps=log_every_n_steps, - num_nodes=num_nodes, - callbacks=callbacks, - plugins=nl.MegatronMixedPrecision( - precision=precision, - params_dtype=get_autocast_dtype(precision), - pipeline_dtype=get_autocast_dtype(precision), - grad_reduce_in_fp32=grad_reduce_in_fp32, - autocast_enabled=False, - ), - enable_checkpointing=create_checkpoint_callback, - ) - tokenizer = get_tokenizer() # Initialize the data module. @@ -303,6 +272,50 @@ def main( ), ) + callbacks = [ + PerplexityLoggingCallback(log_train=False, log_val=True), + RichModelSummary(max_depth=4), + LearningRateMonitor(), + nl_callbacks.PreemptionCallback(), + ] + if nsys_profiling: + if nsys_end_step is None: + nsys_end_step = num_steps + callbacks.append( + nl_callbacks.NsysCallback( + start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True + ) + ) + + if log_tflops_per_sec_per_gpu: + # Add callback that logs the tera-FLOPS per second per GPU during training. + flop_meas_callback = FLOPsMeasurementCallback( + asdict(esm2_config), + data, + "bert", + ) + callbacks.append(flop_meas_callback) + + trainer = nl.Trainer( + devices=devices, + max_steps=num_steps, + accelerator="gpu", + strategy=strategy, + limit_val_batches=limit_val_batches, # This controls upsampling and downsampling + val_check_interval=val_check_interval, + log_every_n_steps=log_every_n_steps, + num_nodes=num_nodes, + callbacks=callbacks, + plugins=nl.MegatronMixedPrecision( + precision=precision, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + grad_reduce_in_fp32=grad_reduce_in_fp32, + autocast_enabled=False, + ), + enable_checkpointing=create_checkpoint_callback, + ) + # Configure our custom Checkpointer if create_checkpoint_callback: checkpoint_callback = nl_callbacks.ModelCheckpoint( @@ -398,6 +411,7 @@ def train_esm2_entrypoint(): overlap_param_gather=not args.no_overlap_param_gather, average_in_collective=not args.no_average_in_collective, grad_reduce_in_fp32=args.grad_reduce_in_fp32, + log_tflops_per_sec_per_gpu=args.log_tflops_per_sec_per_gpu, ) @@ -739,6 +753,14 @@ def get_parser(): action="store_true", default=False, ) + + parser.add_argument( + "--log-tflops-per-sec-per-gpu", + action="store_true", + default=False, + help="Enables FLOP tracking callback to measure teraFLOPs/second performance per GPU device", + ) + return parser