Skip to content

Commit

Permalink
refactor: Update handling of assigned attributes. (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
norhan-synnada authored Mar 7, 2025
1 parent ff34187 commit d506aae
Show file tree
Hide file tree
Showing 23 changed files with 1,561 additions and 757 deletions.
4 changes: 2 additions & 2 deletions examples/flux/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def load_t5_encoder(
repo_id: str = "black-forest-labs/FLUX.1-schnell",
max_len: int = 256,
) -> ml.models.PhysicalModel:
config = hf_hub_download(repo_id, "text_encoder_2/config.json")
config_path = hf_hub_download(repo_id, "text_encoder_2/config.json")

with open(config) as f:
with open(config_path) as f:
config = json.load(f)

t5 = t5_encode(config, name="encoder")
Expand Down
12 changes: 6 additions & 6 deletions examples/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def load_flow_model(name: str, backend: ml.Backend, hf_download: bool = True):
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and (r_id := configs[name].repo_id) is not None
and (r_flow := configs[name].repo_flow) is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
ckpt_path = hf_hub_download(r_id, r_flow)

flux_lm = flux(configs[name].params)
flux_pm = ml.compile(
Expand All @@ -187,11 +187,11 @@ def load_decoder(
ckpt_path = configs[name].ae_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and (r_id := configs[name].repo_id) is not None
and (r_ae := configs[name].repo_ae) is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
ckpt_path = hf_hub_download(r_id, r_ae)

# Loading the autoencoder
print("Init AE")
Expand Down
2 changes: 1 addition & 1 deletion mithril/cores/python/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def scaled_dot_product_attention(
)

L, S = query.shape[-2], key.shape[-2]
scale_factor = 1 / np.sqrt(query.shape[-1]) if scale is None else scale
scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
write_into_cache(cache, "scale_factor", scale_factor)
attn_bias = np.zeros((L, S), dtype=query.dtype)
if is_causal:
Expand Down
44 changes: 33 additions & 11 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def __init__(
value: TensorValueType | ToBeDetermined = TBD,
type: _TensorTypes = int | float | bool,
shape: ShapeNode | None = None,
differentiable: bool = False,
differentiable: bool | None = None,
):
if shape is None:
# If shape is not provided, create a new shape with a Variadic root.
Expand Down Expand Up @@ -1119,8 +1119,16 @@ def match(self, other: Tensor[int | float | bool]) -> Updates:
updates |= non_valued.set_value(valued.value)
self.differentiable = False
other.differentiable = False
else:
self.differentiable |= other.differentiable
elif self.differentiable is None:
self.differentiable = other.differentiable
# Differentiable tensors can only be float type.
if self.differentiable:
updates |= self.set_type(float)
elif (
other.differentiable is not None
and self.differentiable != other.differentiable
):
raise ValueError("Differentiability mismatch!")
# Match shapes.
updates |= self.match_shapes(other.shape)
updates.shape_updates.discard(other)
Expand Down Expand Up @@ -1180,8 +1188,16 @@ def _temp_shape(self) -> ShapeRepr | None:
return None

@property
def differentiable(self) -> bool:
return isinstance(self._value, Tensor) and self._value.differentiable
def differentiable(self) -> bool | None:
if isinstance(self._value, Tensor):
return self._value.differentiable
elif self.is_scalar:
# Scalars are always non-differentiable.
return False
# Differentiability of polymorphic edges are defined
# as None. Depending on its instant type updates, it can
# become True or False (e.g Tensor or int type).
return None

@property
def tensors(self) -> set[Tensor[int | float | bool]]:
Expand Down Expand Up @@ -1405,7 +1421,7 @@ def set_value(
updates.value_updates.add(self)
# Update new type without automatic tensor value creation.
updates |= self.set_type(find_type(self._value), create_tensor=False)
if self.is_valued:
if self.is_tensor and self.is_valued:
self.set_differentiability(False)
return updates

Expand All @@ -1431,13 +1447,19 @@ def match(self, other: IOHyperEdge) -> Updates:

return updates

def set_differentiability(self, differentiable: bool) -> None:
if self.is_tensor:
assert isinstance(self._value, Tensor)
self._value.differentiable = differentiable
elif differentiable:
def set_differentiability(self, differentiable: bool) -> Updates:
if self.is_scalar and differentiable:
raise ValueError("Non-tensor edges cannot be differentiable.")

updates = Updates()
if differentiable:
# Differentiable edges can only be Tensor[float] type.
updates |= self.set_type(Tensor[float])
# Set differentiability of the _value if it is a Tensor.
if isinstance(self._value, Tensor):
self._value.differentiable = differentiable
return updates

def add_constraint(self, constraint: Constraint) -> None:
for type in constraint.types:
self.constraints[type].add(constraint)
Expand Down
Loading

0 comments on commit d506aae

Please sign in to comment.