Skip to content

Commit

Permalink
feat: Add bfloat16 support (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored Jan 20, 2025
1 parent a68d05b commit 2d15221
Show file tree
Hide file tree
Showing 44 changed files with 1,381 additions and 1,179 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ from mithril.models import Model, Linear
# Build a simple linear model
model = Linear(16)

# Create backends, specify the precision
backend_jax = ml.JaxBackend(precision=64)
backend_numpy = ml.NumpyBackend(precision=32)
# Create backends, specify the default dtype
backend_jax = ml.JaxBackend(dtype=ml.float64)
backend_numpy = ml.NumpyBackend(dtype=ml.float32)

# Compile the model with different backends, optionally specify
# the file to write the generated code into and whether to use jit
Expand Down
32 changes: 17 additions & 15 deletions benchmarks/speed_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import mithril as ml
from benchmarks.speed_benchmarks.jax_fns import mlp_v_jax
from benchmarks.speed_benchmarks.speed_helper import colorize_str
from benchmarks.speed_benchmarks.torch_fns import conv_v_torch, mlp_v_torch
from mithril.backends.utils import DtypeBits
from mithril.framework.common import Table
from mithril.models import Relu, Sigmoid, Tanh

# MLX is not included due to Ubuntu OS in Github
backends = ["Torch", "Jax"]
precisions = [64, 32, 16]
dtypes = [ml.float64, ml.float32, ml.float16]

iterations = 100
table = Table()
Expand Down Expand Up @@ -55,20 +57,20 @@

for backend in backends:
fn = mlp_v_jax if backend == "Jax" else mlp_v_torch
for precision in precisions:
if not (precision == 16 and backend == "Torch"):
for dtype in dtypes:
if not (DtypeBits[dtype.name].value == 16 and backend == "Torch"):
num_params, time_backend, time_mithril = fn(
activations=activations,
dimensions=dimensions,
input_shape=input_shape,
iterations=iterations,
precision=precision,
dtype=dtype,
)
table.add_row(
[
"MLP Large",
backend,
str(precision),
str(dtype),
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand All @@ -82,20 +84,20 @@

for backend in backends:
fn = mlp_v_jax if backend == "Jax" else mlp_v_torch
for precision in precisions:
if not (precision == 16 and backend == "Torch"):
for dtype in dtypes:
if not (DtypeBits[dtype.name].value == 16 and backend == "Torch"):
num_params, time_backend, time_mithril = fn(
activations=activations,
dimensions=dimensions,
input_shape=(128, 128),
iterations=iterations,
precision=precision,
dtype=dtype,
)
table.add_row(
[
"MLP Small",
backend,
str(precision),
dtype.name,
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand All @@ -107,21 +109,21 @@
dimensions = [12, 16, 32]
stride = (2, 2)
padding = 1
for precision in [32, 64]:
for dtype in [ml.float32, ml.float64]:
num_params, time_backend, time_mithril = conv_v_torch(
activations=activations,
dimensions=dimensions,
input_shape=(4, 4, 128, 128),
iterations=iterations,
precision=precision,
dtype=dtype,
stride=stride,
padding=padding,
)
table.add_row(
[
"Conv Small",
"Torch",
str(precision),
dtype.name,
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand All @@ -134,21 +136,21 @@
dimensions = [1024, 1024, 1024, 256]
stride = (2, 2)
padding = 2
for precision in [32, 64]:
for dtype in [ml.float32, ml.float64]:
num_params, time_backend, time_mithril = conv_v_torch(
activations=activations,
dimensions=dimensions,
input_shape=(2, 1, 128, 128),
iterations=iterations,
precision=precision,
dtype=dtype,
stride=stride,
padding=padding,
)
table.add_row(
[
"Conv Large",
"Torch",
str(precision),
dtype.name,
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/speed_benchmarks/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
create_compl_mlp,
measure_time_and_grads_mithril,
)
from mithril import JaxBackend
from mithril import JaxBackend, core
from mithril.backends.utils import DtypeBits
from mithril.models import (
AbsoluteError,
Gelu,
Expand Down Expand Up @@ -200,15 +201,15 @@ def mlp_v_jax(
activations: list,
dimensions: list[int],
input_shape: tuple[int, int],
precision: int,
dtype: core.Dtype,
iterations: int,
):
lr = 0.001
_input_shape, batch_size = input_shape
# batch_size, input_shape = input_shape[-1], input_shape[0]
output_shape = [_input_shape] + [dimensions[-1]]
device = "cpu"
dtype_jax = getattr(jnp, f"float{precision}")
dtype_jax = getattr(jnp, f"float{DtypeBits[dtype.name]}")
device = "cpu"
inputs = {
"input": jnp.array(np.random.randn(batch_size, *input_shape), dtype=dtype_jax),
Expand All @@ -231,7 +232,7 @@ def mlp_v_jax(
)
comp_ctx = mithril.compile(
model=ctx,
backend=JaxBackend(device=device, precision=precision),
backend=JaxBackend(device=device, dtype=dtype),
constant_keys=inputs,
)

Expand Down
15 changes: 8 additions & 7 deletions benchmarks/speed_benchmarks/torch_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
create_compl_mlp,
measure_time_and_grads_mithril,
)
from mithril import TorchBackend
from mithril import TorchBackend, core
from mithril.backends.utils import DtypeBits
from mithril.models import (
AbsoluteError,
Gelu,
Expand Down Expand Up @@ -138,14 +139,14 @@ def mlp_v_torch(
activations: list,
dimensions: list[int],
input_shape: tuple[int, int],
precision: int,
dtype: core.Dtype,
iterations: int,
):
lr = 0.001
batch_size, _input_shape = input_shape[-1], input_shape[0]
output_shape = [_input_shape] + [dimensions[-1]]
device = "cpu"
dtype_torch = getattr(torch, f"float{precision}")
dtype_torch = getattr(torch, f"float{DtypeBits[dtype.name]}")
torch.set_default_dtype(dtype_torch)

inputs = {
Expand All @@ -170,7 +171,7 @@ def mlp_v_torch(
)
comp_ctx = mithril.compile(
model=ctx,
backend=TorchBackend(device=device, precision=precision),
backend=TorchBackend(device=device, dtype=dtype),
constant_keys=inputs,
)
randomized_inputs = comp_ctx.randomize_params()
Expand Down Expand Up @@ -207,15 +208,15 @@ def conv_v_torch(
activations: list,
dimensions: list[int],
input_shape: tuple[int, int, int, int],
precision: int,
dtype: core.Dtype,
iterations: int,
stride: tuple[int, int] | int,
padding: int,
):
lr = 0.001
batch_size, in_shape, tensor_shape = input_shape[0], input_shape[1], input_shape[2:]
device = "cpu"
dtype_torch = getattr(torch, f"float{precision}")
dtype_torch = getattr(torch, f"float{DtypeBits[dtype.name]}")
torch.set_default_dtype(dtype_torch)
inputs = {
"input": torch.randn(*input_shape, device=device),
Expand Down Expand Up @@ -252,7 +253,7 @@ def conv_v_torch(
)
comp_ctx = mithril.compile(
model=ctx,
backend=TorchBackend(device=device, precision=precision),
backend=TorchBackend(device=device, dtype=dtype),
constant_keys=inputs,
)
randomized_inputs = comp_ctx.randomize_params()
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run_sample(
)

# Create backend.
backend_obj = backend_map[backend](precision=32, device="cpu")
backend_obj = backend_map[backend](device="cpu")
# Set seed.
backend_obj.set_seed(seed)
# Compile gpt model.
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/cnn_forcast_sine_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# TODO: Remove numpy dependencies from the code.

# Define backend. It would also work with any available backend you prefer.
backend = ml.TorchBackend(precision=32)
backend = ml.TorchBackend()


# Generate synthetic data: a sine wave
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/convolution_with_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)

# Set up device and precision of our backend of choice
backend = ml.TorchBackend(precision=32, device="mps")
backend = ml.TorchBackend(device="mps")

# Compile the model with given non-trainable keys
compiled_model = ml.compile(
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/many_to_one_any_backend_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mithril.models import ManyToOne, Mean, RNNCell, SquaredError, TrainModel

# Define backend. It would also work with any available backend you prefer.
backend = ml.JaxBackend(precision=64)
backend = ml.JaxBackend(dtype=ml.float64)

batch_size = 20
input_features = 10
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/variable_length_many_to_one_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# given an ordered sequence of numbers.

# Define the backend
backend = ml.JaxBackend(precision=64)
backend = ml.JaxBackend(dtype=ml.float64)
backend.set_seed(42)

# Prepare training data. We will test the case for which the input data
Expand Down
2 changes: 2 additions & 0 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .core import (
Constant,
DataType,
bfloat16,
bool,
double,
epsilon_table,
Expand Down Expand Up @@ -50,6 +51,7 @@
"bool",
"float",
"float16",
"bfloat16",
"float32",
"float64",
"int",
Expand Down
31 changes: 16 additions & 15 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .. import core
from ..core import DataType
from .parallel import Parallel
from .utils import DtypeBits

__all__ = ["Backend"]

Expand All @@ -34,31 +35,31 @@ class Backend(ABC, Generic[DataType]):

backend_type = ""
device_type = None
supported_precisions = [16, 32, 64]
is_installed = True
_device: Any
_precision: int
_dtype: core.Dtype
supported_dtypes = [
core.Dtype.float16,
core.Dtype.bfloat16,
core.Dtype.float32,
core.Dtype.float64,
]
primitive_function_dict: dict[str, Callable[..., DataType | Any]]
registered_primitives: dict[str, Callable[..., DataType]]
array_creation_funcs: list[str]
primitive_fn_path: str

def __init__(self, precision: int = 32, device: str = "cpu") -> None:
# Check if given precision is a valid one.
if self.precision not in self.supported_precisions:
raise Exception(
f"'{self.precision}' bits precision is not available!"
" Available precisions: '{self.supported_precisions}'"
def __init__(self, dtype: core.Dtype = core.float32, device: str = "cpu") -> None:
# Check if given dtype is a valid one.
if dtype not in self.supported_dtypes:
raise ValueError(
f"Invalid dtype {dtype}. Supported dtypes are {self.supported_dtypes}."
)
self.seed = 10 # Can be set any integer.

# Initialize epsilon constants according to given precision.
# for key, value in core.epsilon_table[f"float{self.precision}"].items():
# setattr(self, key, value)

@property
def precision(self) -> int:
return self._precision
return DtypeBits[self._dtype.name].value

#!!
@property
Expand Down Expand Up @@ -1076,11 +1077,11 @@ def __repr__(self) -> str:


class ParallelBackend(Backend[DataType]):
def __init__(self, device_mesh: tuple[int, ...] | None) -> None:
def __init__(self, dtype: core.Dtype, device_mesh: tuple[int, ...] | None) -> None:
assert (
isinstance(device_mesh, tuple) or device_mesh is None
), "device_mesh must be a tuple or None."
super().__init__()
super().__init__(dtype=dtype)

self._raw_device_mesh = device_mesh
self.n_devices = math.prod(device_mesh) if device_mesh is not None else 1
Expand Down
25 changes: 25 additions & 0 deletions mithril/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from collections.abc import Sequence

from ..utils.type_utils import is_tuple_int
Expand All @@ -36,3 +37,27 @@ def process_shape(
)

return _shape


class DtypeBits(enum.IntEnum):
bool = 8
int8 = 8
int16 = 16
int32 = 32
int64 = 64
float16 = 16
bfloat16 = 16
float32 = 32
float64 = 64


class DtypeSubTypes(enum.Enum):
bool = "bool"
int8 = "int"
int16 = "int"
int32 = "int"
int64 = "int"
float16 = "float"
bfloat16 = "bfloat"
float32 = "float"
float64 = "float"
Loading

0 comments on commit 2d15221

Please sign in to comment.