diff --git a/.coveragerc b/.coveragerc index 5e874df8..e69de29b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +0,0 @@ -[run] -omit = - /private/var/* - tmp/* - tests/* - diff --git a/.github/workflows/ci-test-example.yml b/.github/workflows/ci-test-example.yml index d24ae221..65abb893 100644 --- a/.github/workflows/ci-test-example.yml +++ b/.github/workflows/ci-test-example.yml @@ -15,15 +15,24 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v2 + with: + submodules: 'recursive' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Make compile script executable - run: chmod +x ./mithril/cores/c/compile.sh + run: | + chmod +x ./mithril/cores/c/raw_c/compile.sh + chmod +x ./mithril/cores/c/ggml/compile.sh + chmod +x ./mithril/cores/c/ggml/build_ggml.sh - name: Compile C code run: | - pushd ./mithril/cores/c + pushd ./mithril/cores/c/raw_c + ./compile.sh + popd + pushd ./mithril/cores/c/ggml + ./build_ggml.sh ./compile.sh popd - name: Install Python dependencies diff --git a/.github/workflows/ci-test-macos.yaml b/.github/workflows/ci-test-macos.yaml index b436e502..7f52cdbf 100644 --- a/.github/workflows/ci-test-macos.yaml +++ b/.github/workflows/ci-test-macos.yaml @@ -15,15 +15,24 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v2 + with: + submodules: 'recursive' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Make compile script executable - run: chmod +x ./mithril/cores/c/compile.sh + - name: Make compile scripts executable + run: | + chmod +x ./mithril/cores/c/raw_c/compile.sh + chmod +x ./mithril/cores/c/ggml/compile.sh + chmod +x ./mithril/cores/c/ggml/build_ggml.sh - name: Compile C code run: | - pushd ./mithril/cores/c + pushd ./mithril/cores/c/raw_c + ./compile.sh + popd + pushd ./mithril/cores/c/ggml + ./build_ggml.sh ./compile.sh popd - name: Install Python dependencies @@ -38,7 +47,6 @@ jobs: python3 -m pip install mypy python3 -m pip install pre-commit pre-commit run --all-files - python3 license_checker.py - name: Execute testcase unit tests run: pytest --cov --cov-report=xml -s tests/ - name: Upload results to Codecov diff --git a/.github/workflows/ci-test-ubuntu.yaml b/.github/workflows/ci-test-ubuntu.yaml index 7d44f70e..bd3cb44c 100644 --- a/.github/workflows/ci-test-ubuntu.yaml +++ b/.github/workflows/ci-test-ubuntu.yaml @@ -13,15 +13,24 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v2 + with: + submodules: 'recursive' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Make compile script executable - run: chmod +x ./mithril/cores/c/compile.sh + - name: Make compile scripts executable + run: | + chmod +x ./mithril/cores/c/raw_c/compile.sh + chmod +x ./mithril/cores/c/ggml/compile.sh + chmod +x ./mithril/cores/c/ggml/build_ggml.sh - name: Compile C code run: | - pushd ./mithril/cores/c + pushd ./mithril/cores/c/raw_c + ./compile.sh + popd + pushd ./mithril/cores/c/ggml + ./build_ggml.sh ./compile.sh popd - name: Install Python dependencies @@ -35,7 +44,6 @@ jobs: python3 -m pip install mypy python3 -m pip install pre-commit pre-commit run --all-files - python3 license_checker.py - name: Execute testcase unit tests run: pytest -s tests/ diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 2755dbd0..ecdec40c 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -15,6 +15,8 @@ jobs: steps: - name: Check out code uses: actions/checkout@v3 + with: + submodules: 'recursive' - name: Set up Python uses: actions/setup-python@v4 diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..098cf55a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "mithril/cores/c/ggml/ggml"] + path = mithril/cores/c/ggml/ggml + url = https://github.com/ggml-org/ggml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66dc43c4..f2bae8f6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: '^mithril/cores/c/ggml/ggml/' repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. @@ -15,12 +16,14 @@ repos: name: license checker entry: python3 license_checker.py language: python + args: ["--exclude=mithril/cores/c/ggml/ggml"] - id: mypy name: mypy entry: mypy . language: system always_run: true pass_filenames: false + args: ["--exclude=mithril/cores/c/ggml/ggml"] # - repo: https://github.com/pre-commit/mirrors-mypy # rev: v1.12.0 diff --git a/MANIFEST.in b/MANIFEST.in index e99726e7..e1ee745f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include mithril/cores/c/libmithrilc.so recursive-include mithril/cores * -recursive-include mithril/backends/with_manualgrad/c_backend/src * \ No newline at end of file +recursive-include mithril/backends/with_manualgrad/c_backend/src * diff --git a/license_checker.py b/license_checker.py index f17acfc4..e83a25a8 100755 --- a/license_checker.py +++ b/license_checker.py @@ -14,6 +14,14 @@ import os +import sys + +# Take the exclude argument with argparse +exclude = "" +for arg in sys.argv: + if arg.startswith("--exclude="): + exclude = arg.split("=")[1] + break license_py = """# Copyright 2022 Synnada, Inc. # @@ -45,8 +53,9 @@ // limitations under the License. """ current_directory = os.getcwd() -# Walk through the directory recursively + +# Walk through the directory recursively for root, _, files in os.walk(current_directory): if os.path.basename(root) == "tmp": continue @@ -54,10 +63,17 @@ if filename.endswith((".py", ".c", ".h")): # Check for .py .h and .c files file_path = os.path.join(root, filename) + if exclude in file_path: + continue + # Check if it's a file if os.path.isfile(file_path): with open(file_path, encoding="utf-8", errors="ignore") as file: - file_license = "".join(next(file) for _ in range(13)) + lines = file.readlines() + if len(lines) < 13: + raise Exception(f"No license found in {file_path}") + + file_license = "".join(lines[:13]) license = license_py if filename.endswith(".py") else license_ch diff --git a/mithril/__init__.py b/mithril/__init__.py index 4a1bbbdb..e9b02065 100644 --- a/mithril/__init__.py +++ b/mithril/__init__.py @@ -46,6 +46,7 @@ "MlxBackend", "TorchBackend", "CBackend", + "GGMLBackend", "NumpyBackend", "compile", "DataType", @@ -92,6 +93,11 @@ except Exception: CBackend = UnavailableBackend # type: ignore +try: + from .backends.with_manualgrad.ggml_backend.backend import GGMLBackend +except ImportError: + GGMLBackend = UnavailableBackend # type: ignore + try: from .backends.with_manualgrad.numpy_backend.backend import NumpyBackend except ImportError: diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index 23255a31..c21010e1 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -19,6 +19,7 @@ from typing import Any, Generic, overload from .. import types +from ..common import CGenConfig, PythonGenConfig from ..types import DataType from .parallel import Parallel from .utils import DtypeBits, StaticScalar @@ -49,6 +50,7 @@ class Backend(ABC, Generic[DataType]): registered_primitives: dict[str, Callable[..., DataType]] array_creation_funcs: list[str] primitive_fn_path: str + CODEGEN_CONFIG: PythonGenConfig | CGenConfig def __init__(self, dtype: types.Dtype = types.float32, device: str = "cpu") -> None: # Check if given dtype is a valid one. @@ -90,10 +92,6 @@ def e(self) -> float: def is_manualgrad(self) -> bool: raise NotImplementedError("is_manualgrad is not implemented") - @property - def codegen_config(self) -> dict[str, bool]: - raise NotImplementedError("codegen_config is not implemented") - def get_backend_array_type(self) -> type[DataType]: raise NotImplementedError("get_backend_array_type is not implemented") diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 461a4e24..9fa9c703 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -27,7 +27,6 @@ from ...utils import DtypeSubTypes, StaticScalar, process_shape from . import utils from .parallel import JaxParallel -from .utils import CODEGEN_CONFIG __all__ = ["JaxBackend"] @@ -52,6 +51,7 @@ class JaxBackend(ParallelBackend[jax.numpy.ndarray]): backend_type = "jax" registered_primitives: dict[str, Callable[..., jax.numpy.ndarray]] = {} primitive_fn_path = "mithril.cores.python.jax.ops" + CODEGEN_CONFIG = utils.CODEGEN_CONFIG def __init__( self, @@ -107,10 +107,6 @@ def get_device(self) -> Any: def DataType(self) -> type[jax.Array]: # noqa: N802 return utils.ArrayType - @property - def codegen_config(self) -> dict[str, bool]: - return CODEGEN_CONFIG - @staticmethod def get_available_devices() -> list[str]: """Static method to get a list of available devices. diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index 3e4bdadb..1a5e29c9 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -19,13 +19,11 @@ import numpy as np from .... import types -from ....common import find_dominant_type +from ....common import PythonGenConfig, find_dominant_type from ....cores.python.jax.utils import dtype_map from ...utils import DtypeSubTypes -CODEGEN_CONFIG: dict[str, bool] = { - "specify_device": True, -} +CODEGEN_CONFIG = PythonGenConfig(SPECIFY_DEVICE=True) ArrayType = jax.Array diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index 277f6871..65e36309 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -37,6 +37,7 @@ class MlxBackend(Backend[mx.array]): supported_dtypes = [Dtype.float16, Dtype.bfloat16, Dtype.float32] registered_primitives: dict[str, Callable[..., mx.array]] = {} primitive_fn_path = "mithril.cores.python.mlx.ops" + CODEGEN_CONFIG = utils.CODEGEN_CONFIG def __init__( self, @@ -75,10 +76,6 @@ def nan(self) -> float: def device(self) -> Any: utils.get_device(self._device) - @property - def codegen_config(self) -> dict[str, bool]: - return utils.CODEGEN_CONFIG - def get_device(self) -> Any: return self._device diff --git a/mithril/backends/with_autograd/mlx_backend/utils.py b/mithril/backends/with_autograd/mlx_backend/utils.py index c99ef2fe..8e1a034f 100644 --- a/mithril/backends/with_autograd/mlx_backend/utils.py +++ b/mithril/backends/with_autograd/mlx_backend/utils.py @@ -21,13 +21,12 @@ import numpy as np from .... import types -from ....common import find_dominant_type +from ....common import PythonGenConfig, find_dominant_type from ....cores.python.mlx.utils import dtype_map from ...utils import DtypeSubTypes -CODEGEN_CONFIG: dict[str, bool] = { - "specify_device": True, -} +CODEGEN_CONFIG = PythonGenConfig(SPECIFY_DEVICE=True) + ArrayType = mx.array diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index 549c381a..9f1e24c5 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -51,6 +51,7 @@ class TorchBackend(ParallelBackend[torch.Tensor]): backend_type = "torch" registered_primitives = {} primitive_fn_path = "mithril.cores.python.torch.ops" + CODEGEN_CONFIG = utils.CODEGEN_CONFIG def __init__( self, @@ -93,10 +94,6 @@ def nan(self) -> float: def DataType(self) -> type[torch.Tensor]: # noqa: N802 return utils.ArrayType - @property - def codegen_config(self) -> dict[str, bool]: - return utils.CODEGEN_CONFIG - @property def device(self) -> torch.device: return utils.get_device(self._device) diff --git a/mithril/backends/with_autograd/torch_backend/utils.py b/mithril/backends/with_autograd/torch_backend/utils.py index e63181c8..21d6a0a7 100644 --- a/mithril/backends/with_autograd/torch_backend/utils.py +++ b/mithril/backends/with_autograd/torch_backend/utils.py @@ -30,13 +30,11 @@ from torch.distributed._tensor import DeviceMesh from .... import types -from ....common import find_dominant_type +from ....common import PythonGenConfig, find_dominant_type from ....cores.python.torch.utils import dtype_map from ...utils import DtypeSubTypes -CODEGEN_CONFIG: dict[str, bool] = { - "specify_device": True, -} +CODEGEN_CONFIG = PythonGenConfig(SPECIFY_DEVICE=True) AVAILABLE_BACKEND_TYPES = ["cpu", "cuda"] diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index 830ac134..44ac609c 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes import os from typing import Any import numpy as np from .... import types -from ....cores.c import array from ....cores.c.array import PyArray +from ....cores.c.raw_c import array from ...backend import Backend from ...utils import process_shape from . import utils @@ -29,7 +30,10 @@ class CBackend(Backend[PyArray]): backend_type = "c" - SRC_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "cores", "c") + SRC_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "cores", "c", "raw_c" + ) + CODEGEN_CONFIG = utils.CODEGEN_CONFIG def __init__(self) -> None: self._device = "cpu" @@ -90,3 +94,6 @@ def array( assert dtype is None, "dtype is not supported in CBackend" input = input.astype(np.float32) return utils.from_numpy(input) + + def get_struct_cls(self) -> type[ctypes.Structure]: + return array.Array diff --git a/mithril/backends/with_manualgrad/c_backend/utils.py b/mithril/backends/with_manualgrad/c_backend/utils.py index b16a8a27..c93175bb 100644 --- a/mithril/backends/with_manualgrad/c_backend/utils.py +++ b/mithril/backends/with_manualgrad/c_backend/utils.py @@ -16,17 +16,33 @@ import numpy as np -from ....cores.c.array import ( +from ....common import CGenConfig +from ....cores.c.array import PyArray +from ....cores.c.raw_c.array import ( Array, - PyArray, lib, to_c_float_array, to_c_int_array, ) +CODEGEN_CONFIG = CGenConfig() + +# File configs +CODEGEN_CONFIG.HEADER_NAME = "cbackend.h" + +# Array configs +CODEGEN_CONFIG.ARRAY_NAME = "Array" + +# Function configs +CODEGEN_CONFIG.RETURN_OUTPUT = False +CODEGEN_CONFIG.USE_OUTPUT_AS_INPUT = True + +# Memory Management configs +CODEGEN_CONFIG.ALLOCATE_INTERNALS = True + def to_numpy(array: PyArray) -> np.ndarray[Any, Any]: - return np.ctypeslib.as_array(array.arr.contents.data, shape=(array.shape)) + return np.ctypeslib.as_array(array.arr.data, shape=(array.shape)) def from_numpy(array: np.ndarray[Any, Any]) -> PyArray: @@ -36,4 +52,4 @@ def from_numpy(array: np.ndarray[Any, Any]) -> PyArray: c_shape = to_c_int_array(shape) c_data = to_c_float_array(array) # type: ignore arr: Array = lib.create_struct(c_data, ndim, c_shape) - return PyArray(arr, shape) + return PyArray(arr.contents, shape) diff --git a/mithril/backends/with_manualgrad/ggml_backend/__init__.py b/mithril/backends/with_manualgrad/ggml_backend/__init__.py new file mode 100644 index 00000000..5f383378 --- /dev/null +++ b/mithril/backends/with_manualgrad/ggml_backend/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .backend import GGMLBackend + +__all__ = ["GGMLBackend"] diff --git a/mithril/backends/with_manualgrad/ggml_backend/backend.py b/mithril/backends/with_manualgrad/ggml_backend/backend.py new file mode 100644 index 00000000..f588bff8 --- /dev/null +++ b/mithril/backends/with_manualgrad/ggml_backend/backend.py @@ -0,0 +1,103 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes +import os +from typing import Any + +import numpy as np + +from .... import types +from ....cores.c.array import PyArray +from ....cores.c.ggml.ggml_core import ggml_struct +from ....cores.c.raw_c import array +from ...backend import Backend +from ...utils import process_shape +from ..c_backend.utils import from_numpy +from . import utils + +__all__ = ["GGMLBackend"] + + +class GGMLBackend(Backend[PyArray]): + backend_type = "c" + SRC_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "cores", "c", "ggml" + ) + CODEGEN_CONFIG = utils.CODEGEN_CONFIG + + def __init__(self) -> None: + self._device = "cpu" + self.primitive_function_dict = {} + + @property + def is_manualgrad(self) -> bool: + return True + + @property + def precision(self) -> int: + return 32 + + def set_seed(self, seed: int) -> None: + raise NotImplementedError("set_seed is not supported in GGML Backend") + + def get_backend_array_type(self) -> type[PyArray]: + return PyArray + + def get_struct_cls(self) -> type[ctypes.Structure]: + return ggml_struct + + def to_numpy(self, array: PyArray) -> np.ndarray[Any, Any]: + return np.ctypeslib.as_array( + ctypes.cast(array.arr.data, ctypes.POINTER(ctypes.c_float)), + shape=(array.shape), + ) + + def array( + self, input: np.ndarray[Any, Any], *, dtype: types.Dtype | None = None + ) -> PyArray: + assert dtype is None, "dtype is not supported in CBackend" + input = input.astype(np.float32) + data_ptr = ctypes.cast(from_numpy(input).arr.data, ctypes.c_void_p) + return PyArray(ggml_struct(data=data_ptr), input.shape) + + def ones( + self, + *shape: int | tuple[int, ...] | list[int], + dtype: types.Dtype | None = None, + ) -> PyArray: + assert dtype is None, "dtype is not supported in GGML Backend" + _shape = process_shape(shape) + data_ptr = ctypes.cast(array.ones(_shape).arr.data, ctypes.c_void_p) + return PyArray(ggml_struct(data=data_ptr), _shape) + + def zeros( + self, + *shape: int | tuple[int, ...] | list[int], + dtype: types.Dtype | None = None, + ) -> PyArray: + assert dtype is None, "dtype is not supported in GGML Backend" + _shape = process_shape(shape) + data_ptr = ctypes.cast(array.zeros(_shape).arr.data, ctypes.c_void_p) + return PyArray(ggml_struct(data=data_ptr), _shape) + + def empty( + self, + *shape: int | tuple[int, ...] | list[int], + dtype: types.Dtype | None = None, + ) -> PyArray: + assert dtype is None, "dtype is not supported in GGML Backend" + _shape = process_shape(shape) + data_ptr = ctypes.cast(array.empty(_shape).arr.data, ctypes.c_void_p) + return PyArray(ggml_struct(data=data_ptr), _shape) diff --git a/mithril/backends/with_manualgrad/ggml_backend/utils.py b/mithril/backends/with_manualgrad/ggml_backend/utils.py new file mode 100644 index 00000000..5266a557 --- /dev/null +++ b/mithril/backends/with_manualgrad/ggml_backend/utils.py @@ -0,0 +1,31 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....common import CGenConfig + +CODEGEN_CONFIG = CGenConfig() + +# File configs +CODEGEN_CONFIG.HEADER_NAME = "ggml_backend.h" + + +# Array configs +CODEGEN_CONFIG.ARRAY_NAME = "struct ggml_tensor" + +# Function configs +CODEGEN_CONFIG.USE_OUTPUT_AS_INPUT = False +CODEGEN_CONFIG.RETURN_OUTPUT = True + +# Memory management +CODEGEN_CONFIG.ALLOCATE_INTERNALS = False diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 79d6d755..201e17d2 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -46,6 +46,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): primitive_fn_path = "mithril.cores.python.numpy.ops" primitive_grad_fn_path = "mithril.cores.python.numpy.ops_grad" registered_primitives_grad_fn: dict[str, Callable[..., np.ndarray[Any, Any]]] = {} + CODEGEN_CONFIG = utils.CODEGEN_CONFIG def __init__(self, device: str = "cpu", dtype: Dtype = Dtype.float32) -> None: self._dtype = dtype @@ -84,10 +85,6 @@ def nan(self) -> float: def DataType(self) -> type[np.ndarray[Any, Any]]: # noqa: N802 return utils.ArrayType - @property - def codegen_config(self) -> dict[str, bool]: - return utils.CODEGEN_CONFIG - def get_backend_array_type(self) -> type[np.ndarray[Any, Any]]: return np.ndarray diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index dc6673e3..aaab8dfe 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -18,13 +18,12 @@ import numpy as np from .... import types -from ....common import find_dominant_type +from ....common import PythonGenConfig, find_dominant_type from ....cores.python.numpy.utils import dtype_map from ...utils import DtypeSubTypes -CODEGEN_CONFIG: dict[str, bool] = { - "specify_device": False, -} +CODEGEN_CONFIG = PythonGenConfig(SPECIFY_DEVICE=False) + ArrayType = np.ndarray diff --git a/mithril/common.py b/mithril/common.py index f61f46b1..ea404dff 100644 --- a/mithril/common.py +++ b/mithril/common.py @@ -13,10 +13,33 @@ # limitations under the License. from collections.abc import Callable, Iterator, MutableMapping +from dataclasses import dataclass from enum import IntEnum from typing import Any, TypeVar +@dataclass +class CGenConfig: + # Import configs + HEADER_NAME: str = "" + + # Array configs + ARRAY_NAME: str = "" + + # Function call configs + USE_OUTPUT_AS_INPUT: bool = False + RETURN_OUTPUT: bool = False + + # Memory Management + ALLOCATE_INTERNALS: bool = False + + +@dataclass +class PythonGenConfig: + # Import configs + SPECIFY_DEVICE: bool = False + + class PaddingType(IntEnum): VALID = 0 SAME = 1 diff --git a/mithril/cores/c/array.py b/mithril/cores/c/array.py index bb7c102b..24a9e81e 100644 --- a/mithril/cores/c/array.py +++ b/mithril/cores/c/array.py @@ -13,81 +13,40 @@ # limitations under the License. import ctypes -import os -from typing import Any - -import numpy as np - -current_file_path = os.path.abspath(__file__) - -lib = ctypes.CDLL(os.path.join(os.path.dirname(current_file_path), "libmithrilc.so")) - - -class Array(ctypes.Structure): - _fields_ = [ - ("data", ctypes.POINTER(ctypes.c_float)), - ("shape", ctypes.POINTER(ctypes.c_int)), - ("strides", ctypes.POINTER(ctypes.c_int)), - ("ndim", ctypes.c_int), - ("size", ctypes.c_int), - ] - - -lib.create_struct.restype = ctypes.POINTER(Array) -lib.create_struct.argtypes = [ - ctypes.POINTER(ctypes.c_float), - ctypes.c_int, - ctypes.POINTER(ctypes.c_int), -] - -lib.create_empty_struct.restype = ctypes.POINTER(Array) -lib.create_empty_struct.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_int)] - -lib.create_full_struct.restype = ctypes.POINTER(Array) -lib.create_full_struct.argtypes = [ - ctypes.c_float, - ctypes.c_int, - ctypes.POINTER(ctypes.c_int), -] - -lib.delete_struct.argtypes = [ctypes.POINTER(Array)] - - -def to_c_int_array(lst: list[int] | tuple[int, ...]) -> ctypes.Array[ctypes.c_int]: - return (ctypes.c_int * len(lst))(*lst) - - -def to_c_float_array( - arr: list[float] | tuple[float, ...] | np.ndarray[Any, Any], -) -> ctypes.Array[ctypes.c_float]: - return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) +from collections.abc import Sequence class PyArray: - def __init__(self, arr: Array, shape: tuple[int, ...] | list[int]): + def __init__(self, arr: ctypes.Structure, shape: tuple[int, ...] | list[int]): + # TODO: PyArray need to store strides + self.arr = arr if isinstance(shape, list): shape = tuple(shape) self.shape = shape self.ndim = len(shape) - def __del__(self): - lib.delete_struct(self.arr) + # TODO: Implement __del__ method for deleting the struct + # def __del__(self): + # lib.delete_struct(self.arr) @property - def data(self): + def data(self) -> Sequence[int | Sequence[int | Sequence[int]]]: total_elements = 1 for dim in self.shape: total_elements *= dim # Convert the array into a Python list - data_ptr = self.arr.contents.data + data_ptr = ctypes.cast(self.arr.data, ctypes.POINTER(ctypes.c_float)) data_list = [data_ptr[i] for i in range(total_elements)] # Reshape the flat list based on the shape - def reshape(data: PyArray, shape: tuple[int, ...]): + def reshape( + data: Sequence[int], shape: tuple[int, ...] + ) -> Sequence[int | Sequence[int | Sequence[int]]]: if len(shape) == 1: return data + size = shape[0] return [ reshape(data[i * size : (i + 1) * size], shape[1:]) @@ -101,21 +60,3 @@ def __repr__(self): def __str__(self): return f"PyArray(shape={self.shape})\n{self.data}" - - -def empty(shape: tuple[int, ...] | list[int]): - c_shape = to_c_int_array(shape) - arr = lib.create_empty_struct(len(shape), c_shape) - return PyArray(arr, shape) - - -def ones(shape: tuple[int, ...] | list[int]): - c_shape = to_c_int_array(shape) - arr = lib.create_full_struct(1.0, len(shape), c_shape) - return PyArray(arr, shape) - - -def zeros(shape: tuple[int, ...] | list[int]): - c_shape = to_c_int_array(shape) - arr = lib.create_full_struct(0.0, len(shape), c_shape) - return PyArray(arr, shape) diff --git a/mithril/cores/c/array.pyi b/mithril/cores/c/array.pyi index 7331645c..2527147e 100644 --- a/mithril/cores/c/array.pyi +++ b/mithril/cores/c/array.pyi @@ -27,15 +27,15 @@ def to_c_float_array( arr: list[float] | tuple[float, ...], ) -> ctypes.Array[ctypes.c_float]: ... -class Array(ctypes.Structure): ... - class PyArray: - arr: Array + arr: ctypes.Structure shape: tuple[int, ...] ndim: int def data(self) -> NestedList: ... - def __init__(self, arr: Array, shape: tuple[int, ...] | list[int]) -> None: ... + def __init__( + self, arr: ctypes.Structure, shape: tuple[int, ...] | list[int] + ) -> None: ... def __gt__(self, other: PyArray) -> PyArray: ... def __ge__(self, other: PyArray) -> PyArray: ... def __lt__(self, other: PyArray) -> PyArray: ... @@ -214,7 +214,3 @@ class PyArray: def __iter__(self) -> Any: ... def __repr__(self) -> str: ... def __str__(self) -> str: ... - -def empty(shape: tuple[builtins.int, ...] | list[builtins.int]) -> PyArray: ... -def ones(shape: tuple[builtins.int, ...] | list[builtins.int]) -> PyArray: ... -def zeros(shape: tuple[builtins.int, ...] | list[builtins.int]) -> PyArray: ... diff --git a/mithril/cores/c/ggml/README.md b/mithril/cores/c/ggml/README.md new file mode 100644 index 00000000..ecb5dc4d --- /dev/null +++ b/mithril/cores/c/ggml/README.md @@ -0,0 +1,66 @@ +# Mithril GGML Bindings + +This directory contains bindings for the GGML tensor library for use in Mithril. + +## Building the Binaries + +### Option 1: Use Pre-built GGML Libraries + +If you already have GGML libraries (`libggml-base.so`, `libggml-base.dylib`, etc.) available, you can just build the bindings: + +```bash +# Make the script executable if needed +chmod +x compile.sh + +# Build the bindings +./compile.sh +``` + +This will generate the `libmithrilggml.[so/dylib/dll]` file that contains the Mithril custom operations for GGML. + +### Option 2: Build GGML from Source + +If you don't have GGML libraries available, you can download and build them from source: + +```bash +# Make the script executable if needed +chmod +x build_ggml.sh + +# Build GGML and the bindings +./build_ggml.sh +``` + +This script will: +1. Clone the GGML repository (if not already present) +2. Build GGML with CMake +3. Copy the resulting libraries to this directory +4. Build the Mithril GGML bindings + +## Using the Bindings in Python + +The GGML bindings can be imported and used in Python: + +```python +from mithril.cores.c.ggml import ggml_struct +# Additional imports may be needed depending on your usage +``` + +## Custom Operations + +This binding provides the following custom operations: + +- `add`: Adds two GGML tensors +- `multiplication`: Multiplies two GGML tensors + +Each operation also has a corresponding gradient function for use in backpropagation: + +- `add_grad`: Gradient function for addition +- `multiplication_grad`: Gradient function for multiplication + +## Platform Support + +The build scripts handle cross-platform compilation and will generate the appropriate library type for your system: + +- Linux: `.so` extension +- macOS: `.dylib` extension +- Windows: `.dll` extension \ No newline at end of file diff --git a/mithril/cores/c/ggml/build_ggml.sh b/mithril/cores/c/ggml/build_ggml.sh new file mode 100755 index 00000000..757da883 --- /dev/null +++ b/mithril/cores/c/ggml/build_ggml.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Script to download and build GGML from source if needed + +set -e # Exit on any error + +# Detect OS and set appropriate library extension +UNAME=$(uname) +if [ "$UNAME" = "Darwin" ]; then + LIB_EXT="dylib" +elif [ "$UNAME" = "Windows" ] || [ "$UNAME" = "MINGW"* ] || [ "$UNAME" = "MSYS"* ]; then + LIB_EXT="dll" +else + LIB_EXT="so" +fi + +# Set directories +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${SCRIPT_DIR}/ggml/build" +GGML_REPO="https://github.com/ggerganov/ggml.git" +GGML_DIR="${SCRIPT_DIR}/ggml" + +# Check if we need to download GGML +if [ ! -d "${GGML_DIR}" ]; then + echo "GGML not found. Cloning from repository..." + git clone --depth 1 "${GGML_REPO}" "${GGML_DIR}" +else + echo "GGML directory exists. Using existing files." +fi + +# Create build directory if it doesn't exist +mkdir -p "${BUILD_DIR}" +cd "${BUILD_DIR}" + +# Configure CMake build +echo "Configuring GGML build..." +cmake .. -DBUILD_SHARED_LIBS=ON + +# Build GGML +echo "Building GGML..." +cmake --build . --config Release -j$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 1) + +# Copy the libraries to our directory +echo "Copying GGML libraries to ${SCRIPT_DIR}..." +find . -name "libggml*.${LIB_EXT}*" -exec cp {} "${SCRIPT_DIR}/" \; + +# Build our own bindings +cd "${SCRIPT_DIR}" +echo "Building Mithril GGML bindings..." +./compile.sh + +echo "Build completed successfully!" +echo "The following libraries are available:" +ls -l "${SCRIPT_DIR}"/*.${LIB_EXT}* \ No newline at end of file diff --git a/mithril/cores/c/ggml/compile.sh b/mithril/cores/c/ggml/compile.sh new file mode 100755 index 00000000..2b34c608 --- /dev/null +++ b/mithril/cores/c/ggml/compile.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Build GGML binaries with proper names for different platforms + +set -e # Exit on any error + +# Detect OS and set appropriate library extension +UNAME=$(uname) +if [ "$UNAME" = "Darwin" ]; then + LIB_EXT="dylib" +elif [ "$UNAME" = "Windows" ] || [ "$UNAME" = "MINGW"* ] || [ "$UNAME" = "MSYS"* ]; then + LIB_EXT="dll" +else + LIB_EXT="so" +fi + +# Basic compiler settings +CC=${CC:-cc} +CFLAGS="-O3 -fPIC" + +echo "Building GGML binaries for platform: $UNAME (extension: .$LIB_EXT)" + +# Compile the main library +echo "Compiling libmithrilggml.$LIB_EXT..." +${CC} ${CFLAGS} ops.c -L. -lggml-base -shared -o "libmithrilggml.$LIB_EXT" + +# Make the library executable if needed +chmod +x "libmithrilggml.$LIB_EXT" + +echo "Done! Created libmithrilggml.$LIB_EXT" \ No newline at end of file diff --git a/mithril/cores/c/ggml/ggml b/mithril/cores/c/ggml/ggml new file mode 160000 index 00000000..9a4acb37 --- /dev/null +++ b/mithril/cores/c/ggml/ggml @@ -0,0 +1 @@ +Subproject commit 9a4acb374565f4146b8d6eb1cffdcd7d437d1ba2 diff --git a/mithril/cores/c/ggml/ggml_backend.h b/mithril/cores/c/ggml/ggml_backend.h new file mode 100644 index 00000000..b7292c30 --- /dev/null +++ b/mithril/cores/c/ggml/ggml_backend.h @@ -0,0 +1,23 @@ +// Copyright 2022 Synnada, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef GGMLBACKEND_H +#define GGMLBACKEND_H + +#include "ops.h" +#include "ggml/include/ggml-cpu.h" +#include + +#endif \ No newline at end of file diff --git a/mithril/cores/c/ggml/ggml_core.py b/mithril/cores/c/ggml/ggml_core.py new file mode 100644 index 00000000..bbd99a3a --- /dev/null +++ b/mithril/cores/c/ggml/ggml_core.py @@ -0,0 +1,69 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["ggml_struct"] + +import ctypes + +GGML_MAX_DIMS = 4 +GGML_MAX_OP_PARAMS = 64 +GGML_MAX_SRC = 10 +GGML_MAX_NAME = 64 + + +class ggml_struct(ctypes.Structure): # noqa: N801 + """n-dimensional tensor + + Attributes: + type (int): ggml_type + buffer (ctypes.pointer[ggml_backend_buffer]): pointer to backend buffer + ne (ctypes.Array[ctypes.c_int64]): number of elements in each dimension + nb (ctypes.Array[ctypes.c_size_t]): stride in bytes for each dimension + op (int): ggml operation + op_params (ctypes.Array[ctypes.c_int32]): `GGML_MAX_OP_PARAMS`-length array of + operation parameters + flags (int): tensor flags + grad (ggml_struct_p): reference to gradient tensor + src (ctypes.Array[ggml_struct_p]): `GGML_MAX_SRC`-length array of source tensors + perf_runs (int): number of performance runs + perf_cycles (int): number of cycles + perf_time_us (int): time in microseconds + view_src (ggml_struct_p): pointer to tensor if this tensor is a view, None if + the tensor is not a view + view_offs (ctypes.c_size_t): offset into the data pointer of the view tensor + data (ctypes.c_void_p): reference to raw tensor data + name (bytes): name of tensor + extra (ctypes.c_void_p): extra data (e.g. for CUDA) + """ + + +ggml_struct._fields_ = [ + ("type", ctypes.c_int), + ("buffer", ctypes.c_void_p), + ("ne", ctypes.c_int64 * GGML_MAX_DIMS), + ("nb", ctypes.c_size_t * GGML_MAX_DIMS), + ("op", ctypes.c_int), + ( + "op_params", + ctypes.c_int32 * (GGML_MAX_OP_PARAMS // ctypes.sizeof(ctypes.c_int32)), + ), + ("flags", ctypes.c_int), + ("src", ctypes.POINTER(ggml_struct) * GGML_MAX_SRC), + ("view_src", ctypes.POINTER(ggml_struct)), + ("view_offs", ctypes.c_size_t), + ("data", ctypes.c_void_p), + ("name", ctypes.c_char * GGML_MAX_NAME), + ("extra", ctypes.c_void_p), + ("padding", ctypes.c_char * 8), +] diff --git a/mithril/cores/c/ggml/ggml_mithril.md b/mithril/cores/c/ggml/ggml_mithril.md new file mode 100644 index 00000000..282e74aa --- /dev/null +++ b/mithril/cores/c/ggml/ggml_mithril.md @@ -0,0 +1,32 @@ +# Program flow + +1) Determine context size + * Tensor data + * Tensor overhead * n#Tensor + * Graph overhead + * Some addttional overhead? +2) Initialize params with context size and some additional parameters +2) Initialize contex with params +4) Create tensors via context and set data + * We have to set dtype, and shape here. +5) Create graph via context +6) Apply operations to input tensors and obtain result Tensor + * This step is lazy, actually it creates the computational graph + + +### Problems +* Operations are lazy how we are going to do backend operations? +* Shape and dtype must be determined. + +### Questions +* Are we going to create context etc + +### Installation + + +### Allocation in GGML +1) The context is created with the params. The params contains: + * The size of the entire graph + * buffer_mem(if NULL it will allocated) + +2) diff --git a/mithril/cores/c/ggml/libggml-base.dylib b/mithril/cores/c/ggml/libggml-base.dylib new file mode 100755 index 00000000..fa98f428 Binary files /dev/null and b/mithril/cores/c/ggml/libggml-base.dylib differ diff --git a/mithril/cores/c/ggml/libggml-cpu.dylib b/mithril/cores/c/ggml/libggml-cpu.dylib new file mode 100755 index 00000000..b8035c8c Binary files /dev/null and b/mithril/cores/c/ggml/libggml-cpu.dylib differ diff --git a/mithril/cores/c/ggml/libggml.dylib b/mithril/cores/c/ggml/libggml.dylib new file mode 100755 index 00000000..58b642e3 Binary files /dev/null and b/mithril/cores/c/ggml/libggml.dylib differ diff --git a/mithril/cores/c/ggml/ops.c b/mithril/cores/c/ggml/ops.c new file mode 100644 index 00000000..42f26094 --- /dev/null +++ b/mithril/cores/c/ggml/ops.c @@ -0,0 +1,45 @@ +// Copyright 2022 Synnada, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops.h" + +struct ggml_tensor * add(struct ggml_context * ctx, struct ggml_tensor * left, struct ggml_tensor * right) { + struct ggml_tensor * res = ggml_add(ctx, left, right); + return res; +} + + +struct ggml_tensor * multiplication(struct ggml_context * ctx, struct ggml_tensor * left, struct ggml_tensor * right) { + struct ggml_tensor * res = ggml_mul(ctx, left, right); + return res; +} + + + + +struct ggml_tensor * add_grad(struct ggml_context * ctx, struct ggml_tensor * gradient, int idx, struct ggml_tensor * output, struct ggml_tensor * left, struct ggml_tensor * right) +{ + return gradient; +} + + +struct ggml_tensor * multiplication_grad(struct ggml_context * ctx, struct ggml_tensor * gradient, int idx, struct ggml_tensor * output, struct ggml_tensor * left, struct ggml_tensor * right) +{ + if (idx == 0){ + return multiplication(ctx, gradient, right); + } + else{ + return multiplication(ctx, gradient, left); + } +} diff --git a/mithril/cores/c/ggml/ops.h b/mithril/cores/c/ggml/ops.h new file mode 100644 index 00000000..282c02a2 --- /dev/null +++ b/mithril/cores/c/ggml/ops.h @@ -0,0 +1,28 @@ +// Copyright 2022 Synnada, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MITHRIL_GGML_OPS_H +#define MITHRIL_GGML_OPS_H + +#include "ggml/include/ggml.h" + + +struct ggml_tensor * add(struct ggml_context * ctx, struct ggml_tensor * left, struct ggml_tensor * right); +struct ggml_tensor * multiplication(struct ggml_context * ctx, struct ggml_tensor * left, struct ggml_tensor * right); + + +struct ggml_tensor * add_grad(struct ggml_context * ctx, struct ggml_tensor * gradient, int idx, struct ggml_tensor * output, struct ggml_tensor * left, struct ggml_tensor * right); +struct ggml_tensor * multiplication_grad(struct ggml_context * ctx, struct ggml_tensor * gradient, int idx, struct ggml_tensor * output, struct ggml_tensor * left, struct ggml_tensor * right); + +#endif diff --git a/mithril/cores/c/array.c b/mithril/cores/c/raw_c/array.c similarity index 100% rename from mithril/cores/c/array.c rename to mithril/cores/c/raw_c/array.c diff --git a/mithril/cores/c/array.h b/mithril/cores/c/raw_c/array.h similarity index 100% rename from mithril/cores/c/array.h rename to mithril/cores/c/raw_c/array.h diff --git a/mithril/cores/c/raw_c/array.py b/mithril/cores/c/raw_c/array.py new file mode 100644 index 00000000..2295da7f --- /dev/null +++ b/mithril/cores/c/raw_c/array.py @@ -0,0 +1,83 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes +import os +from typing import Any + +import numpy as np + +from ..array import PyArray + +current_file_path = os.path.abspath(__file__) + +lib = ctypes.CDLL(os.path.join(os.path.dirname(current_file_path), "libmithrilc.so")) + + +class Array(ctypes.Structure): + _fields_ = [ + ("data", ctypes.POINTER(ctypes.c_float)), + ("shape", ctypes.POINTER(ctypes.c_int)), + ("strides", ctypes.POINTER(ctypes.c_int)), + ("ndim", ctypes.c_int), + ("size", ctypes.c_int), + ] + + +lib.create_struct.restype = ctypes.POINTER(Array) +lib.create_struct.argtypes = [ + ctypes.POINTER(ctypes.c_float), + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), +] + +lib.create_empty_struct.restype = ctypes.POINTER(Array) +lib.create_empty_struct.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_int)] + +lib.create_full_struct.restype = ctypes.POINTER(Array) +lib.create_full_struct.argtypes = [ + ctypes.c_float, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), +] + +lib.delete_struct.argtypes = [ctypes.POINTER(Array)] + + +def to_c_int_array(lst: list[int] | tuple[int, ...]) -> ctypes.Array[ctypes.c_int]: + return (ctypes.c_int * len(lst))(*lst) + + +def to_c_float_array( + arr: list[float] | tuple[float, ...] | np.ndarray[Any, Any], +) -> ctypes.Array[ctypes.c_float]: + return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + + +def empty(shape: tuple[int, ...] | list[int]): + c_shape = to_c_int_array(shape) + arr = lib.create_empty_struct(len(shape), c_shape).contents + return PyArray(arr, shape) + + +def ones(shape: tuple[int, ...] | list[int]): + c_shape = to_c_int_array(shape) + arr = lib.create_full_struct(1.0, len(shape), c_shape).contents + return PyArray(arr, shape) + + +def zeros(shape: tuple[int, ...] | list[int]): + c_shape = to_c_int_array(shape) + arr = lib.create_full_struct(0.0, len(shape), c_shape).contents + return PyArray(arr, shape) diff --git a/mithril/cores/c/raw_c/array.pyi b/mithril/cores/c/raw_c/array.pyi new file mode 100644 index 00000000..d0a21e30 --- /dev/null +++ b/mithril/cores/c/raw_c/array.pyi @@ -0,0 +1,33 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import ctypes + +from ..array import PyArray + +NestedList = list[float | "NestedList"] + +lib: ctypes.CDLL + +def to_c_int_array(lst: list[int] | tuple[int, ...]) -> ctypes.Array[ctypes.c_int]: ... +def to_c_float_array( + arr: list[float] | tuple[float, ...], +) -> ctypes.Array[ctypes.c_float]: ... + +class Array(ctypes.Structure): ... + +def empty(shape: tuple[builtins.int, ...] | list[builtins.int]) -> PyArray: ... +def ones(shape: tuple[builtins.int, ...] | list[builtins.int]) -> PyArray: ... +def zeros(shape: tuple[builtins.int, ...] | list[builtins.int]) -> PyArray: ... diff --git a/mithril/cores/c/cbackend.h b/mithril/cores/c/raw_c/cbackend.h similarity index 100% rename from mithril/cores/c/cbackend.h rename to mithril/cores/c/raw_c/cbackend.h diff --git a/mithril/cores/c/compile.sh b/mithril/cores/c/raw_c/compile.sh similarity index 100% rename from mithril/cores/c/compile.sh rename to mithril/cores/c/raw_c/compile.sh diff --git a/mithril/cores/c/main.c b/mithril/cores/c/raw_c/main.c similarity index 100% rename from mithril/cores/c/main.c rename to mithril/cores/c/raw_c/main.c diff --git a/mithril/cores/c/ops.c b/mithril/cores/c/raw_c/ops.c similarity index 100% rename from mithril/cores/c/ops.c rename to mithril/cores/c/raw_c/ops.c diff --git a/mithril/cores/c/ops.h b/mithril/cores/c/raw_c/ops.h similarity index 100% rename from mithril/cores/c/ops.h rename to mithril/cores/c/raw_c/ops.h diff --git a/mithril/cores/c/utils.h b/mithril/cores/c/raw_c/utils.h similarity index 100% rename from mithril/cores/c/utils.h rename to mithril/cores/c/raw_c/utils.h diff --git a/mithril/framework/codegen/__init__.py b/mithril/framework/codegen/__init__.py index 86bb3b5a..b65758c9 100644 --- a/mithril/framework/codegen/__init__.py +++ b/mithril/framework/codegen/__init__.py @@ -42,11 +42,19 @@ pass try: from ...backends.with_manualgrad.c_backend import CBackend - from .c_gen import CGen + from .raw_c_gen import RawCGen - code_gen_map[CBackend] = CGen + code_gen_map[CBackend] = RawCGen except Exception: pass + +try: + from ...backends.with_manualgrad.ggml_backend import GGMLBackend + from .ggml_gen import GGMLCodeGen + + code_gen_map[GGMLBackend] = GGMLCodeGen +except Exception as e: + raise e try: from ...backends.with_manualgrad.numpy_backend import NumpyBackend from .numpy_gen import NumpyCodeGen @@ -63,4 +71,5 @@ "PythonCodeGen", "NumpyCodeGen", "TorchCodeGen", + "GGMLCodeGen", ] diff --git a/mithril/framework/codegen/c_ast.py b/mithril/framework/codegen/c_ast.py index 04414ac5..c7533a96 100644 --- a/mithril/framework/codegen/c_ast.py +++ b/mithril/framework/codegen/c_ast.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass @@ -34,22 +34,29 @@ class Stmt(AST): pass +@dataclass +class MakeStmt(Stmt): + expr: Expr + + def to_str(self) -> str: + return self.expr.to_str() + ";" + + @dataclass class Call(Expr): name: str - args: list[str] | list[Expr] + args: Sequence[str | Expr] def to_str(self) -> str: args_str = ", ".join( [arg.to_str() if isinstance(arg, Expr) else arg for arg in self.args] ) - # args_str = ", ".join(self.args) return f"{self.name}({args_str})" @dataclass class Constant(Expr): - value: int | float + value: int | float | str def to_str(self) -> str: return str(self.value) @@ -58,13 +65,34 @@ def __str__(self) -> str: return self.to_str() +@dataclass +class Variable(Expr): + name: str + + def to_str(self) -> str: + return self.name + + +@dataclass +class Assign(Stmt): + target: Expr + source: Expr | Stmt + + def to_str(self) -> str: + result_str = f"{self.target.to_str()} = {self.source.to_str()}" + if not isinstance(self.source, Stmt): + result_str += ";" + return result_str + + @dataclass class Parameter: - type: str + type: str | Expr name: str def to_str(self) -> str: - return f"{self.type} {self.name}" + type_str = self.type.to_str() if isinstance(self.type, Expr) else self.type + return f"{type_str} {self.name}" @dataclass @@ -76,10 +104,12 @@ class FunctionDef(Stmt): def to_str(self) -> str: params_str = ( - "\n\t" + ",\n\t".join([param.to_str() for param in self.params]) + "\n" + ("\n\t" + ",\n\t".join([param.to_str() for param in self.params]) + "\n") + if len(self.params) > 0 + else "" ) - body_str = "\n ".join([stmt.to_str() + ";" for stmt in self.body]) - return f"{self.return_type} {self.name}({params_str})\n{{\n {body_str}\n}}" + body_str = "\n ".join([stmt.to_str() for stmt in self.body]) + return f"\n{self.return_type} {self.name}({params_str})\n{{\n {body_str}\n}}" @dataclass @@ -102,6 +132,43 @@ def to_str(self) -> str: return f'#include "{self.header}"' +@dataclass +class Comment(Stmt): + text: str + multiline: bool = False # True for /* */ comments, False for // comments + + def to_str(self) -> str: + if self.multiline: + # Format multiline comments with proper line breaks + lines = self.text.split("\n") + if len(lines) == 1: + return f"/* {self.text} */" + formatted_lines = [f" * {line}" for line in lines] + return "/*\n" + "\n".join(formatted_lines) + "\n */" + else: + return f"// {self.text}" + + +@dataclass +class StructField: + type: str | Expr + name: str + + def to_str(self) -> str: + type_str = self.type.to_str() if isinstance(self.type, Expr) else self.type + return f" {type_str} {self.name};" + + +@dataclass +class StructDef(Stmt): + name: str + fields: list[StructField] + + def to_str(self) -> str: + fields_str = "\n".join(field.to_str() for field in self.fields) + return f"\nstruct {self.name} {{\n{fields_str}\n}};\n" + + @dataclass class FILE(AST): includes: list[Include] @@ -115,3 +182,83 @@ def to_str(self) -> str: globals_str = "\n".join(stmt.to_str() for stmt in self.globals) declarations_str = "\n\n".join(decl.to_str() for decl in self.declarations) return f"{includes_str}\n\n{globals_str}\n\n{declarations_str}" + + +@dataclass +class StructInit(Stmt): + struct_name: str + field_values: Mapping[str, Expr | str] + static: bool = False + + def to_str(self) -> str: + field_inits = [ + f".{field} = {value.to_str() if isinstance(value, Expr) else value}" + for field, value in self.field_values.items() + ] + fields_str = ", ".join(field_inits) + + stmt = f"struct {self.struct_name} = {{ {fields_str} }};" + if self.static: + stmt = f"static {stmt}" + + return stmt + + +@dataclass +class StaticVariable(Stmt): + type: str | Expr + name: str + initial_value: Expr | None = None + + def to_str(self) -> str: + type_str = self.type.to_str() if isinstance(self.type, Expr) else self.type + if self.initial_value is None: + return f"static {type_str} {self.name};" + return f"static {type_str} {self.name} = {self.initial_value.to_str()};" + + +@dataclass +class If(Stmt): + condition: Expr + body: list[Stmt] + else_body: list[Stmt] | None = None + + def to_str(self) -> str: + body_str = "\n ".join([stmt.to_str() for stmt in self.body]) + if self.else_body is None: + return f"if ({self.condition.to_str()}) {{\n {body_str}\n}}" + else: + else_str = "\n ".join([stmt.to_str() for stmt in self.else_body]) + return ( + f"if ({self.condition.to_str()}) {{\n {body_str}\n}} else " + f"{{\n {else_str}\n}}" + ) + + +@dataclass +class Arrow(Expr): + target: Expr + field: str + + def to_str(self) -> str: + return f"{self.target.to_str()}->{self.field}" + + +@dataclass +class Dot(Expr): + target: Variable + field: str + + def to_str(self) -> str: + return f"{self.target.to_str()}.{self.field}" + + +@dataclass +class Pointer(Expr): + target: str | Expr + + def to_str(self) -> str: + target_str = ( + self.target.to_str() if isinstance(self.target, Expr) else self.target + ) + return f"{target_str} *" diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index 275fda79..e23fe0a4 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -19,8 +19,10 @@ from functools import partial from ...backends.with_manualgrad.c_backend import CBackend, backend -from ...cores.c import array -from ...cores.c.array import Array, PyArray +from ...backends.with_manualgrad.ggml_backend import GGMLBackend +from ...common import CGenConfig +from ...cores.c.array import PyArray +from ...cores.c.raw_c import array from ...framework.common import ( EvaluateAllType, EvaluateGradientsType, @@ -31,17 +33,30 @@ from . import c_ast from .code_gen import CodeGen -FinalCost = "final_cost" +ast_block_type = list[c_ast.Stmt] | list[c_ast.Expr] | list[c_ast.Stmt | c_ast.Expr] class CGen(CodeGen[PyArray]): BACKWARD_FN_SUFFIX = "_grad" + EVALUATE_INPUT_STRUCT_NAME = "eval_inputs" + EVALUATE_GRAD_INPUT_STRUCT_NAME = "eval_grad_inputs" + EVALUATE_OUTPUT_STRUCT_NAME = "eval_outputs" + EVALUATE_GRAD_OUTPUT_STRUCT_NAME = "eval_grad_outputs" + CACHE_STRUCT_NAME = "cache_keys" + GRAD_STRUCT_NAME = "grad_keys" + CACHE_NAME = "cache" + + dynamic_links: list[str] = [] def __init__(self, pm: PhysicalModel[PyArray]) -> None: super().__init__(pm) - assert isinstance(self.pm.backend, CBackend) - self.backend: CBackend = self.pm.backend + assert isinstance(self.pm.backend, CBackend | GGMLBackend), ( + f"Invalid backend '{self.pm.backend.backend_type}'! Must be CBackend" + " or GGMLBackend" + ) + + self.backend: CBackend | GGMLBackend = self.pm.backend self.imports: list[c_ast.AST] = [] self.globals: list[c_ast.AST] = [] @@ -49,22 +64,52 @@ def __init__(self, pm: PhysicalModel[PyArray]) -> None: # This will be used to store the keys of the argument of the functions self.func_arg_keys: dict[str, list[str]] = {} + self.configs: CGenConfig = self.backend.CODEGEN_CONFIG - def generate_imports(self) -> list[c_ast.Include]: - header_path = os.path.join(self.backend.SRC_PATH, "cbackend.h") - return [c_ast.Include(header_path, system=False)] + # Ignored grad keys + self.ignored_grad_keys: set[str] = self._infer_ignored_grad_keys() + + # Determine struct keys + self.determined_struct_keys: dict[str, list[str]] = ( + self._determine_struct_keys() + ) def generate_code(self, file_path: str | None = None) -> None: self.file_path = file_path - self.imports = self.generate_imports() # type: ignore - eval_fn, eval_used_keys = self.generate_evaluate() + self.imports += self.generate_imports() + + # Functions + eval_fn = self.generate_evaluate() self.functions.append(eval_fn) - self.func_arg_keys["evaluate"] = sorted(eval_used_keys) + self.func_arg_keys["evaluate"] = sorted(self.pm.input_keys) + if not self.pm.inference: - eval_grad_fn, eval_grad_used_keys = self.generate_evaluate_gradients() + eval_grad_fn = self.generate_evaluate_gradients() self.functions.append(eval_grad_fn) - self.func_arg_keys["evaluate_gradients"] = sorted(eval_grad_used_keys) + + # Structs + self._generate_structs() + + # Init cache struct + cache_struct = c_ast.StructInit( + f"{self.CACHE_STRUCT_NAME} {self.CACHE_NAME}", + {key: "NULL" for key in self.determined_struct_keys["eval_cache_keys"]}, + static=True, + ) + self.globals.append(cache_struct) + + if not self.pm.inference: + # Init grad struct + grad_struct = c_ast.StructInit( + f"{self.EVALUATE_GRAD_OUTPUT_STRUCT_NAME} {self.GRAD_STRUCT_NAME}", + { + key: "NULL" + for key in self.determined_struct_keys["eval_grad_output_keys"] + }, + static=True, + ) + self.globals.append(grad_struct) generated_code = c_ast.FILE(self.imports, self.globals, self.functions).to_str() # type: ignore @@ -89,7 +134,6 @@ def compile_code( assert not jit, "JIT is not yet supported for CBackend" assert self.file_path is not None, "Code has not been generated yet!" - eval_arg_keys = self.func_arg_keys["evaluate"] so_file_path = self.file_path.replace(".c", ".so") default_compile_flags = ["cc", self.file_path, "-shared", "-fPIC"] @@ -100,7 +144,7 @@ def compile_code( [ *default_compile_flags, f"-L{self.backend.SRC_PATH}", - "-lmithrilc", + *self.dynamic_links, f"-Wl,-rpath,{self.backend.SRC_PATH}", "-o", so_file_path, @@ -110,14 +154,51 @@ def compile_code( if so_file_path[0] != "/": so_file_path = "./" + so_file_path + # Load dynamic links + for link in self.dynamic_links: + link_path = os.path.join(self.backend.SRC_PATH, link.replace("-l", "lib")) + if os.path.exists(link_path + ".so"): + link_path += ".so" + elif os.path.exists(link_path + ".dylib"): + link_path += ".dylib" + + ctypes.CDLL(link_path) + # We need backend subtype lib = ctypes.CDLL(so_file_path) - lib.evaluate.argtypes = [ctypes.POINTER(Array)] * len(eval_arg_keys) + + # Input and output structs + class Inputs(ctypes.Structure): + _fields_ = [ + (key, ctypes.POINTER(self.backend.get_struct_cls())) + for key in self.determined_struct_keys["eval_input_keys"] + ] + + class Outputs(ctypes.Structure): + _fields_ = [ + (key, ctypes.POINTER(self.backend.get_struct_cls())) + for key in self.determined_struct_keys["eval_output_keys"] + ] + + class GradInputs(ctypes.Structure): + _fields_ = [ + (key, ctypes.POINTER(self.backend.get_struct_cls())) + for key in self.determined_struct_keys["eval_grad_input_keys"] + ] + + class GradOutputs(ctypes.Structure): + _fields_ = [ + (key, ctypes.POINTER(self.backend.get_struct_cls())) + for key in self.determined_struct_keys["eval_grad_output_keys"] + ] + + # Set the return type and argument types + lib.evaluate.argtypes = [ctypes.POINTER(Inputs)] + lib.evaluate.restype = Outputs + if not self.pm.inference: - eval_grad_arg_keys = self.func_arg_keys["evaluate_gradients"] - lib.evaluate_gradients.argtypes = [ctypes.POINTER(Array)] * len( - eval_grad_arg_keys - ) + lib.evaluate_gradients.argtypes = [ctypes.POINTER(GradInputs)] + lib.evaluate_gradients.restype = GradOutputs # we need backend data types! # include_internals flag is used for get internal values for backpropagation @@ -135,21 +216,38 @@ def evaluate_wrapper( if isinstance(cache, dict): inputs |= cache - # Allocate output arrays - for arg_key in eval_arg_keys: - if arg_key in inputs: - continue + if self.configs.ALLOCATE_INTERNALS: + # Allocate output arrays + for arg_key in self.determined_struct_keys["eval_input_keys"]: + if arg_key in inputs: + continue - arr_shape = self._get_array_shape(arg_key) - inputs[arg_key] = self.backend.empty(*arr_shape) + arr_shape = self._get_array_shape(arg_key) + inputs[arg_key] = self.backend.empty(*arr_shape) - inputs_ordered = [inputs[arg].arr for arg in eval_arg_keys] - lib.evaluate(*inputs_ordered) + inputs_struct = Inputs( + **{ + key: ctypes.pointer(inputs[key].arr) + for key in self.determined_struct_keys["eval_input_keys"] + } + ) + inputs_struct_ptr = ctypes.pointer(inputs_struct) - if not include_internals: - return {key: inputs[key] for key in self.pm.output_keys} - else: - return inputs + output_struct = lib.evaluate(inputs_struct_ptr) + + outputs = {} + return_keys = ( + self.determined_struct_keys["eval_output_keys"] + if include_internals + else self.pm.output_keys + ) + for key in return_keys: + array_ptr = getattr(output_struct, key) + outputs[key] = PyArray( + array_ptr.contents, shape=self._get_tensor_shape(key) + ) + + return outputs def evaluate_gradients_wrapper( params: dict[str, PyArray], @@ -160,38 +258,58 @@ def evaluate_gradients_wrapper( if data is None: data = {} - if output_gradients is None and FinalCost not in self.pm._output_keys: + if output_gradients is None and self.FinalCost not in self.pm._output_keys: raise ValueError( "Requires output gradients if final loss is not attached!" ) elif output_gradients is None: - output_gradients = {FinalCost: array.ones((1,))} + output_gradients = {self.FinalCost: array.ones((1,))} gradients = {key: value for key, value in output_gradients.items()} forward_pass = evaluate_wrapper( - params=params, data=data, cache={}, include_internals=True + params=params, + data=data, + cache={}, + include_internals=self.configs.ALLOCATE_INTERNALS, ) # Create gradients for all params - for key in ( - self.pm.flat_graph.all_source_keys - - self.pm.flat_graph.all_static_keys - - self.pm.flat_graph.unused_keys - - self.pm.ignore_grad_keys - ): - # In CBackend we are creating all internal gradients with zeros. - if key not in gradients: - arr_shape = self._get_array_shape(key) - gradients[key] = self.backend.zeros(*arr_shape) - - gradients = {key + "_grad": value for key, value in gradients.items()} + if self.configs.ALLOCATE_INTERNALS: + for key in ( + self.pm.flat_graph.all_source_keys + - self.pm.flat_graph.all_static_keys + - self.pm.flat_graph.unused_keys + - self.ignored_grad_keys + ): + # In CBackend we are creating all internal gradients with zeros. + if key not in gradients: + arr_shape = self._get_array_shape(key) + gradients[key] = self.backend.zeros(*arr_shape) + + gradients = { + key + self.BACKWARD_FN_SUFFIX: value for key, value in gradients.items() + } inputs = params | data | gradients | forward_pass - inputs_ordered = [inputs[arg].arr for arg in sorted(inputs.keys())] - lib.evaluate_gradients(*inputs_ordered) + inputs_struct = GradInputs( + **{ + key: ctypes.pointer(inputs[key].arr) + for key in self.determined_struct_keys["eval_grad_input_keys"] + } + ) + inputs_struct_ptr = ctypes.pointer(inputs_struct) + + output_struct = lib.evaluate_gradients(inputs_struct_ptr) + outputs = {} + for grad_key in self.determined_struct_keys["eval_grad_output_keys"]: + key = grad_key.replace(self.BACKWARD_FN_SUFFIX, "") + array_ptr = getattr(output_struct, grad_key) + outputs[key] = PyArray( + array_ptr.contents, shape=self._get_tensor_shape(key) + ) - return {key: inputs[key + "_grad"] for key in params} + return outputs return ( # type: ignore evaluate_wrapper, @@ -199,86 +317,209 @@ def evaluate_gradients_wrapper( partial(evaluate_gradients_wrapper, include_output=True), # type: ignore ) - def create_primitive_call(self, formula_name: str, args: list[str]) -> c_ast.Expr: - return c_ast.Call(formula_name, args) + def generate_imports(self) -> list[c_ast.Include]: + header_path = os.path.join(self.backend.SRC_PATH, self.configs.HEADER_NAME) + return [c_ast.Include(header_path, system=False)] + + def create_primitive_call( + self, formula_name: str, args: list[c_ast.Expr], context: str + ) -> c_ast.Expr: + return c_ast.Call(formula_name, args=args) + + def assign_primitive_output( + self, target: str, source: c_ast.Expr, context: str + ) -> c_ast.Assign: + return self.assign_array( + self.create_key_ref(target, context=context, load=False), source + ) + + def create_key_ref(self, key: str, context: str, load: bool = True) -> c_ast.Expr: + if key in self.determined_struct_keys["eval_cache_keys"]: + return c_ast.Variable(f"{self.CACHE_NAME}.{key}") + + elif ( + context == "eval" and key in self.determined_struct_keys["eval_input_keys"] + ): + return c_ast.Arrow(c_ast.Variable("inputs"), key) + + elif context == "eval_grad": + if key in self.determined_struct_keys["eval_grad_input_keys"]: + return c_ast.Arrow(c_ast.Variable("inputs"), key) + + if ( + key in self.pm.flat_graph.all_keys + or key.replace(self.BACKWARD_FN_SUFFIX, "") + in self.pm.flat_graph.all_keys + ) and not load: + return c_ast.Variable(f"{self.configs.ARRAY_NAME} * {key}") + + return c_ast.Variable(key) + + def assign_array( + self, target: c_ast.Variable | c_ast.Expr, source: c_ast.Expr + ) -> c_ast.Assign: + return c_ast.Assign(target, source) + + def define_function( + self, + return_type: str, + name: str, + params: list[c_ast.Parameter], + pre_process: ast_block_type, + operations: ast_block_type, + post_process: ast_block_type, + ) -> c_ast.FunctionDef: + body = pre_process + operations + post_process + return c_ast.FunctionDef(return_type, name, params, body) + + def create_output_struct(self, context: str) -> c_ast.StructInit: + output_keys = ( + self.determined_struct_keys["eval_output_keys"] + if context == "eval" + else self.determined_struct_keys["eval_grad_output_keys"] + ) + output_struct_init: dict[str, c_ast.Expr] = { + key: self.create_key_ref(key, context=context) for key in output_keys + } + + output_struct_name = ( + self.EVALUATE_OUTPUT_STRUCT_NAME + if context == "eval" + else self.EVALUATE_GRAD_OUTPUT_STRUCT_NAME + ) + + return c_ast.StructInit( + f"{output_struct_name} output_struct", output_struct_init + ) - def generate_evaluate(self) -> tuple[c_ast.FunctionDef, set[str]]: - fn_body: list[c_ast.Expr] = [] - used_keys: set[str] = set() + def generate_evaluate(self) -> c_ast.FunctionDef: + # Function body + pre_process: ast_block_type = [] + operations: ast_block_type = [] + post_process: ast_block_type = [] + + # Define function arguments + arguments = [ + c_ast.Parameter( + c_ast.Pointer(f"struct {self.EVALUATE_INPUT_STRUCT_NAME}"), "inputs" + ) + ] for output_key in self.pm.flat_graph.topological_order: model = self.pm.flat_graph.get_model(output_key) inputs = self.pm.flat_graph.get_source_keys(output_key) - # In C backend we need to pass output array as first argument - inputs = [output_key] + inputs + if self.configs.USE_OUTPUT_AS_INPUT: + # In raw_c backend we need to pass output array as first argument + inputs = [output_key] + inputs + + input_vars: list[c_ast.Expr] = [ + self.create_key_ref(key, context="eval", load=True) for key in inputs + ] # Create primitive call - p_call = self.create_primitive_call(model.formula_key, inputs) - fn_body.append(p_call) + p_call = self.create_primitive_call( + model.formula_key, + input_vars, + context="eval", + ) - used_keys.add(output_key) - used_keys |= set(inputs) + p_call_stmts: c_ast.Stmt = self.assign_primitive_output( + output_key, p_call, context="eval" + ) - arguments: list[c_ast.Parameter] = [] - for used_key in sorted(used_keys): - arguments.append(c_ast.Parameter("Array *", used_key)) + operations.append(p_call_stmts) # type: ignore - evaluate_fn = c_ast.FunctionDef("void", "evaluate", arguments, fn_body) + # Prepare output + post_process.append(self.create_output_struct(context="eval")) # type: ignore + post_process.append(c_ast.Return(c_ast.Variable("output_struct"))) # type: ignore - return evaluate_fn, used_keys + evaluate_fn = self.define_function( + f"struct {self.EVALUATE_OUTPUT_STRUCT_NAME}", + "evaluate", + arguments, + pre_process, + operations, + post_process, + ) - def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: - fn_body: list[c_ast.Expr] = [] - used_keys: set[str] = set() + return evaluate_fn - all_ignored_keys = ( - self.pm.ignore_grad_keys - | self.pm.flat_graph.all_static_keys - | self.pm.flat_graph.unused_keys - ) - all_ignored_keys, _ = self.pm.flat_graph.infer_ignore( - set(), self.pm._output_keys, all_ignored_keys, update_graph=False - ) + def generate_evaluate_gradients(self) -> c_ast.FunctionDef: + # Function body + pre_process: ast_block_type = [] + operations: ast_block_type = [] + post_process: ast_block_type = [] + + # Define function arguments + arguments = [ + c_ast.Parameter( + c_ast.Pointer(f"struct {self.EVALUATE_GRAD_INPUT_STRUCT_NAME}"), + "inputs", + ) + ] for output_key in reversed(self.pm.flat_graph.topological_order): # Staticly infered and unused model will not be added - if output_key in all_ignored_keys: + if output_key in self.ignored_grad_keys: continue model = self.pm.flat_graph.get_model(output_key) inputs = self.pm.flat_graph.get_source_keys(output_key) # Assume all inputs are Array - grad_inputs = [input_key + "_grad" for input_key in inputs] - for idx in range(len(grad_inputs)): - fn_inputs: list[str] = ( - [output_key + "_grad", c_ast.Constant(idx).to_str(), output_key] - + inputs - + grad_inputs - ) + for idx in range(len(inputs)): + if inputs[idx] in self.ignored_grad_keys: + continue + + fn_inputs: list[c_ast.Expr] = [ + self.create_key_ref( + output_key + self.BACKWARD_FN_SUFFIX, context="eval_grad" + ), + c_ast.Constant(idx), + self.create_key_ref(output_key, context="eval_grad"), + ] + [ + self.create_key_ref(input_key, context="eval_grad") + for input_key in inputs + ] + + if self.configs.USE_OUTPUT_AS_INPUT: + fn_inputs += [ + self.create_key_ref( + input_key + self.BACKWARD_FN_SUFFIX, context="eval_grad" + ) + if input_key not in self.ignored_grad_keys + else c_ast.Variable("NULL") + for input_key in inputs + ] # Create primitive call p_call = self.create_primitive_call( - model.formula_key + self.BACKWARD_FN_SUFFIX, fn_inputs + model.formula_key + self.BACKWARD_FN_SUFFIX, + fn_inputs, + context="eval_grad", ) - fn_body.append(p_call) - used_keys.add(output_key) - used_keys.add(output_key + "_grad") - used_keys |= set(inputs) - used_keys |= set(grad_inputs) + p_call_stmts: c_ast.Stmt = self.assign_primitive_output( + inputs[idx] + self.BACKWARD_FN_SUFFIX, p_call, context="eval_grad" + ) - arguments: list[c_ast.Parameter] = [] + operations.append(p_call_stmts) # type: ignore - for used_key in sorted(used_keys): - arguments.append(c_ast.Parameter("Array *", used_key)) + # Prepare output + post_process.append(self.create_output_struct(context="eval_grad")) # type: ignore + post_process.append(c_ast.Return(c_ast.Variable("output_struct"))) # type: ignore - evaluate_grad_fn = c_ast.FunctionDef( - "void", "evaluate_gradients", arguments, fn_body + evaluate_grad_fn = self.define_function( + f"struct {self.EVALUATE_GRAD_OUTPUT_STRUCT_NAME}", + "evaluate_gradients", + arguments, + pre_process, + operations, + post_process, ) - return evaluate_grad_fn, used_keys + return evaluate_grad_fn def _get_backend_path(self) -> str: backend_path = backend.__file__ @@ -293,3 +534,102 @@ def _get_array_shape(self, key: str) -> tuple[int, ...]: return tuple(shape) else: raise ValueError(f"Unexpected shape: {shape}") + + def _generate_structs(self) -> None: + # Generate structs + eval_input_struct = self._generate_struct( + self.EVALUATE_INPUT_STRUCT_NAME, + self.determined_struct_keys["eval_input_keys"], + ) + eval_outputs_struct = self._generate_struct( + self.EVALUATE_OUTPUT_STRUCT_NAME, + self.determined_struct_keys["eval_output_keys"], + ) + + cache_struct = self._generate_struct( + self.CACHE_STRUCT_NAME, self.determined_struct_keys["eval_cache_keys"] + ) + + structs = [eval_input_struct, eval_outputs_struct, cache_struct] + + if not self.pm.inference: + eval_grad_input_struct = self._generate_struct( + self.EVALUATE_GRAD_INPUT_STRUCT_NAME, + self.determined_struct_keys["eval_grad_input_keys"], + ) + + eval_grad_outputs_struct = self._generate_struct( + self.EVALUATE_GRAD_OUTPUT_STRUCT_NAME, + self.determined_struct_keys["eval_grad_output_keys"], + ) + + structs += [eval_grad_input_struct, eval_grad_outputs_struct] + + self.globals = structs + self.globals + + def _generate_struct(self, name: str, field_keys: list[str]) -> c_ast.Stmt: + fields = [ + c_ast.StructField( + c_ast.Pointer(c_ast.Variable(self.configs.ARRAY_NAME)), key + ) + for key in sorted(field_keys) + ] + struct = c_ast.StructDef(name, fields) + return struct + + def _infer_ignored_grad_keys(self) -> set[str]: + all_ignored_keys = ( + self.pm.ignore_grad_keys + | self.pm.flat_graph.all_static_keys + | self.pm.flat_graph.unused_keys + ) + all_ignored_keys, _ = self.pm.flat_graph.infer_ignore( + set(), self.pm._output_keys, all_ignored_keys, update_graph=False + ) + + return all_ignored_keys + + def _determine_struct_keys(self) -> dict[str, list[str]]: + eval_input_keys = sorted(self.pm.input_keys) + if self.configs.USE_OUTPUT_AS_INPUT: + eval_input_keys = sorted(self.pm.flat_graph.all_keys) + + eval_output_keys = sorted(self.pm.output_keys) + eval_cache_keys = sorted(self.pm.flat_graph.all_keys - self.pm.input_keys) + + eval_grad_input_keys = sorted( + ( + self.pm.input_keys + | set(self.pm.output_keys) + | { + key + self.BACKWARD_FN_SUFFIX + for key in set(self.pm.output_keys) - self.ignored_grad_keys + } + ) + - set(eval_cache_keys) + ) + + eval_grad_output_keys = sorted( + [ + key + self.BACKWARD_FN_SUFFIX + for key in set(self.pm.input_keys) - self.ignored_grad_keys + ] + ) + + determined_struct_keys = { + "eval_input_keys": eval_input_keys, + "eval_output_keys": eval_output_keys, + "eval_cache_keys": eval_cache_keys, + "eval_grad_input_keys": eval_grad_input_keys, + "eval_grad_output_keys": eval_grad_output_keys, + } + + return determined_struct_keys + + def _get_tensor_shape(self, key: str) -> tuple[int, ...]: + if key in self.pm.shapes: + return self.pm.shapes[key] # type: ignore + elif key.replace(self.BACKWARD_FN_SUFFIX, "") in self.pm.shapes: + return self.pm.shapes[key.replace(self.BACKWARD_FN_SUFFIX, "")] # type: ignore + else: + raise ValueError(f"Shape for key {key} not found") diff --git a/mithril/framework/codegen/code_gen.py b/mithril/framework/codegen/code_gen.py index a0ca649b..5ea453d0 100644 --- a/mithril/framework/codegen/code_gen.py +++ b/mithril/framework/codegen/code_gen.py @@ -21,6 +21,8 @@ class CodeGen(ABC, Generic[DataType]): + FinalCost = "final_cost" + def __init__(self, pm: PhysicalModel[DataType]) -> None: self.pm: PhysicalModel[DataType] = pm self.code: str | None = None diff --git a/mithril/framework/codegen/ggml_gen.py b/mithril/framework/codegen/ggml_gen.py new file mode 100644 index 00000000..e7b538fb --- /dev/null +++ b/mithril/framework/codegen/ggml_gen.py @@ -0,0 +1,298 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import override + +from ...cores.c.array import PyArray +from ..physical.model import PhysicalModel +from . import c_ast +from .c_gen import CGen + +ast_block_type = list[c_ast.Stmt] | list[c_ast.Expr] | list[c_ast.Stmt | c_ast.Expr] + + +class GGMLCodeGen(CGen): + dynamic_links = ["-lggml-base", "-lggml-cpu", "-lmithrilggml"] + + def __init__(self, pm: PhysicalModel[PyArray]) -> None: + super().__init__(pm) + + self.defined_tmp_vars: set[str] = set() + + def generate_code(self, file_path: str | None = None) -> None: + # Add stdlib.h include for atexit + stdlib_include = c_ast.Include("stdlib.h", system=True) + self.imports.append(stdlib_include) + + # Generate static context variable at file scope + eval_static_ctx = c_ast.StaticVariable( + c_ast.Pointer("struct ggml_context"), + "eval_static_ctx", + c_ast.Constant("NULL"), + ) + + eval_static_gf = c_ast.StaticVariable( + c_ast.Pointer("struct ggml_cgraph"), + "eval_static_gf", + c_ast.Constant("NULL"), + ) + + eval_grad_static_ctx = c_ast.StaticVariable( + c_ast.Pointer("struct ggml_context"), + "eval_grad_static_ctx", + c_ast.Constant("NULL"), + ) + + eval_grad_static_gf = c_ast.StaticVariable( + c_ast.Pointer("struct ggml_cgraph"), + "eval_grad_static_gf", + c_ast.Constant("NULL"), + ) + + cleanup_fn = self.generate_cleanup_fn() + + self.globals.extend( + [ + eval_static_ctx, + eval_grad_static_ctx, + eval_static_gf, + eval_grad_static_gf, + cleanup_fn, + ] + ) + + super().generate_code(file_path) + + def generate_cleanup_fn(self) -> c_ast.Stmt: + fn_body: list[c_ast.Stmt] = [] + + # Add if check for static_ctx + if_check1 = c_ast.If( + c_ast.Variable("eval_static_ctx != NULL"), + [ + c_ast.MakeStmt(c_ast.Call("ggml_free", ["eval_static_ctx"])), + c_ast.Assign(c_ast.Variable("eval_static_ctx"), c_ast.Constant("NULL")), + ], + ) + + if_check2 = c_ast.If( + c_ast.Variable("eval_grad_static_ctx != NULL"), + [ + c_ast.MakeStmt(c_ast.Call("ggml_free", ["eval_grad_static_ctx"])), + c_ast.Assign( + c_ast.Variable("eval_grad_static_ctx"), c_ast.Constant("NULL") + ), + ], + ) + + fn_body.append(if_check1) + fn_body.append(if_check2) + + return c_ast.FunctionDef("void", "cleanup", [], fn_body) + + @override + def define_function( + self, + return_type: str, + name: str, + params: list[c_ast.Parameter], + pre_process: ast_block_type, + operations: ast_block_type, + post_process: ast_block_type, + ) -> c_ast.FunctionDef: + if name in ["evaluate", "evaluate_gradients"]: + return self.update_function( + name, return_type, params, pre_process, operations, post_process + ) + + return super().define_function( + return_type, name, params, pre_process, operations, post_process + ) + + @override + def create_primitive_call( + self, formula_name: str, args: list[c_ast.Expr], context: str + ) -> c_ast.Expr: + # Add context as input for all primitive calls + context_var = "eval_static_ctx" if context == "eval" else "eval_grad_static_ctx" + return c_ast.Call( + formula_name, + [context_var] + args, # type: ignore + ) + + def update_function( + self, + name: str, + return_type: str, + params: list[c_ast.Parameter], + pre_process: ast_block_type, + operations: ast_block_type, + post_process: ast_block_type, + ) -> c_ast.FunctionDef: + # Define static variables at function scope + static_vars: list[c_ast.Stmt] = [] + + fn_ref_name = "eval" if name == "evaluate" else "eval_grad" + ctx_name = f"{fn_ref_name}_static_ctx" + + # Add static tensors + for key in self.determined_struct_keys[f"{fn_ref_name}_input_keys"]: + static_vars.append( + c_ast.StaticVariable( + c_ast.Pointer("struct ggml_tensor"), key, c_ast.Constant("NULL") + ) + ) + + pre_process = static_vars + pre_process + + # Create initialization block + init_block: ast_block_type = [] + + # Initialize context if NULL + init_block.append(c_ast.Comment("One-time initialization")) # type: ignore + init_block.append( + c_ast.StructInit( # type: ignore + "ggml_init_params params", + { + "mem_size": c_ast.Constant(1024 * 1024 * 512), + "mem_buffer": c_ast.Constant("NULL"), + "no_alloc": c_ast.Constant("false"), + }, + ) + ) + init_block.append( + c_ast.Assign( # type: ignore + c_ast.Variable(f"{fn_ref_name}_static_ctx"), + c_ast.Call("ggml_init", ["params"]), + ) + ) + + # Create tensors + init_block.append(c_ast.Comment("Create tensors only once")) # type: ignore + for key in self.determined_struct_keys[f"{fn_ref_name}_input_keys"]: + shape = self._get_tensor_shape(key) + if shape is not None: + tensor = c_ast.Call( + f"ggml_new_tensor_{len(shape)}d", + [ctx_name, "GGML_TYPE_F32"] + [str(size) for size in shape], + ) + init_block.append(c_ast.Assign(c_ast.Variable(key), tensor)) # type: ignore + + # Create and build graph + init_block.extend( + [ + c_ast.Comment("Create graph object only once"), # type: ignore + c_ast.Assign( # type: ignore + c_ast.Variable("eval_static_gf"), + c_ast.Call("ggml_new_graph", [ctx_name]), + ), + ] + ) + + # Add the original body operations + init_block += operations # type: ignore + + # Build graph + for out_key in self.determined_struct_keys[f"{fn_ref_name}_output_keys"]: + init_block.append( + c_ast.MakeStmt( # type: ignore + c_ast.Call( + "ggml_build_forward_expand", + [ + "eval_static_gf", + self.create_key_ref(out_key, context=fn_ref_name), + ], + ) + ) + ) + + init_block.append(c_ast.MakeStmt(c_ast.Call("atexit", ["cleanup"]))) # type: ignore + + # Wrap initialization in if check + if_init = [c_ast.If(c_ast.Variable(f"{ctx_name} == NULL"), init_block)] # type: ignore + + # Update input data + update_ptr_block: ast_block_type = [] + update_ptr_block.append(c_ast.Comment("Update tensor data for each call")) # type: ignore + for key in self.determined_struct_keys[f"{fn_ref_name}_input_keys"]: + update_ptr_block.append( + c_ast.Assign( # type: ignore + c_ast.Arrow(c_ast.Variable(f"{key}"), "data"), + c_ast.Arrow(c_ast.Arrow(c_ast.Variable("inputs"), key), "data"), + ) + ) + + # Initialization function + init_fn = super().define_function( + "void", + f"init_{fn_ref_name}", + params, + static_vars, + if_init, # type: ignore + update_ptr_block, + ) + + self.functions.append(init_fn) + + # Call initialization function + call_init_fn = c_ast.MakeStmt( + c_ast.Call( + f"init_{fn_ref_name}", + ["inputs"], + ) + ) + + pre_process = [call_init_fn] # type: ignore + + # Compute graph + compute_block = [ + c_ast.Comment("Compute graph"), + c_ast.MakeStmt( + c_ast.Call( + "ggml_graph_compute_with_ctx", + [ctx_name, "eval_static_gf", c_ast.Constant(1)], + ) + ), + ] + + post_process = compute_block + post_process + + return super().define_function( + return_type, name, params, pre_process, [], post_process + ) + + @override + def create_key_ref( + self, key: str, context: str, load: bool = True + ) -> c_ast.Variable | c_ast.Expr: + # TODO: Refactor this logic + if ( + key not in self.determined_struct_keys["eval_cache_keys"] + and context == "eval" + ): + return c_ast.Variable(key) + + elif ( + key not in self.determined_struct_keys["eval_cache_keys"] + and context == "eval_grad" + ): + if key in self.determined_struct_keys["eval_grad_output_keys"]: + return c_ast.Dot(c_ast.Variable(f"{self.GRAD_STRUCT_NAME}"), key) + elif not load: + return c_ast.Variable(f"{self.configs.ARRAY_NAME} * {key}") + else: + return c_ast.Variable(key) + + return super().create_key_ref(key, context, load) diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 90c39889..1ed3c53f 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -22,6 +22,7 @@ from typing import Any, Generic, Literal, Protocol, overload from ...backends.backend import ParallelBackend +from ...common import PythonGenConfig from ...types import DataType, Dtype from ...utils.func_utils import prepare_function_args from ..common import ( @@ -106,6 +107,9 @@ def __init__(self, pm: PhysicalModel[DataType]) -> None: self.functions: list[ast.stmt] = [] self.backend = self.pm.backend + assert isinstance(self.backend.CODEGEN_CONFIG, PythonGenConfig) + self.configs = self.backend.CODEGEN_CONFIG + def generate_code(self, file_path: str | None = None) -> None: self.file_path = file_path self.imports += self.generate_imports() diff --git a/mithril/framework/codegen/raw_c_gen.py b/mithril/framework/codegen/raw_c_gen.py new file mode 100644 index 00000000..c3036f0e --- /dev/null +++ b/mithril/framework/codegen/raw_c_gen.py @@ -0,0 +1,53 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import override + +from . import c_ast +from .c_gen import CGen + + +class RawCGen(CGen): + dynamic_links = ["-lmithrilc"] + + @override + def _determine_struct_keys(self) -> dict[str, list[str]]: + determined_struct_keys = super()._determine_struct_keys() + determined_struct_keys["eval_grad_input_keys"] = sorted( + { + key + "_grad" + for key in self.pm.flat_graph.all_keys + if key not in self.ignored_grad_keys + } + | self.pm.flat_graph.all_keys + ) + determined_struct_keys["eval_output_keys"] = sorted( + determined_struct_keys["eval_input_keys"] + ) + + return determined_struct_keys + + def create_key_ref( + self, key: str, context: str, load: bool = True + ) -> c_ast.Variable | c_ast.Expr: + if key in self.determined_struct_keys["eval_input_keys"]: + return c_ast.Variable(f"inputs->{key}") + + else: + return super().create_key_ref(key, context, load) + + def assign_primitive_output( + self, target: str, source: c_ast.Expr, context: str + ) -> c_ast.Assign: + return c_ast.MakeStmt(source) # type: ignore diff --git a/mithril/framework/codegen/utils.py b/mithril/framework/codegen/utils.py index 9835784c..2cdeb1b1 100644 --- a/mithril/framework/codegen/utils.py +++ b/mithril/framework/codegen/utils.py @@ -15,6 +15,7 @@ import ast from ...backends.backend import Backend +from ...common import PythonGenConfig from ...types import DataType from ..common import ShapeNode @@ -23,11 +24,13 @@ def partial_array_creation_func( backend: Backend[DataType], formula_key: str ) -> ast.stmt: + assert isinstance(backend.CODEGEN_CONFIG, PythonGenConfig) + kwargs = [ ast.keyword(arg="default_dtype", value=ast.Constant(value=backend._dtype.name)) ] - if backend.codegen_config["specify_device"]: + if backend.CODEGEN_CONFIG.SPECIFY_DEVICE: kwargs.append( ast.keyword(arg="device", value=ast.Constant(value=backend.get_device())) ) diff --git a/mithril/framework/physical/flat_graph.py b/mithril/framework/physical/flat_graph.py index 64cbd564..876d012d 100644 --- a/mithril/framework/physical/flat_graph.py +++ b/mithril/framework/physical/flat_graph.py @@ -21,7 +21,7 @@ import mithril as ml -from ...common import BiMap +from ...common import BiMap, PythonGenConfig from ...types import DataType, GenericDataType from ...utils.func_utils import is_make_array_required, prepare_function_args from ..common import ( @@ -684,7 +684,9 @@ def infer_static_keys(self) -> Updates: # If function needs backend specific args if model.formula_key in self.backend.array_creation_funcs: kwargs["default_dtype"] = self.backend._dtype.name - if self.backend.codegen_config["specify_device"]: + # TODO: Add support for C backends + assert isinstance(self.backend.CODEGEN_CONFIG, PythonGenConfig) + if self.backend.CODEGEN_CONFIG.SPECIFY_DEVICE: kwargs["device"] = self.backend.get_device() static_value = fn(*args, **kwargs) diff --git a/setup.py b/setup.py index f10a3268..822bb43d 100644 --- a/setup.py +++ b/setup.py @@ -17,19 +17,60 @@ import setuptools from setuptools.command.build_ext import build_ext +from setuptools.extension import Extension class CustomBuildExt(build_ext): def run(self): - shell = os.getenv("SHELL", "sh") - script_path = os.path.join( - os.path.dirname(__file__), - "mithril", - "cores", - "c", - "compile.sh", - ) - subprocess.check_call([shell, script_path]) + # Use bash explicitly instead of relying on SHELL environment variable + shell = "/bin/bash" + + # Define script paths + scripts = [ + os.path.join( + os.path.dirname(__file__), + "mithril", + "cores", + "c", + "raw_c", + "compile.sh", + ), + os.path.join( + os.path.dirname(__file__), + "mithril", + "cores", + "c", + "ggml", + "build_ggml.sh", + ), + os.path.join( + os.path.dirname(__file__), + "mithril", + "cores", + "c", + "ggml", + "compile.sh", + ), + ] + + print("Running compilation scripts...") + + # Save current working directory + original_dir = os.getcwd() + + try: + # Run each script from its own directory + for script_path in scripts: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + print(f"Running {script_name} in {script_dir}") + os.chdir(script_dir) + subprocess.check_call([shell, f"./{script_name}"]) + os.chdir(original_dir) # Return to original directory + finally: + # Make sure we return to the original directory even if an error occurs + os.chdir(original_dir) + # Continue with the normal build super().run() @@ -54,6 +95,6 @@ def run(self): python_requires=">=3.12", install_requires=[], cmdclass={"build_ext": CustomBuildExt}, - package_data={"mithril.cores.c": ["libmithrilc.so"]}, + ext_modules=[Extension("mithril.c_sources", sources=[])], include_package_data=True, ) diff --git a/tests/scripts/test_c_backend.py b/tests/scripts/test_c_backend.py index 644bcd03..057d0863 100644 --- a/tests/scripts/test_c_backend.py +++ b/tests/scripts/test_c_backend.py @@ -17,7 +17,7 @@ import numpy as np -from mithril import CBackend, NumpyBackend, compile +from mithril import CBackend, GGMLBackend, NumpyBackend, compile from mithril.cores.c.array import PyArray from mithril.framework.common import Tensor from mithril.models import Add, IOKey, Model, Multiply @@ -30,9 +30,11 @@ def test_cbackend_1(): model += Add()(left="left", right="right", output="output") model.set_types(left=Tensor, right=Tensor) + model.set_differentiability(left=True, right=True) c_backend = CBackend() np_backend = NumpyBackend() + ggml_backend = GGMLBackend() c_pm = compile( model, @@ -49,15 +51,26 @@ def test_cbackend_1(): jit=False, ) + ggml_pm = compile( + model, + ggml_backend, + shapes={"left": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "right"}, + jit=False, + ) + left = np_backend.ones(5, 5) right = np_backend.ones(5, 5) output_grad = np_backend.rand(5, 5) + # Numpy Backend + np_outputs = np_pm.evaluate({"left": left, "right": right}) np_grads = np_pm.evaluate_gradients( {"left": left, "right": right}, {}, output_gradients={"output": output_grad} ) + # Raw C Backend c_left = c_backend.array(left) c_right = c_backend.array(right) c_output_grad = c_backend.array(output_grad) @@ -69,25 +82,42 @@ def test_cbackend_1(): output_gradients={"output": c_output_grad}, ) + # GGML Backend + ggml_left = ggml_backend.array(left) + ggml_right = ggml_backend.array(right) + ggml_output_grad = ggml_backend.array(output_grad) + + ggml_outputs = ggml_pm.evaluate({"left": ggml_left, "right": ggml_right}) + ggml_grads = ggml_pm.evaluate_gradients( + {"left": ggml_left, "right": ggml_right}, + {}, + output_gradients={"output": ggml_output_grad}, + ) + + # Assertions for key in np_outputs: out = c_outputs[key] + out_ggml = ggml_outputs[key] out_np = np_outputs[key] assert isinstance(out_np, np.ndarray) assert isinstance(out, PyArray) + assert isinstance(out_ggml, PyArray) assert np.allclose(c_backend.to_numpy(out), out_np) + assert np.allclose(ggml_backend.to_numpy(out_ggml), out_np) for key in np_grads: assert np.allclose(c_backend.to_numpy(c_grads[key]), np_grads[key]) + assert np.allclose(ggml_backend.to_numpy(ggml_grads[key]), np_grads[key]) @with_temp_file(suffix=".c") def test_cbackend_2(file_path: str): model = Model() - add = Add() - model |= add(left="left", right="right", output=IOKey(name="output")) + model |= Add()(left="left", right="right", output=IOKey(name="output")) model |= Add()(left="left2", right="output", output=IOKey(name="output2")) model.set_types(left=Tensor, left2=Tensor, right=Tensor) + model.set_differentiability(left=True, left2=True, right=True) c_backend = CBackend() np_backend = NumpyBackend() @@ -113,6 +143,7 @@ def test_cbackend_2(file_path: str): right = np_backend.ones(5, 5) # type: ignore # (check after DataTypes Update) output_grad = np_backend.rand(5, 5) + # Numpy np_outputs = np_pm.evaluate({"left": left, "right": right, "left2": left2}) np_grads = np_pm.evaluate_gradients( {"left": left, "right": right, "left2": left2}, @@ -123,6 +154,7 @@ def test_cbackend_2(file_path: str): }, ) + # Raw C Backend c_left = c_backend.array(left) c_left2 = c_backend.array(left2) c_right = c_backend.array(right) @@ -135,6 +167,7 @@ def test_cbackend_2(file_path: str): output_gradients={"output": c_output_grad, "output2": c_output_grad}, ) + # Assertions for key in np_outputs: out = c_outputs[key] out_np = np_outputs[key] @@ -144,7 +177,6 @@ def test_cbackend_2(file_path: str): for key in np_grads: assert np.allclose(c_backend.to_numpy(c_grads[key]), np_grads[key]) - os.remove(file_path.replace(".c", ".so")) diff --git a/tests/scripts/test_functions.py b/tests/scripts/test_functions.py index 71134cef..6d237dcf 100644 --- a/tests/scripts/test_functions.py +++ b/tests/scripts/test_functions.py @@ -732,22 +732,24 @@ def test_code_generator_8(file_path: str): evaluate_gradient_code = "".join(code[start_line:end_line]) reference_eval_code = ( - "void evaluate(\n\tArray * left,\n\tArray * output,\n\tArray * output_0" - ",\n\tArray * right,\n\tArray * right2\n)\n{\n add(output_0, left, " - "right);\n multiplication(output, output_0, right2);\n}\n" + "struct eval_outputs evaluate(\n\tstruct eval_inputs * inputs\n)\n{\n " + "add(inputs->output_0, inputs->left, inputs->right);\n multiplication(i" + "nputs->output, inputs->output_0, inputs->right2);\n struct eval_outputs " + "output_struct = { .left = inputs->left, .output = inputs->output, .output_0 " + "= inputs->output_0, .right = inputs->right, .right2 = inputs->right2 };\n " + "return output_struct;\n}\n" ) reference_eval_grad_code = ( - "void evaluate_gradients(\n\tArray * left,\n\tArray * " - "left_grad,\n\tArray * output,\n\tArray * output_0,\n\t" - "Array * output_0_grad,\n\tArray * output_grad,\n\tArray * right,\n\tArray " - "* right2,\n\tArray * right2_grad,\n\tArray * right_grad\n)\n{\n " - "multiplication_grad(output_grad, 0, output, output_0, right2, " - "output_0_grad, right2_grad);\n multiplication_grad(output_grad" - ", 1, output, output_0, right2, output_0_grad, right2_grad);\n" - " add_grad(output_0_grad, 0, output_0, left, right, left_grad," - " right_grad);\n add_grad(output_0_grad, 1, output_0, left, " - "right, left_grad, right_grad);\n}" + "struct eval_grad_outputs evaluate_gradients(\n\tstruct eval_grad_inputs " + "* inputs\n)\n{\n multiplication_grad(inputs->output_grad, 0, inputs->" + "output, inputs->output_0, inputs->right2, inputs->output_0_grad, NULL);\n" + " add_grad(inputs->output_0_grad, 0, inputs->output_0, inputs->left, inputs" + "->right, inputs->left_grad, inputs->right_grad);\n add_grad(inputs->" + "output_0_grad, 1, inputs->output_0, inputs->left, inputs->right, inputs->" + "left_grad, inputs->right_grad);\n struct eval_grad_outputs output_struct" + " = { .left_grad = inputs->left_grad, .right_grad = inputs->right_grad };\n " + "return output_struct;\n}" ) assert eval_code == reference_eval_code