Skip to content

Commit

Permalink
add static inference support for rawc and ggml backends
Browse files Browse the repository at this point in the history
  • Loading branch information
emrecakmakyurdu committed Mar 9, 2025
1 parent 048686e commit 9b2f289
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 12 deletions.
15 changes: 11 additions & 4 deletions mithril/backends/with_manualgrad/c_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@

from .... import types
from ....cores.c.array import PyArray
from ....cores.c.raw_c import array
from ....cores.c.raw_c import array, ops
from ...backend import Backend
from ...utils import process_shape
from ....common import BiMap
from . import utils

__all__ = ["CBackend"]


dtype_map: BiMap[str, Any] = BiMap(
{
"float32": np.float32,
}
)
class CBackend(Backend[PyArray]):
backend_type = "c"
SRC_PATH = os.path.join(
Expand All @@ -37,8 +42,10 @@ class CBackend(Backend[PyArray]):

def __init__(self) -> None:
self._device = "cpu"
self.primitive_function_dict = {}

self.primitive_function_dict = ops.primitive_func_dict
self.dtype_map = dtype_map
self.registered_primitives = {}
self.array_creation_funcs = {}
@property
def is_manualgrad(self) -> bool:
return True
Expand Down
12 changes: 11 additions & 1 deletion mithril/backends/with_manualgrad/ggml_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@
from ...utils import process_shape
from ..c_backend.utils import from_numpy
from . import utils
from ....cores.c.ggml import ops
from ....common import BiMap

__all__ = ["GGMLBackend"]

dtype_map: BiMap[str, Any] = BiMap(
{
"float32": np.float32,
}
)

class GGMLBackend(Backend[PyArray]):
backend_type = "c"
Expand All @@ -39,7 +46,10 @@ class GGMLBackend(Backend[PyArray]):

def __init__(self) -> None:
self._device = "cpu"
self.primitive_function_dict = {}
self.primitive_function_dict = ops.primitive_func_dict
self.dtype_map = dtype_map
self.registered_primitives = {}
self.array_creation_funcs = {}

@property
def is_manualgrad(self) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion mithril/cores/c/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import ctypes
from collections.abc import Sequence

import numpy as np

class PyArray:
def __init__(self, arr: ctypes.Structure, shape: tuple[int, ...] | list[int]):
Expand All @@ -30,6 +30,10 @@ def __init__(self, arr: ctypes.Structure, shape: tuple[int, ...] | list[int]):
# def __del__(self):
# lib.delete_struct(self.arr)

@property
def dtype(self) -> type:
return np.float32

@property
def data(self) -> Sequence[int | Sequence[int | Sequence[int]]]:
total_elements = 1
Expand Down
53 changes: 53 additions & 0 deletions mithril/cores/c/ggml/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import ctypes
import os
from ....cores.c.array import PyArray
from ....cores.c.raw_c.array import (
Array,
zeros,
lib
)
from ....backends.with_manualgrad.c_backend.backend import array
from ....backends.with_manualgrad.c_backend.utils import from_numpy
from ....cores.c.ggml.ggml_core import ggml_struct
import numpy as np

__all__ = [
"add",
"multiplication"
]

def convert_to_c_array(
input: PyArray
) -> PyArray:
input_np = np.array(input.data, dtype=input.dtype)
return from_numpy(input_np)

def add(
left: PyArray,
right: PyArray
) -> PyArray:
# In C backend, output is given as first input
output = zeros(left.shape)
left_c = convert_to_c_array(left)
right_c = convert_to_c_array(right)
lib.add(ctypes.byref(output.arr), ctypes.byref(left_c.arr), ctypes.byref(right_c.arr))
_shape = output.shape
data_ptr = ctypes.cast(output.arr.data, ctypes.c_void_p)
return PyArray(ggml_struct(data=data_ptr), _shape)

def multiplication(
left: PyArray,
right: PyArray
) -> PyArray:
# In C backend, output is given as first input
output = zeros(left.shape)
left_c = convert_to_c_array(left)
right_c = convert_to_c_array(right)
lib.multiplication(ctypes.byref(output.arr), ctypes.byref(left_c.arr), ctypes.byref(right_c.arr))
_shape = output.shape
data_ptr = ctypes.cast(output.arr.data, ctypes.c_void_p)
return PyArray(ggml_struct(data=data_ptr), _shape)



primitive_func_dict = {key: fn for key, fn in globals().items() if callable(fn)}
32 changes: 32 additions & 0 deletions mithril/cores/c/raw_c/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import ctypes
import os
from ....cores.c.array import PyArray
from ....cores.c.raw_c.array import (
Array,
zeros,
lib
)
__all__ = [
"add",
"multiplication"
]

def add(
left: PyArray,
right: PyArray
) -> PyArray:
# In C backend, output is given as first input
output = zeros(left.shape)
lib.add(ctypes.byref(output.arr), ctypes.byref(left.arr), ctypes.byref(right.arr))
return output

def multiplication(
left: PyArray,
right: PyArray
) -> PyArray:
# In C backend, output is given as first input
output = zeros(left.shape)
lib.multiplication(ctypes.byref(output.arr), ctypes.byref(left.arr), ctypes.byref(right.arr))
return output

primitive_func_dict = {key: fn for key, fn in globals().items() if callable(fn)}
34 changes: 29 additions & 5 deletions mithril/framework/codegen/ggml_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def update_function(
# Create tensors
init_block.append(c_ast.Comment("Create tensors only once")) # type: ignore
for key in self.determined_struct_keys[f"{fn_ref_name}_input_keys"]:
# If key statically inferred, skip tensor creation
if key in self.determined_struct_keys[f"{fn_ref_name}_output_keys"]:
continue
shape = self._get_tensor_shape(key)
if shape is not None:
tensor = c_ast.Call(
Expand All @@ -206,6 +209,9 @@ def update_function(

# Build graph
for out_key in self.determined_struct_keys[f"{fn_ref_name}_output_keys"]:
# If key is statically inferred, skip marking
if out_key in self.determined_struct_keys[f"{fn_ref_name}_input_keys"]:
continue
init_block.append(
c_ast.MakeStmt( # type: ignore
c_ast.Call(
Expand All @@ -227,12 +233,21 @@ def update_function(
update_ptr_block: ast_block_type = []
update_ptr_block.append(c_ast.Comment("Update tensor data for each call")) # type: ignore
for key in self.determined_struct_keys[f"{fn_ref_name}_input_keys"]:
update_ptr_block.append(
c_ast.Assign( # type: ignore
c_ast.Arrow(c_ast.Variable(f"{key}"), "data"),
c_ast.Arrow(c_ast.Arrow(c_ast.Variable("inputs"), key), "data"),
# If key is statically inferred, assign to output directly
if key in self.determined_struct_keys[f"{fn_ref_name}_output_keys"]:
update_ptr_block.append(
c_ast.Assign( # type: ignore
self.create_key_ref(key, context=fn_ref_name),
c_ast.Arrow(c_ast.Variable("inputs"), f"{key}")
)
)
else:
update_ptr_block.append(
c_ast.Assign( # type: ignore
c_ast.Arrow(c_ast.Variable(f"{key}"), "data"),
c_ast.Arrow(c_ast.Arrow(c_ast.Variable("inputs"), key), "data"),
)
)
)

# Initialization function
init_fn = super().define_function(
Expand Down Expand Up @@ -296,3 +311,12 @@ def create_key_ref(
return c_ast.Variable(key)

return super().create_key_ref(key, context, load)

@override
def _determine_struct_keys(self) -> dict[str, list[str]]:
determined_struct_keys = super()._determine_struct_keys()
static_cache_keys = sorted(self.pm.flat_graph.all_static_keys)
if static_cache_keys:
determined_struct_keys["eval_input_keys"] = static_cache_keys

return determined_struct_keys
2 changes: 1 addition & 1 deletion mithril/framework/physical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def evaluate(
):
outputs = self.backend._run_callable(params, data, fn_name="eval_fn")
else:
outputs = self._generated_eval_fn(params, data)
outputs = self._generated_eval_fn(params, data, cache=self.flat_graph.cached_data)

outputs, state_outputs = self._extract_state_outputs(outputs)
if len(state_outputs) == 0:
Expand Down
Loading

0 comments on commit 9b2f289

Please sign in to comment.