From 747ad5765a78ee7ea0019928c817d1f90cdc6dcf Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 5 Apr 2024 09:41:59 +0000 Subject: [PATCH] chore: adapt xla_model_parallel style to avoid CI complaints --- optimum/tpu/xla_model_parallel.py | 126 ++++++++++-------------------- 1 file changed, 40 insertions(+), 86 deletions(-) diff --git a/optimum/tpu/xla_model_parallel.py b/optimum/tpu/xla_model_parallel.py index ed7b3441..6ca3b359 100644 --- a/optimum/tpu/xla_model_parallel.py +++ b/optimum/tpu/xla_model_parallel.py @@ -28,7 +28,7 @@ EPS = torch.finfo(torch.float32).eps -USE_CUDA = os.environ.get('USE_CUDA', False) +USE_CUDA = os.environ.get("USE_CUDA", False) if not USE_CUDA: import torch_xla.core.xla_model as xm @@ -188,27 +188,20 @@ def backward(ctx, grad_output): # type: ignore # ----------------- -def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size, - rank) -> torch.Tensor: +def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: return _CopyToModelParallelRegion.apply(input_, groups, world_size, rank) -def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size, - rank) -> torch.Tensor: - return _ReduceFromModelParallelRegion.apply(input_, groups, world_size, - rank) +def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: + return _ReduceFromModelParallelRegion.apply(input_, groups, world_size, rank) -def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size, - rank) -> torch.Tensor: - return _ScatterToModelParallelRegion.apply(input_, groups, world_size, - rank) +def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: + return _ScatterToModelParallelRegion.apply(input_, groups, world_size, rank) -def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size, - rank) -> torch.Tensor: - return _GatherFromModelParallelRegion.apply(input_, groups, world_size, - rank) +def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: + return _GatherFromModelParallelRegion.apply(input_, groups, world_size, rank) def ensure_divisibility(numerator: int, denominator: int) -> None: @@ -244,6 +237,7 @@ def split_tensor_along_last_dim( return tensor_list + # Below copied from fairscale/nn/model_parallel/layers.py @@ -255,8 +249,7 @@ def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: # All-reduce. if USE_CUDA: - input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, - RANKSET, GROUP_SIZE) + input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, RANKSET, GROUP_SIZE) else: input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups) @@ -299,9 +292,7 @@ def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: idx = input_.dim() - 1 - last_dim padding[2 * idx] = left * size padding[2 * idx + 1] = right * size - output = torch.ops.c10d_functional.all_reduce(F.pad(input_, - padding), "sum", - TAG, RANKSET, GROUP_SIZE) + output = torch.ops.c10d_functional.all_reduce(F.pad(input_, padding), "sum", TAG, RANKSET, GROUP_SIZE) else: output = xm.all_gather(input_, dim=-1, groups=groups) @@ -334,18 +325,12 @@ def _initialize_affine_weight( return None # Initialize master weight - master_weight = torch.empty(out_features, - in_features, - dtype=weight.dtype, - requires_grad=False) + master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) init_method(master_weight) # Split and copy - per_partition_per_stride_size = divide_and_check_no_remainder( - per_partition_size, stride) - weight_list = torch.split(master_weight, - per_partition_per_stride_size, - dim=partition_dim) + per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) my_weight_list = weight_list[rank::world_size] with torch.no_grad(): @@ -375,8 +360,7 @@ def __init__( norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - init_method: Callable[[torch.Tensor], - torch.Tensor] = init.xavier_normal_, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, keep_master_weight_for_test: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, @@ -405,8 +389,7 @@ def __init__( self._weight = None self.quant = quant # Divide the weight matrix along the embedding dimension. - self.embedding_dim_per_partition = divide_and_check_no_remainder( - self.embedding_dim, self.world_size) + self.embedding_dim_per_partition = divide_and_check_no_remainder(self.embedding_dim, self.world_size) # Allocate weights. if quant: @@ -419,9 +402,7 @@ def __init__( ) self.weight_scaler = Parameter(torch.Tensor(self.num_embeddings)) else: - self.weight = Parameter( - torch.Tensor(self.num_embeddings, - self.embedding_dim_per_partition)) + self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition)) # And initialize. _initialize_affine_weight( @@ -438,14 +419,11 @@ def __init__( ) def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore - input_parallel = copy_to_model_parallel_region(input_, self.groups, - self.world_size, - self.rank) + input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank) # PyTorch eager and inductor do not accept negative values in the input to embedding # layers. Take the modulus to avoid this error. if USE_CUDA: - input_parallel = torch.remainder(input_parallel, - self.weight.shape[0]) + input_parallel = torch.remainder(input_parallel, self.weight.shape[0]) weight = self.weight if self.quant: weight = weight * self.weight_scaler.unsqueeze(-1) @@ -458,9 +436,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore self.scale_grad_by_freq, self.sparse, ) - output = gather_from_model_parallel_region(output_parallel, - self.groups, - self.world_size, self.rank) + output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank) return output @@ -489,8 +465,7 @@ def __init__( out_features: int, bias: bool = True, gather_output: bool = True, - init_method: Callable[[torch.Tensor], - torch.Tensor] = init.xavier_normal_, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, stride: int = 1, keep_master_weight_for_test: bool = False, world_size: Optional[int] = None, @@ -515,8 +490,7 @@ def __init__( self.gather_output = gather_output self.quant = quant # Divide the weight matrix along the last dimension. - self.output_size_per_partition = divide_and_check_no_remainder( - out_features, self.world_size) + self.output_size_per_partition = divide_and_check_no_remainder(out_features, self.world_size) # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -529,18 +503,16 @@ def __init__( ), requires_grad=False, ) - self.weight_scaler = Parameter( - torch.Tensor(self.output_size_per_partition)) + self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition)) else: - self.weight = Parameter( - torch.Tensor(self.output_size_per_partition, self.in_features)) + self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) if bias: self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # Initialize weight. self.master_weight = _initialize_affine_weight( @@ -567,12 +539,10 @@ def get_master_weight(self) -> torch.Tensor: def set_quantize(self): assert not self.quant self.weight = Parameter( - torch.empty((self.output_size_per_partition, self.in_features), - dtype=torch.int8), + torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8), requires_grad=False, ) - self.weight_scaler = Parameter( - torch.Tensor(self.output_size_per_partition)) + self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition)) self.quant = True def quantize(self): @@ -581,12 +551,10 @@ def quantize(self): orig_dtype = fp_w.dtype fp_w = fp_w.to(torch.float32) self.weight = Parameter( - torch.empty((self.output_size_per_partition, self.in_features), - dtype=torch.int8), + torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8), requires_grad=False, ) - self.weight_scaler = Parameter( - torch.Tensor(self.output_size_per_partition)) + self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition)) qconfig = TensorQConfig(axis=0) self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig) self.weight_scaler.data = scale.to(orig_dtype) @@ -594,9 +562,7 @@ def quantize(self): def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_, self.groups, - self.world_size, - self.rank) + input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank) # Matrix multiply. if self.quant and USE_CUDA: # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear. @@ -609,10 +575,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore output_parallel = F.linear(input_parallel, self.weight, self.bias) if self.gather_output: # All-gather across the partitions. - output = gather_from_model_parallel_region(output_parallel, - self.groups, - self.world_size, - self.rank) + output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank) else: output = output_parallel return output @@ -649,8 +612,7 @@ def __init__( out_features: int, bias: bool = True, input_is_parallel: bool = False, - init_method: Callable[[torch.Tensor], - torch.Tensor] = init.xavier_normal_, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, stride: int = 1, keep_master_weight_for_test: bool = False, world_size: Optional[int] = None, @@ -675,8 +637,7 @@ def __init__( self.input_is_parallel = input_is_parallel self.quant = quant # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide_and_check_no_remainder( - in_features, self.world_size) + self.input_size_per_partition = divide_and_check_no_remainder(in_features, self.world_size) # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -691,15 +652,14 @@ def __init__( ) self.weight_scaler = Parameter(torch.Tensor(self.out_features)) else: - self.weight = Parameter( - torch.Tensor(self.out_features, self.input_size_per_partition)) + self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) if bias: self.bias = Parameter(torch.Tensor(self.out_features)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # Initialize weight. self.master_weight = _initialize_affine_weight( @@ -716,14 +676,12 @@ def __init__( ) def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data, self.groups, - self.world_size, self.rank) + return gather_from_model_parallel_region(self.weight.data, self.groups, self.world_size, self.rank) def set_quantize(self): assert not self.quant self.weight = Parameter( - torch.empty((self.out_features, self.input_size_per_partition), - dtype=torch.int8), + torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8), requires_grad=False, ) self.weight_scaler = Parameter(torch.Tensor(self.out_features)) @@ -735,8 +693,7 @@ def quantize(self): orig_dtype = fp_w.dtype fp_w = fp_w.to(torch.float32) self.weight = Parameter( - torch.empty((self.out_features, self.input_size_per_partition), - dtype=torch.int8), + torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8), requires_grad=False, ) self.weight_scaler = Parameter(torch.Tensor(self.out_features)) @@ -750,8 +707,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore if self.input_is_parallel: input_parallel = input_ else: - input_parallel = scatter_to_model_parallel_region( - input_, self.groups, self.world_size, self.rank) + input_parallel = scatter_to_model_parallel_region(input_, self.groups, self.world_size, self.rank) # Matrix multiply. if self.quant and USE_CUDA: # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear. @@ -763,9 +719,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore else: output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. - output_ = reduce_from_model_parallel_region(output_parallel, - self.groups, - self.world_size, self.rank) + output_ = reduce_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank) if self.bias is not None: output = output_ + self.bias else: