From 3f485dd32d380862defb0e97cd0349d368bcbf57 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 16 Apr 2024 19:26:54 -0400 Subject: [PATCH] Support Low Rank Adaptation (LoRA). (#745) Signed-off-by: Pawel Gadzinski --- tests/jax/test_functions.py | 68 ++++++++ tests/jax/test_praxis_layers.py | 44 +++++ transformer_engine/jax/flax/module.py | 169 ++++++++++++++++++- transformer_engine/jax/flax/transformer.py | 103 +++++++++++ transformer_engine/jax/praxis/module.py | 18 ++ transformer_engine/jax/praxis/transformer.py | 12 ++ 6 files changed, 412 insertions(+), 2 deletions(-) create mode 100644 tests/jax/test_functions.py diff --git a/tests/jax/test_functions.py b/tests/jax/test_functions.py new file mode 100644 index 0000000000..aaa6be77ac --- /dev/null +++ b/tests/jax/test_functions.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest + +import jax +import jax.numpy as jnp + +from utils import assert_allclose +from transformer_engine.jax.flax.module import _apply_low_rank_adaptation +from transformer_engine.jax.flax.module import _normalize_axes +from transformer_engine.jax.flax.transformer import LoRAScope +from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope + + +class TestLoRA: + + def reference(x, la, lb, pattern, scale): + out = jnp.einsum(pattern, x, la, lb) + return out * scale + + @pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)]) + @pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'), + ((-1,), (3, 1024), '...h,hkr,krz->...kz')]) + @pytest.mark.parametrize('rank', [32, 16]) + @pytest.mark.parametrize('alpha', [None, 4, 8]) + def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha): + axis, features, pattern = axis_features_pattern + axis = _normalize_axes(axis, len(shape)) + shape_in_axis = tuple(shape[ax] for ax in axis) + + key = jax.random.key(1124) + key, x_key = jax.random.split(key) + x = jax.random.normal(x_key, shape, dtype) + + key, la_key = jax.random.split(key) + la_shape = (*shape_in_axis, *features[:-1], rank) + la = jax.random.normal(la_key, la_shape, dtype) + + key, lb_key = jax.random.split(key) + lb_shape = (*features[:-1], rank, features[-1]) + lb = jax.random.normal(lb_key, lb_shape, dtype) + + out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha) + scale_ref = alpha / rank if alpha is not None else 1.0 + out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref) + + assert_allclose(out_target, out_ref, dtype=dtype) + + @pytest.mark.parametrize('scope_ref_assert', + [('none', LoRAScope(False, False, False), False), + ('all', LoRAScope(True, True, True), False), + ('qkv_proj', LoRAScope(True, False, False), False), + ('output_proj', LoRAScope(False, True, False), False), + ('mlp', LoRAScope(False, False, True), False), + ('exclude_qkv_proj', LoRAScope(False, True, True), False), + ('exclude_output_proj', LoRAScope(True, False, True), False), + ('exclude_mlp', LoRAScope(True, True, False), False), + ('messing_up', LoRAScope(), True)]) + def test_lora_scope_generator(self, scope_ref_assert): + scope, reference, need_assert = scope_ref_assert + try: + lora_scope = _canonicalize_lora_scope(scope) + assert lora_scope == reference + except AssertionError as ae: + assert need_assert, f"{ae.args}" diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 43581f1015..dce0263ac7 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -784,6 +784,7 @@ class MultiHeadAttnAttr: NUM_GQA_GROUPS = 'num_gqa_groups' ENABLE_ROPE = 'enable_rotary_pos_emb' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' + LORA_SCOPE = 'low_rank_adaptation_scope' ATTRS = [{ USE_BIAS: True, LN_TYPE: 'layernorm', @@ -853,6 +854,22 @@ class MultiHeadAttnAttr: NUM_ATTN_HEADS: 8, NUM_GQA_GROUPS: 4, ATTN_MASK_TYPE: 'causal' + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + ATTN_MASK_TYPE: 'padding', + LORA_SCOPE: 'all' + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + ATTN_MASK_TYPE: 'causal', + LORA_SCOPE: 'all' }] @@ -883,6 +900,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] + low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') fuse_qkv_params = True transpose_batch_sequence = True scale_attn_logits = False @@ -905,6 +923,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): attn_mask_type=attn_mask_type, enable_rotary_pos_emb=enable_rotary_pos_emb, rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, fuse_qkv_params=fuse_qkv_params, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, @@ -926,6 +945,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): attn_mask_type=attn_mask_type, enable_rotary_pos_emb=enable_rotary_pos_emb, rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, fuse_qkv_params=fuse_qkv_params, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, @@ -969,6 +989,7 @@ class TransformerLayerAttr: TRANSPOSE_BS = 'transpose_batch_sequence' ENABLE_ROPE = 'enable_rotary_pos_emb' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' + LORA_SCOPE = 'low_rank_adaptation_scope' ATTRS = [{ USE_BIAS: True, LN_TYPE: 'layernorm', @@ -1113,6 +1134,16 @@ class TransformerLayerAttr: ENABLE_ROPE: False, ROPE_GROUP_METHOD: 'consecutive', TRANSPOSE_BS: False + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ACTIVATION: ('gelu',), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + TRANSPOSE_BS: False, + LORA_SCOPE: 'all' }, { USE_BIAS: True, LN_TYPE: 'layernorm', @@ -1185,6 +1216,16 @@ class TransformerLayerAttr: ENABLE_ROPE: True, ROPE_GROUP_METHOD: 'consecutive', TRANSPOSE_BS: False + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ACTIVATION: ('gelu',), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + TRANSPOSE_BS: False, + LORA_SCOPE: 'all' }] @@ -1219,6 +1260,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): layer_type = attrs[TransformerLayerAttr.LYR_TYPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD] + low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none') enable_relative_embedding = True relative_embedding = pax_fiddle.Config(RelativePositionBiases, dtype=dtype, @@ -1257,6 +1299,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): enable_relative_embedding=enable_relative_embedding, enable_rotary_pos_emb=enable_rotary_pos_emb, rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, relative_embedding=relative_embedding, drop_path=drop_path, transpose_batch_sequence=transpose_batch_sequence) @@ -1282,6 +1325,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): rotary_pos_emb_group_method=rotary_pos_emb_group_method, enable_relative_embedding=enable_relative_embedding, relative_embedding=relative_embedding_flax_module, + low_rank_adaptation_scope=low_rank_adaptation_scope, drop_path=drop_path, transpose_batch_sequence=transpose_batch_sequence) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8ddc74ac2e..8ca8edcb0b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]): return mask +def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha): + """Low Rank Adaptation Implementation""" + + assert len(axis) <= 5 + hidden_in_names = 'ijklm'[:len(axis)] + assert len(features) <= 5 + hidden_out_names = 'nopqr'[:len(features)] + rank_name = 's' + + assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2] + rank = lora_a_kernel.shape[-1] + scaling = alpha / rank if alpha is not None else 1.0 + + x_einsum_express = f"...{hidden_in_names}" + lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}" + lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}" + output_einsum_express = f"...{hidden_out_names}" + final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \ + f"->{output_einsum_express}" + + output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel) + output = output * scaling + return output + + class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. @@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase): bias_axes: Tuple[str, ...], default = () The name of axes used to shard bias with a corresponding mesh, only used when :attr:`use_bias=True`. + enable_low_rank_adaptation: bool, default = False + Indicate whether to enable low rank adaptation for each linear layer. + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. @@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase): use_bias: bool = True bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = False @@ -439,6 +475,32 @@ def __call__(self, inputs: Array) -> Array: fp8_meta_pkg=fp8_gemm_pkg, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel', + self.kernel_init, + lora_a_kernel_init_shape, + jnp.float32, + axes=lora_a_kernel_axes) + lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) + lora_a_kernel = lora_a_kernel.astype(self.dtype) + + lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) + lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) + lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel', + nn.initializers.zeros, + lora_b_kernel_shape, + jnp.float32, + axes=lora_b_kernel_axes) + lora_b_kernel = lora_b_kernel.astype(self.dtype) + + y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel, + self.low_rank_adaptation_alpha) + if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape y += jnp.reshape(bias, bias_shape) @@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): return_layernorm_output: bool, default = True Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs. + enable_low_rank_adaptation: bool, default = False + Indicate whether to enable low rank adaptation for each linear layer. + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. layernorm_input_axes: Tuple[str, ...], default = None @@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () return_layernorm_output: bool = True + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = True @@ -650,6 +723,32 @@ def __call__(self, inputs: Array) -> Array: fp8_meta_pkg=fp8_meta_package, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel', + self.kernel_init, + lora_a_kernel_init_shape, + jnp.float32, + axes=lora_a_kernel_axes) + lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) + lora_a_kernel = lora_a_kernel.astype(self.dtype) + + lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) + lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) + lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel', + nn.initializers.zeros, + lora_b_kernel_shape, + jnp.float32, + axes=lora_b_kernel_axes) + lora_b_kernel = lora_b_kernel.astype(self.dtype) + + z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel, + self.low_rank_adaptation_alpha) + bias = None if self.use_bias: bias = nn_partitioning.param_with_axes('bias', @@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase): Dropout probability for the dropout op after the :attr:`activations`. intermediate_hidden_dropout_dims: Sequence[int], default = () Dimensions that will share the same dropout mask for hidden + enable_low_rank_adaptation: bool, default = False + Indicate whether to enable low rank adaptation for each linear layer. + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True`. + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. layernorm_input_axes: Tuple[str, ...], default = None @@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase): intermediate_dropout_rng_name: str = 'dropout' intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = True @@ -856,11 +966,13 @@ def is_gelu(acts): use_fused_ln_geglu_mlp = fuse_layernorm \ and (not self.use_bias) and is_geglu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) + and (self.intermediate_dropout_rate < 1e-3) \ + and not self.enable_low_rank_adaptation use_fused_ln_gelu_mlp = fuse_layernorm \ and self.use_bias and is_gelu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) + and (self.intermediate_dropout_rate < 1e-3) \ + and not self.enable_low_rank_adaptation # LayerNorm if self.enable_layernorm: @@ -999,6 +1111,37 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): fp8_meta_pkg=gemm1_fp8_meta_package, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations, + self.low_rank_adaptation_dim) + wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations, + self.low_rank_adaptation_dim) + wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0], + self.low_rank_adaptation_dim) + wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) + wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel', + kernel_1_init, + num_activations, + -2, + wi_lora_a_kernel_init_each_shape, + jnp.float32, + axes=wi_lora_a_kernel_axes) + wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) + wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) + + wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim, + self.intermediate_dim) + wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape) + wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel', + nn.initializers.zeros, + wi_lora_b_kernel_shape, + jnp.float32, + axes=wi_lora_b_kernel_axes) + wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) + + x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel, + wi_lora_b_kernel, self.low_rank_adaptation_alpha) + bias = None if self.use_bias: bias = nn_partitioning.param_with_axes('wi_bias', @@ -1042,6 +1185,28 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): fp8_meta_pkg=gemm2_fp8_meta_package, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim) + wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape) + wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel', + self.kernel_init, + wo_lora_a_kernel_shape, + jnp.float32, + axes=wo_lora_a_kernel_axes) + wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) + + wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) + wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) + wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel', + nn.initializers.zeros, + wo_lora_b_kernel_shape, + jnp.float32, + axes=wo_lora_b_kernel_axes) + wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) + + out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel, + wo_lora_b_kernel, self.low_rank_adaptation_alpha) + bias = None if self.use_bias: bias = nn_partitioning.param_with_axes('wo_bias', diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fcf06aa128..cacb360a27 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -637,6 +637,53 @@ def canonicalize_group_method(gm): return consecutive_impl() +class LoRAScope: # pylint: disable=too-few-public-methods + """LoRA Scope""" + + def __init__(self, qkv_proj=False, output_proj=False, mlp=False): + self.qkv_proj = qkv_proj + self.output_proj = output_proj + self.mlp = mlp + + def __eq__(self, other): + return (self.qkv_proj, self.output_proj, self.mlp) == \ + (other.qkv_proj, other.output_proj, other.mlp) + + +def _canonicalize_lora_scope(scope): + + SCOPE_NONE = 'none' + SCOPE_ALL = 'all' + SCOPE_QKV_PROJ = 'qkv_proj' + SCOPE_OUTPUT_PROJ = 'output_proj' + SCOPE_MLP = 'mlp' + SCOPE_EX_QKV_PROJ = 'exclude_qkv_proj' + SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj' + SCOPE_EX_MLP = 'exclude_mlp' + + scope = SCOPE_NONE if scope is None else scope + + scope = scope.lower() + + assert scope in [ + SCOPE_NONE, SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_OUTPUT_PROJ, SCOPE_MLP, SCOPE_EX_QKV_PROJ, + SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP + ] + + lora_scope = LoRAScope() + + if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]: + lora_scope.qkv_proj = True + + if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]: + lora_scope.output_proj = True + + if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]: + lora_scope.mlp = True + + return lora_scope + + class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods r""" Multi-head Attention (MHA), including Query, @@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods Indicate the method to coupled the coordinates. It should be one of ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. + low_rank_adaptation_scope: str, default = 'none' + Indicate the scope to apply low rank adaptation. It should be one of + ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj'] + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. num_heads: int, default = None @@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 fuse_qkv_params: bool = True transpose_batch_sequence: bool = True @@ -914,6 +973,8 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) + lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) + if self.fuse_qkv_params: if is_qkvpack: qkv_proj, ln_out = LayerNormDenseGeneral( @@ -932,6 +993,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_JOINED_AXES, W_TP_AXES), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, name='qkv', @@ -954,6 +1018,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_TP_AXES,), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, @@ -972,6 +1039,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_JOINED_AXES, W_TP_AXES), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, name='kv', dtype=self.dtype)(inputs_kv) kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') @@ -986,6 +1056,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_TP_AXES,), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, @@ -1002,6 +1075,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_TP_AXES,), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, @@ -1142,6 +1218,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_NO_SHARD_AXES,), + enable_low_rank_adaptation=lora_scope.output_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, name='out')(x) out = checkpoint_name(out, 'out_proj') @@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Indicate the method to coupled the coordinates. It should be one of ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. + low_rank_adaptation_scope: str, default = 'none' + Indicate the scope to apply low rank adaptation. It should be one of + ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', + 'exclude_output_proj', 'exclude_mlp'] + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. @@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True @@ -1579,6 +1671,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, fuse_qkv_params=self.fuse_qkv_params, kernel_init=self.mha_kernel_init, use_bias=self.use_bias, @@ -1646,6 +1741,9 @@ def hidden_dropout(x, deterministic): enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, @@ -1674,6 +1772,8 @@ def hidden_dropout(x, deterministic): mlp_input = with_sharding_constraint_by_logical_axes( mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) + # MlpBlock residual = mlp_input z, ln_out = LayerNormMLP( @@ -1697,6 +1797,9 @@ def hidden_dropout(x, deterministic): bias_init=self.bias_init, bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_2=(W_NO_SHARD_AXES,), + enable_low_rank_adaptation=lora_scope.mlp, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index 3688b62370..e6372b91dc 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer): use_bias: bool = True bias_init: WeightInit = WeightInit.Constant(0.0) bias_axes: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 transpose_batch_sequence: bool = False sharding_type: ShardingType = ShardingType.SINGLE @@ -147,6 +150,9 @@ def setup(self) -> None: use_bias=self.use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_axes=self.bias_axes, + enable_low_rank_adaptation=self.enable_low_rank_adaptation, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, axis=self.axis, dtype=self.dtype, transpose_batch_sequence=self.transpose_batch_sequence) @@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer): use_bias: bool = False bias_init: WeightInit = WeightInit.Constant(0.0) bias_axes: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None return_layernorm_output: bool = True axis: Union[Iterable[int], int] = -1 transpose_batch_sequence: bool = False @@ -201,6 +210,9 @@ def setup(self) -> None: use_bias=self.use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_axes=self.bias_axes, + enable_low_rank_adaptation=self.enable_low_rank_adaptation, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, return_layernorm_output=self.return_layernorm_output, axis=self.axis, dtype=self.dtype, @@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): bias_init: WeightInit = WeightInit.Constant(0.0) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ('relu',) intermediate_dropout_rate: float = 0.1 @@ -263,6 +278,9 @@ def setup(self) -> None: bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_axes_1=self.bias_axes_1, bias_axes_2=self.bias_axes_2, + enable_low_rank_adaptation=self.enable_low_rank_adaptation, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, return_layernorm_output=self.return_layernorm_output, activations=self.activations, intermediate_dropout_rate=self.intermediate_dropout_rate, diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index d0a37e89b8..b68909190b 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False @@ -208,6 +211,9 @@ def setup(self) -> None: enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, fuse_qkv_params=self.fuse_qkv_params, transpose_batch_sequence=self.transpose_batch_sequence, enable_sequence_parallel=self.enable_sequence_parallel, @@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer): enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None enable_relative_embedding: bool = True relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) drop_path: float = 0.0 @@ -332,6 +341,9 @@ def setup(self) -> None: enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, enable_relative_embedding=self.enable_relative_embedding, relative_embedding=relative_embedding_flax_module, drop_path=self.drop_path,