Skip to content

Commit

Permalink
chore: adapt xla_model_parallel style to avoid CI complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Apr 5, 2024
1 parent 3aaaf96 commit 747ad57
Showing 1 changed file with 40 additions and 86 deletions.
126 changes: 40 additions & 86 deletions optimum/tpu/xla_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

EPS = torch.finfo(torch.float32).eps

USE_CUDA = os.environ.get('USE_CUDA', False)
USE_CUDA = os.environ.get("USE_CUDA", False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -188,27 +188,20 @@ def backward(ctx, grad_output): # type: ignore
# -----------------


def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_, groups, world_size, rank)


def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_, groups, world_size,
rank)
def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_, groups, world_size, rank)


def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_, groups, world_size,
rank)
def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_, groups, world_size, rank)


def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_, groups, world_size,
rank)
def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_, groups, world_size, rank)


def ensure_divisibility(numerator: int, denominator: int) -> None:
Expand Down Expand Up @@ -244,6 +237,7 @@ def split_tensor_along_last_dim(

return tensor_list


# Below copied from fairscale/nn/model_parallel/layers.py


Expand All @@ -255,8 +249,7 @@ def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:

# All-reduce.
if USE_CUDA:
input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG,
RANKSET, GROUP_SIZE)
input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, RANKSET, GROUP_SIZE)
else:
input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups)

Expand Down Expand Up @@ -299,9 +292,7 @@ def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
idx = input_.dim() - 1 - last_dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
output = torch.ops.c10d_functional.all_reduce(F.pad(input_,
padding), "sum",
TAG, RANKSET, GROUP_SIZE)
output = torch.ops.c10d_functional.all_reduce(F.pad(input_, padding), "sum", TAG, RANKSET, GROUP_SIZE)
else:
output = xm.all_gather(input_, dim=-1, groups=groups)

Expand Down Expand Up @@ -334,18 +325,12 @@ def _initialize_affine_weight(
return None

# Initialize master weight
master_weight = torch.empty(out_features,
in_features,
dtype=weight.dtype,
requires_grad=False)
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
init_method(master_weight)

# Split and copy
per_partition_per_stride_size = divide_and_check_no_remainder(
per_partition_size, stride)
weight_list = torch.split(master_weight,
per_partition_per_stride_size,
dim=partition_dim)
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
my_weight_list = weight_list[rank::world_size]

with torch.no_grad():
Expand Down Expand Up @@ -375,8 +360,7 @@ def __init__(
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor],
torch.Tensor] = init.xavier_normal_,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
Expand Down Expand Up @@ -405,8 +389,7 @@ def __init__(
self._weight = None
self.quant = quant
# Divide the weight matrix along the embedding dimension.
self.embedding_dim_per_partition = divide_and_check_no_remainder(
self.embedding_dim, self.world_size)
self.embedding_dim_per_partition = divide_and_check_no_remainder(self.embedding_dim, self.world_size)

# Allocate weights.
if quant:
Expand All @@ -419,9 +402,7 @@ def __init__(
)
self.weight_scaler = Parameter(torch.Tensor(self.num_embeddings))
else:
self.weight = Parameter(
torch.Tensor(self.num_embeddings,
self.embedding_dim_per_partition))
self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition))

# And initialize.
_initialize_affine_weight(
Expand All @@ -438,14 +419,11 @@ def __init__(
)

def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
input_parallel = copy_to_model_parallel_region(input_, self.groups,
self.world_size,
self.rank)
input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# PyTorch eager and inductor do not accept negative values in the input to embedding
# layers. Take the modulus to avoid this error.
if USE_CUDA:
input_parallel = torch.remainder(input_parallel,
self.weight.shape[0])
input_parallel = torch.remainder(input_parallel, self.weight.shape[0])
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
Expand All @@ -458,9 +436,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
self.scale_grad_by_freq,
self.sparse,
)
output = gather_from_model_parallel_region(output_parallel,
self.groups,
self.world_size, self.rank)
output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
return output


Expand Down Expand Up @@ -489,8 +465,7 @@ def __init__(
out_features: int,
bias: bool = True,
gather_output: bool = True,
init_method: Callable[[torch.Tensor],
torch.Tensor] = init.xavier_normal_,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
Expand All @@ -515,8 +490,7 @@ def __init__(
self.gather_output = gather_output
self.quant = quant
# Divide the weight matrix along the last dimension.
self.output_size_per_partition = divide_and_check_no_remainder(
out_features, self.world_size)
self.output_size_per_partition = divide_and_check_no_remainder(out_features, self.world_size)

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand All @@ -529,18 +503,16 @@ def __init__(
),
requires_grad=False,
)
self.weight_scaler = Parameter(
torch.Tensor(self.output_size_per_partition))
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
else:
self.weight = Parameter(
torch.Tensor(self.output_size_per_partition, self.in_features))
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)

# Initialize weight.
self.master_weight = _initialize_affine_weight(
Expand All @@ -567,12 +539,10 @@ def get_master_weight(self) -> torch.Tensor:
def set_quantize(self):
assert not self.quant
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features),
dtype=torch.int8),
torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(
torch.Tensor(self.output_size_per_partition))
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
self.quant = True

def quantize(self):
Expand All @@ -581,22 +551,18 @@ def quantize(self):
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features),
dtype=torch.int8),
torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(
torch.Tensor(self.output_size_per_partition))
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True

def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_, self.groups,
self.world_size,
self.rank)
input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
Expand All @@ -609,10 +575,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel,
self.groups,
self.world_size,
self.rank)
output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
else:
output = output_parallel
return output
Expand Down Expand Up @@ -649,8 +612,7 @@ def __init__(
out_features: int,
bias: bool = True,
input_is_parallel: bool = False,
init_method: Callable[[torch.Tensor],
torch.Tensor] = init.xavier_normal_,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
Expand All @@ -675,8 +637,7 @@ def __init__(
self.input_is_parallel = input_is_parallel
self.quant = quant
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide_and_check_no_remainder(
in_features, self.world_size)
self.input_size_per_partition = divide_and_check_no_remainder(in_features, self.world_size)

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand All @@ -691,15 +652,14 @@ def __init__(
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
else:
self.weight = Parameter(
torch.Tensor(self.out_features, self.input_size_per_partition))
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)

# Initialize weight.
self.master_weight = _initialize_affine_weight(
Expand All @@ -716,14 +676,12 @@ def __init__(
)

def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data, self.groups,
self.world_size, self.rank)
return gather_from_model_parallel_region(self.weight.data, self.groups, self.world_size, self.rank)

def set_quantize(self):
assert not self.quant
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition),
dtype=torch.int8),
torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
Expand All @@ -735,8 +693,7 @@ def quantize(self):
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition),
dtype=torch.int8),
torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
Expand All @@ -750,8 +707,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(
input_, self.groups, self.world_size, self.rank)
input_parallel = scatter_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
Expand All @@ -763,9 +719,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
else:
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel,
self.groups,
self.world_size, self.rank)
output_ = reduce_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
if self.bias is not None:
output = output_ + self.bias
else:
Expand Down

0 comments on commit 747ad57

Please sign in to comment.