Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update handling of assigned attributes. #220

Merged
4 changes: 2 additions & 2 deletions examples/flux/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def load_t5_encoder(
repo_id: str = "black-forest-labs/FLUX.1-schnell",
max_len: int = 256,
) -> ml.models.PhysicalModel:
config = hf_hub_download(repo_id, "text_encoder_2/config.json")
config_path = hf_hub_download(repo_id, "text_encoder_2/config.json")

with open(config) as f:
with open(config_path) as f:
config = json.load(f)

t5 = t5_encode(config, name="encoder")
Expand Down
12 changes: 6 additions & 6 deletions examples/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def load_flow_model(name: str, backend: ml.Backend, hf_download: bool = True):
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and (r_id := configs[name].repo_id) is not None
and (r_flow := configs[name].repo_flow) is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
ckpt_path = hf_hub_download(r_id, r_flow)

flux_lm = flux(configs[name].params)
flux_pm = ml.compile(
Expand All @@ -187,11 +187,11 @@ def load_decoder(
ckpt_path = configs[name].ae_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and (r_id := configs[name].repo_id) is not None
and (r_ae := configs[name].repo_ae) is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
ckpt_path = hf_hub_download(r_id, r_ae)

# Loading the autoencoder
print("Init AE")
Expand Down
2 changes: 1 addition & 1 deletion mithril/cores/python/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def scaled_dot_product_attention(
)

L, S = query.shape[-2], key.shape[-2]
scale_factor = 1 / np.sqrt(query.shape[-1]) if scale is None else scale
scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
write_into_cache(cache, "scale_factor", scale_factor)
attn_bias = np.zeros((L, S), dtype=query.dtype)
if is_causal:
Expand Down
44 changes: 33 additions & 11 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def __init__(
value: TensorValueType | ToBeDetermined = TBD,
type: _TensorTypes = int | float | bool,
shape: ShapeNode | None = None,
differentiable: bool = False,
differentiable: bool | None = None,
):
if shape is None:
# If shape is not provided, create a new shape with a Variadic root.
Expand Down Expand Up @@ -1119,8 +1119,16 @@ def match(self, other: Tensor[int | float | bool]) -> Updates:
updates |= non_valued.set_value(valued.value)
self.differentiable = False
other.differentiable = False
else:
self.differentiable |= other.differentiable
elif self.differentiable is None:
self.differentiable = other.differentiable
# Differentiable tensors can only be float type.
if self.differentiable:
updates |= self.set_type(float)
elif (
other.differentiable is not None
and self.differentiable != other.differentiable
):
raise ValueError("Differentiability mismatch!")
# Match shapes.
updates |= self.match_shapes(other.shape)
updates.shape_updates.discard(other)
Expand Down Expand Up @@ -1180,8 +1188,16 @@ def _temp_shape(self) -> ShapeRepr | None:
return None

@property
def differentiable(self) -> bool:
return isinstance(self._value, Tensor) and self._value.differentiable
def differentiable(self) -> bool | None:
if isinstance(self._value, Tensor):
return self._value.differentiable
elif self.is_scalar:
# Scalars are always non-differentiable.
return False
# Differentiability of polymorphic edges are defined
# as None. Depending on its instant type updates, it can
# become True or False (e.g Tensor or int type).
return None

@property
def tensors(self) -> set[Tensor[int | float | bool]]:
Expand Down Expand Up @@ -1405,7 +1421,7 @@ def set_value(
updates.value_updates.add(self)
# Update new type without automatic tensor value creation.
updates |= self.set_type(find_type(self._value), create_tensor=False)
if self.is_valued:
if self.is_tensor and self.is_valued:
self.set_differentiability(False)
return updates

Expand All @@ -1431,13 +1447,19 @@ def match(self, other: IOHyperEdge) -> Updates:

return updates

def set_differentiability(self, differentiable: bool) -> None:
if self.is_tensor:
assert isinstance(self._value, Tensor)
self._value.differentiable = differentiable
elif differentiable:
def set_differentiability(self, differentiable: bool) -> Updates:
if self.is_scalar and differentiable:
raise ValueError("Non-tensor edges cannot be differentiable.")

updates = Updates()
if differentiable:
# Differentiable edges can only be Tensor[float] type.
updates |= self.set_type(Tensor[float])
# Set differentiability of the _value if it is a Tensor.
if isinstance(self._value, Tensor):
self._value.differentiable = differentiable
return updates

def add_constraint(self, constraint: Constraint) -> None:
for type in constraint.types:
self.constraints[type].add(constraint)
Expand Down
Loading