Skip to content

Commit

Permalink
Support Low Rank Adaptation (LoRA). (NVIDIA#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 authored and pggPL committed May 9, 2024
1 parent 3064988 commit d98aa6f
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 2 deletions.
68 changes: 68 additions & 0 deletions tests/jax/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest

import jax
import jax.numpy as jnp

from utils import assert_allclose
from transformer_engine.jax.flax.module import _apply_low_rank_adaptation
from transformer_engine.jax.flax.module import _normalize_axes
from transformer_engine.jax.flax.transformer import LoRAScope
from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope


class TestLoRA:

def reference(x, la, lb, pattern, scale):
out = jnp.einsum(pattern, x, la, lb)
return out * scale

@pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'),
((-1,), (3, 1024), '...h,hkr,krz->...kz')])
@pytest.mark.parametrize('rank', [32, 16])
@pytest.mark.parametrize('alpha', [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape))
shape_in_axis = tuple(shape[ax] for ax in axis)

key = jax.random.key(1124)
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, shape, dtype)

key, la_key = jax.random.split(key)
la_shape = (*shape_in_axis, *features[:-1], rank)
la = jax.random.normal(la_key, la_shape, dtype)

key, lb_key = jax.random.split(key)
lb_shape = (*features[:-1], rank, features[-1])
lb = jax.random.normal(lb_key, lb_shape, dtype)

out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha)
scale_ref = alpha / rank if alpha is not None else 1.0
out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref)

assert_allclose(out_target, out_ref, dtype=dtype)

@pytest.mark.parametrize('scope_ref_assert',
[('none', LoRAScope(False, False, False), False),
('all', LoRAScope(True, True, True), False),
('qkv_proj', LoRAScope(True, False, False), False),
('output_proj', LoRAScope(False, True, False), False),
('mlp', LoRAScope(False, False, True), False),
('exclude_qkv_proj', LoRAScope(False, True, True), False),
('exclude_output_proj', LoRAScope(True, False, True), False),
('exclude_mlp', LoRAScope(True, True, False), False),
('messing_up', LoRAScope(), True)])
def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert
try:
lora_scope = _canonicalize_lora_scope(scope)
assert lora_scope == reference
except AssertionError as ae:
assert need_assert, f"{ae.args}"
44 changes: 44 additions & 0 deletions tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ class MultiHeadAttnAttr:
NUM_GQA_GROUPS = 'num_gqa_groups'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
Expand Down Expand Up @@ -853,6 +854,22 @@ class MultiHeadAttnAttr:
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all'
}]


Expand Down Expand Up @@ -883,6 +900,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
fuse_qkv_params = True
transpose_batch_sequence = True
scale_attn_logits = False
Expand All @@ -905,6 +923,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
Expand All @@ -926,6 +945,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
Expand Down Expand Up @@ -969,6 +989,7 @@ class TransformerLayerAttr:
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
Expand Down Expand Up @@ -1113,6 +1134,16 @@ class TransformerLayerAttr:
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
Expand Down Expand Up @@ -1185,6 +1216,16 @@ class TransformerLayerAttr:
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}]


Expand Down Expand Up @@ -1219,6 +1260,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
Expand Down Expand Up @@ -1257,6 +1299,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
Expand All @@ -1282,6 +1325,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)

Expand Down
Loading

0 comments on commit d98aa6f

Please sign in to comment.