Skip to content

Commit

Permalink
Reuse func from the flax
Browse files Browse the repository at this point in the history
Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 committed Jan 19, 2024
1 parent 6b6b556 commit 41c2dcf
Showing 1 changed file with 3 additions and 30 deletions.
33 changes: 3 additions & 30 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
Expand Down Expand Up @@ -108,33 +109,6 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules)


def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = func(mask, other_mask)
return mask


def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
"""Combine attention masks."""
func = jnp.logical_and
return _merge_mask(func, *masks).astype(dtype)


def combine_biases(*masks: Optional[Array]):
"""Combine attention biases."""

def func(a, b):
return a + b

return _merge_mask(func, *masks)


def core_attention(query: Array,
key: Array,
value: Array,
Expand Down Expand Up @@ -242,9 +216,6 @@ def rope(x: Array, windows: Tuple[int, int], transpose_batch_seqlen: bool):
return jnp.concatenate([part_1, part_2], axis=-1)


dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r"""
Multi-head Attention (MHA), including Query,
Expand Down Expand Up @@ -709,6 +680,8 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq):
mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))

if bias is not None:
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim,
in_axes=(None, 0, None, None))
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)

Expand Down

0 comments on commit 41c2dcf

Please sign in to comment.