Skip to content

Commit

Permalink
chore: Flat graph data store communication (synnada-ai#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored and mehmetozsoy-synnada committed Feb 12, 2025
1 parent e07bf80 commit c80f1bf
Show file tree
Hide file tree
Showing 26 changed files with 1,352 additions and 1,144 deletions.
17 changes: 5 additions & 12 deletions mithril/framework/codegen/c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def evaluate_gradients_wrapper(
# Create gradients for all params
for key in (
self.pm.flat_graph.all_source_keys
- self.pm.data_store.all_static_keys
- self.pm.data_store.unused_keys
- self.pm.flat_graph.all_static_keys
- self.pm.flat_graph.unused_keys
- self.pm.ignore_grad_keys
):
# In CBackend we are creating all internal gradients with zeros.
Expand Down Expand Up @@ -206,14 +206,7 @@ def generate_evaluate(self) -> tuple[c_ast.FunctionDef, set[str]]:
fn_body: list[c_ast.Expr] = []
used_keys: set[str] = set()

unused_keys = self.pm.data_store.unused_keys
cached_data_keys = self.pm.data_store.cached_data.keys()

for output_key in self.pm.flat_graph.topological_order:
# Staticly infered and unused model will not be added
if output_key in (cached_data_keys | unused_keys):
continue

model = self.pm.flat_graph.get_model(output_key)
inputs = self.pm.flat_graph.get_source_keys(output_key)

Expand Down Expand Up @@ -241,10 +234,10 @@ def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]:

all_ignored_keys = (
self.pm.ignore_grad_keys
| self.pm.data_store.all_static_keys
| self.pm.data_store.unused_keys
| self.pm.flat_graph.all_static_keys
| self.pm.flat_graph.unused_keys
)
all_ignored_keys, _ = self.pm.infer_ignore(
all_ignored_keys, _ = self.pm.flat_graph.infer_ignore(
set(), self.pm._output_keys, all_ignored_keys, update_graph=False
)

Expand Down
34 changes: 11 additions & 23 deletions mithril/framework/codegen/numpy_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def evaluate_gradients_wrapper_manualgrad(
data = {}
# TODO: Consider not unioning batch data (data) into self.data
# If evaluate_gradients called directly, first call evaluate.
cached_data = self.pm.data_store.data_values
cached_data = self.pm.flat_graph.cached_data

output: dict[str, np.ndarray[Any, Any]] = eval_fn(
params=params, data=data, cache=cached_data
Expand All @@ -151,8 +151,8 @@ def evaluate_gradients_wrapper_manualgrad(
gradients: dict[str, np.ndarray[Any, Any]] = {}
for key in (
self.pm.flat_graph.all_keys
- self.pm.data_store.all_static_keys
- self.pm.data_store.unused_keys
- self.pm.flat_graph.all_static_keys
- self.pm.flat_graph.unused_keys
- self.pm.ignore_grad_keys
):
key_cache = cached_data.get(key + "_cache", {})
Expand Down Expand Up @@ -285,7 +285,7 @@ def create_primitive_call_targets(

def get_cache_name(self, output_key: str, model: PrimitiveModel) -> str:
cache_name = "_".join([output_key, model.cache_name])
if cache_name not in self.pm.data_store._all_data:
if cache_name not in self.pm.flat_graph.all_data:
self.add_cache(model, output_key)

return cache_name
Expand All @@ -294,7 +294,7 @@ def add_cache(self, model: PrimitiveModel, output_key: str) -> None:
cache_name = "_".join([output_key, model.cache_name])
cache_value: dict[str, Any] | None = None if self.pm.inference else {}
# Create a scalar for caches in manualgrad backend.
self.pm.data_store.update_data(
self.pm.flat_graph.update_data(
{cache_name: IOHyperEdge(dict | None, cache_value)}
)

Expand All @@ -305,11 +305,7 @@ def generate_evaluate_gradients(
function_body: list[ast.stmt] = []
used_keys: set[str] = set()

all_ignored_keys = (
ignore_grad_keys
| self.pm.data_store.all_static_keys
| self.pm.data_store.unused_keys
)
all_ignored_keys = ignore_grad_keys | self.pm.flat_graph.all_static_keys

# TODO: Is this should be here?
# Seperate ignored keys into two types of weak and strict ignored keys.
Expand All @@ -321,17 +317,9 @@ def generate_evaluate_gradients(
and find_intersection_type(self.pm.data[key].value_type, float)
}

# weak_ignored_keys = set()
# for key in all_ignored_keys:
# if key in self.pm.data:
# edge = self.pm.data[key]
# if isinstance(edge._value, Tensor):
# if find_intersection_type(edge._value.type, float):
# weak_ignored_keys |= {key}

strict_ignored_keys = all_ignored_keys - weak_ignored_keys

ignore_grad_keys, _ = self.pm.infer_ignore(
ignore_grad_keys, _ = self.pm.flat_graph.infer_ignore(
weak_ignored_keys,
self.pm._output_keys,
strict_ignored_keys,
Expand Down Expand Up @@ -413,7 +401,7 @@ def generate_evaluate_gradients(
)
}
args, kwargs = prepare_function_args(
self.pm.data_store.data_values,
self.pm.flat_graph.cached_data,
primitive_function,
local_to_global_dict,
self.backend.array_creation_funcs,
Expand Down Expand Up @@ -450,7 +438,7 @@ def generate_evaluate_gradients(
for idx, global_input_key in enumerate(global_input_keys[:-2]):
if (
global_input_key
in ignore_grad_keys | self.pm.data_store.runtime_static_keys
in ignore_grad_keys | self.pm.flat_graph.runtime_static_keys
):
continue

Expand Down Expand Up @@ -538,10 +526,10 @@ def generate_evaluate_gradients(
if (
key
in self.pm.flat_graph.all_target_keys
| self.pm.data_store.cached_data.keys()
| self.pm.flat_graph.cached_data.keys()
):
dict_type = "cache"
elif key in self.pm.data_store.runtime_static_keys:
elif key in self.pm.flat_graph.runtime_static_keys:
dict_type = "data"
else:
dict_type = "params"
Expand Down
30 changes: 13 additions & 17 deletions mithril/framework/codegen/python_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def post_process_fns(
eval_fn: EvaluateType[DataType] | partial[Any] = partial(
self.compute_evaluate,
fn=raw_eval_fn,
cache=self.pm.data_store.data_values,
cache=self.pm.flat_graph.cached_data,
)
grad_fn = None
evaluate_all_fn = None
Expand Down Expand Up @@ -299,10 +299,10 @@ def generate_imports(self) -> list[ast.stmt]:

def is_static_scalar(self, key: str) -> bool:
return (
key in self.pm.data_store.cached_data
key in self.pm.flat_graph.cached_data
and not self.pm.data[key].is_tensor
and self.pm.data[key].edge_type != Dtype
and not isinstance(self.pm.data_store.cached_data[key], enum.Enum)
and not isinstance(self.pm.flat_graph.cached_data[key], enum.Enum)
)

def get_primitive_details(
Expand Down Expand Up @@ -343,19 +343,15 @@ def generate_evaluate(self) -> ast.FunctionDef:
used_keys: set[str] = set()
used_keys |= set(self.pm.flat_graph.output_dict.values())

unused_keys = self.pm.data_store.unused_keys
cached_data_keys = self.pm.data_store.cached_data.keys()
unused_keys = self.pm.flat_graph.unused_keys
cached_data_keys = self.pm.flat_graph.cached_data.keys()
discarded_keys = self.pm.discarded_keys # TODO: Consider is this necessary?

deleted_vars: set[str] = set()
assigned_output_keys: set[str] = set()

# Iterate over Primitive models in topological order to add their formula.
for output_key in self.pm.flat_graph.topological_order:
# Staticly infered and unused model will not be added
if output_key in (cached_data_keys | unused_keys | discarded_keys):
continue

model, g_input_keys, l_input_keys = self.get_primitive_details(output_key)
formula_key = model.formula_key

Expand Down Expand Up @@ -386,7 +382,7 @@ def generate_evaluate(self) -> ast.FunctionDef:
or used_key in deleted_vars
or (
used_key in self.pm.input_keys # Inputs shouldn't deleted
or used_key in self.pm.data_store.all_static_keys
or used_key in self.pm.flat_graph.all_static_keys
)
):
continue
Expand All @@ -403,7 +399,7 @@ def generate_evaluate(self) -> ast.FunctionDef:
for key in sorted(used_keys):
if key in cached_data_keys:
dict_type = "cache"
elif key in self.pm.data_store.runtime_static_keys:
elif key in self.pm.flat_graph.runtime_static_keys:
dict_type = "data"
elif key not in self.pm.flat_graph.all_target_keys:
dict_type = "params"
Expand All @@ -419,7 +415,7 @@ def generate_evaluate(self) -> ast.FunctionDef:
# TODO: give an api to get outputdict
if self.is_static_scalar(output_key):
return_values.append(
ast.Constant(self.pm.data_store.cached_data[output_key])
ast.Constant(self.pm.flat_graph.cached_data[output_key])
)
else:
return_values.append(
Expand Down Expand Up @@ -489,7 +485,7 @@ def append_inputs(
else:
# If key is an output of a function, then get the corresponding
# function cache from general cache and then get "output" from there.
cached_data = self.pm.data_store.cached_data
cached_data = self.pm.flat_graph.cached_data
data_dict: ast.Subscript | ast.Name
if key not in cached_data:
cache_name = key + "_cache"
Expand Down Expand Up @@ -520,14 +516,14 @@ def create_primitive_call(
"""Generates a single function call AST (Abstract Syntax Tree)."""
if default_args is None:
default_args = {}
cache = self.pm.data_store.cached_data
cache = self.pm.flat_graph.cached_data
formula_key = function.__name__
inputs = {
key: value for key, value in zip(local_keys, global_keys, strict=False)
}
# Prepare function arguments
fn_args_mapping, fn_kwarg_dict = prepare_function_args(
self.pm.data_store.data_values,
self.pm.flat_graph.cached_data,
function,
inputs,
self.pm.backend.array_creation_funcs,
Expand Down Expand Up @@ -616,14 +612,14 @@ def create_gradient_fn(
grad_fn = partial(
self.compute_gradients,
raw_evaluate_fn=raw_evaluate_fn,
cache=self.pm.data_store.data_values,
cache=self.pm.flat_graph.cached_data,
include_output=False,
)
# Fix fn_all for mlx support!!
fn_all = partial(
self.compute_gradients,
raw_evaluate_fn=raw_evaluate_fn,
cache=self.pm.data_store.data_values,
cache=self.pm.flat_graph.cached_data,
include_output=True,
)
return grad_fn, fn_all # type: ignore
Expand Down
Loading

0 comments on commit c80f1bf

Please sign in to comment.