-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsharding.py
138 lines (121 loc) · 5.39 KB
/
sharding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base sharding functions from big vision changed for our nets and optimizers."""
import numpy as np
import jax
from jax.sharding import NamedSharding, PartitionSpec as P
import flax.linen as nn
from utils import tree_flatten_with_names, write_note
def infer_sharding(params, mesh, op):
"""Infer sharding spec for the given parameters.
Return a sharding tree and a spec tree.
"""
x_with_names, tree_def = tree_flatten_with_names(params)
names = tree_def.unflatten(list(zip(*x_with_names))[0])
specs = jax.tree.map(lambda x: (None,) * x.ndim, params)
specs = jax.tree.map(
lambda x, name, spec: op(spec, mesh, name, x),
params,
names,
specs,
# Preconditioners for PSGD and tearfree shampoo kept in lists
is_leaf=lambda v: isinstance(v, nn.Partitioned) or isinstance(v, list),
)
# Two-level tree_map to prevent it from doing traversal inside the spec.
specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs)
sharding = jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs)
return sharding, specs
def fsdp_sharding(axis, min_size_to_shard_mb=1):
"""Simple FSDP sharding rules."""
# TODO consider not overwriting already sharded dims
axis = axis if isinstance(axis, str) else tuple(axis)
axis_tuple = axis if isinstance(axis, tuple) else (axis,)
def _update_spec(cur_spec, mesh, name, x):
axis_size = np.prod([mesh.shape[a] for a in axis_tuple])
if isinstance(x, list):
# Preconditioners for PSGD and tearfree shampoo kept in lists
precond_specs = []
shard_dim = -2
for precond in x:
shape = precond.shape
new_sharding = [None for _ in shape]
if (
np.prod(shape) * precond.dtype.itemsize
>= min_size_to_shard_mb * (2**20)
and len(shape) > 1
and shape[shard_dim] % axis_size == 0
):
new_sharding[shard_dim] = axis
print(f"sharding {name}:{shape} to {new_sharding}")
precond_specs.append(tuple(new_sharding))
return precond_specs
shape = x.shape
# Partitioning rules, simple FSDP
# indexed backwards from last dim for friendliness to scanned leading dims
if (
np.prod(shape) * x.dtype.itemsize >= min_size_to_shard_mb * (2**20)
and len(shape) > 1
):
new_sharding = [None for _ in shape]
if "scale" in name or "bias" in name:
pass
elif any(s in name for s in ["embedding", "out_kernel", "down_kernel"]):
# shard these on last dim (-1)
if shape[-1] % axis_size == 0:
new_sharding[-1] = axis
print(f"sharding {name}:{shape} to {new_sharding}")
return tuple(new_sharding)
else:
print(
f"WARNING: Parameter {name}:{shape} is not sharded because "
f"last dimension is not divisible by axis size {axis_size}. "
"Consider changing last dim to be divisible by axis size."
)
elif any(
s in name
for s in [
"q_kernel",
"k_kernel",
"v_kernel",
"gate_kernel",
"up_kernel",
]
):
# shard these on first dim (-2)
if shape[-2] % axis_size == 0:
new_sharding[-2] = axis
print(f"sharding {name}:{shape} to {new_sharding}")
return tuple(new_sharding)
else:
print(
f"WARNING: Parameter {name}:{shape} is not sharded because "
f"first dimension is not divisible by axis size {axis_size}. "
"Consider changing first dim to be divisible by axis size."
)
else:
# If not explicitly sharded above, infer here by partitioning
# along largest axis that is divisible and not taken.
idx = np.argsort(shape)[::-1]
for i in idx:
if shape[i] % axis_size == 0:
if cur_spec[i] is None:
new_sharding[i] = axis
print(f"sharding {name}:{shape} to {new_sharding}")
return tuple(new_sharding)
write_note(
f"Parameter {name}:{shape} not sharded because did not meet rules "
f"or already occupied by other sharding rules: {cur_spec}"
)
return cur_spec
return _update_spec