Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle] Optimize memory usage when training in pipeline parallel #580

Merged
merged 8 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions tests/paddle/parallel_tests/linear_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@
import transformer_engine.paddle as te


class TELinear(te.Linear):
"""To pass is_first_microbatch"""

def __init__(self, *args, **kwargs):
assert 'accumulate_steps' in kwargs
self.accumulate_steps = kwargs['accumulate_steps']
del kwargs['accumulate_steps']
self._micro_batch_id = 0
super().__init__(*args, **kwargs)

def forward(self, *args, **kwargs):
kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1
return super().forward(*args, **kwargs)


class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test"""

Expand All @@ -28,17 +45,30 @@ def __init__(self,
weight_attrs,
use_te=True,
use_fp8=False,
accumulate_steps=1,
**kwargs):
self.in_features = in_features
self.hidden_features = hidden_features
self.fp8 = use_fp8
hcg = fleet.get_hybrid_communicate_group()
self.dp_group = hcg.get_data_parallel_group()

Linear = te.Linear if use_te else paddle.nn.Linear
Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {}
if use_te:
extra_kwargs['accumulate_steps'] = accumulate_steps

model_desc = [
LayerDesc(Linear, self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
LayerDesc(Linear, self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
LayerDesc(Linear,
self.in_features,
self.hidden_features,
weight_attr=weight_attrs[0],
**extra_kwargs),
LayerDesc(Linear,
self.hidden_features,
self.in_features,
weight_attr=weight_attrs[1],
**extra_kwargs),
]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)

Expand Down Expand Up @@ -84,8 +114,9 @@ def init_dist_env(self):
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size,
}
self.accumulate_steps = self.batch_size // self.micro_batch_size
strategy.pipeline_configs = {
"accumulate_steps": self.batch_size // self.micro_batch_size,
"accumulate_steps": self.accumulate_steps,
"micro_batch_size": self.micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
Expand Down Expand Up @@ -128,6 +159,7 @@ def test_pipeline_train(self):
use_fp8=self.fp8,
seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size,
accumulate_steps=self.accumulate_steps,
)

# Check if model is split across ranks as expected
Expand Down
Loading