diff --git a/examples/flux/main.py b/examples/flux/main.py index 9191d966..a7f3237e 100644 --- a/examples/flux/main.py +++ b/examples/flux/main.py @@ -19,7 +19,8 @@ import numpy as np from PIL import Image from sampling import denoise, get_noise, get_schedule, prepare, rearrange, unpack -from util import configs, load_clip, load_decoder, load_flow_model, load_t5 +from t5 import download_t5_encoder_weights, load_t5_encoder, load_t5_tokenizer +from util import configs, load_clip, load_decoder, load_flow_model import mithril as ml @@ -76,9 +77,11 @@ def run( backend = ml.TorchBackend(device="cuda", dtype=ml.bfloat16) backend.seed = seed - t5 = load_t5( - device=device, max_length=256 if model_name == "flux-schnell" else 512 - ).to("cuda") + t5 = load_t5_encoder(backend) + t5_tokenizer = load_t5_tokenizer(backend, pad=False) + t5_np_weights = download_t5_encoder_weights(backend) + t5_weights = {key: backend.array(value) for key, value in t5_np_weights.items()} + clip = load_clip(device=device).to("cuda") flow_model, flow_params = load_flow_model(model_name, backend=backend) @@ -94,7 +97,9 @@ def run( ) noise = get_noise(1, opts.height, opts.width, backend) - inp = prepare(t5, clip, noise, prompt=opts.prompt, backend=backend) + inp = prepare( + t5, t5_weights, t5_tokenizer, clip, noise, prompt=opts.prompt, backend=backend + ) timesteps = get_schedule( opts.num_steps, @@ -112,7 +117,7 @@ def run( x = x.clamp(-1, 1) # TODO: add to backend x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray(np.array(127.5 * (x.cpu() + 1.0)).astype(np.uint8)) - img.save("qwe.png") + img.save("img.png") if __name__ == "__main__": diff --git a/examples/flux/sampling.py b/examples/flux/sampling.py index c08cb7a2..9406caa2 100644 --- a/examples/flux/sampling.py +++ b/examples/flux/sampling.py @@ -39,6 +39,8 @@ def get_noise( def prepare( t5: HFEmbedder, + t5_weights: dict, + t5_tokenizer, clip: HFEmbedder, img: torch.Tensor, prompt: str | list[str], @@ -61,7 +63,8 @@ def prepare( if isinstance(prompt, str): prompt = [prompt] - txt = t5(prompt) + t5_prompt = t5_tokenizer.encode(prompt) + txt = t5(t5_weights, {"input": t5_prompt})["output"] if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = backend.zeros(bs, txt.shape[1], 3) diff --git a/examples/flux/t5.py b/examples/flux/t5.py new file mode 100644 index 00000000..e4c78575 --- /dev/null +++ b/examples/flux/t5.py @@ -0,0 +1,211 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +from typing import Any + +import numpy as np +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from sentencepiece import SentencePieceProcessor + +import mithril as ml +from examples.t5 import t5_encode + + +def sanitize(weights): + shared_replacement_patterns = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), + ] + + encoder_replacement_patterns = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), + ] + + decoder_replacement_patterns = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), + ] + + ignored_keys = ["decoder_layers_0_cross_attention_relative_attention_bias_weight"] + + def replace_key(key: str) -> str: + for old, new in shared_replacement_patterns: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in encoder_replacement_patterns: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in decoder_replacement_patterns: + key = key.replace(old, new) + return key.replace(".", "_") + + weights = {replace_key(k): v for k, v in weights.items()} + for key in ignored_keys: + if key in weights: + del weights[key] + return weights + + +class T5Tokenizer: + def __init__(self, model_file, backend: ml.Backend, max_length=512): + self._tokenizer = SentencePieceProcessor(model_file) + self.max_length = max_length + self.backend = backend + + @property + def pad(self): + try: + return self._tokenizer.id_to_piece(self.pad_token) + except IndexError: + return None + + @property + def pad_token(self): + return self._tokenizer.pad_id() + + @property + def bos(self): + try: + return self._tokenizer.id_to_piece(self.bos_token) + except IndexError: + return None + + @property + def bos_token(self): + return self._tokenizer.bos_id() + + @property + def eos(self): + try: + return self._tokenizer.id_to_piece(self.eos_token) + except IndexError: + return None + + @property + def eos_token(self): + return self._tokenizer.eos_id() + + def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text] + + tokens = self._tokenizer.encode(text) + + if prepend_bos and self.bos_token >= 0: + tokens = [self.bos_token] + tokens + if append_eos and self.eos_token >= 0: + tokens.append(self.eos_token) + if pad and len(tokens) < self.max_length and self.pad_token >= 0: + tokens += [self.pad_token] * (self.max_length - len(tokens)) + + return tokens + + def encode(self, text, pad=True): + if not isinstance(text, list): + return self.encode([text], pad=pad) + + pad_token = self.pad_token if self.pad_token >= 0 else 0 + tokens = self.tokenize(text, pad=pad) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([pad_token] * (length - len(t))) + + return self.backend.array(tokens) + + +sys.setrecursionlimit(3500) + + +def download_t5_encoder_weights( + backend: ml.Backend, repo_id: str = "black-forest-labs/FLUX.1-schnell" +) -> dict[str, np.ndarray[Any, Any]]: + model_index = hf_hub_download( + repo_id, "text_encoder_2/model.safetensors.index.json" + ) + + weight_files = set() + with open(model_index) as f: + for _, w in json.load(f)["weight_map"].items(): + weight_files.add(w) + + if backend.backend_type == "torch": + target_lib = "pt" + elif backend.backend_type == "jax": + target_lib = "jax" + else: + # TODO Fix here + raise NotImplementedError("T5 encoder only supported for Jax and Torch!") + + weights = {} + for w in weight_files: + w = hf_hub_download(repo_id, f"text_encoder_2/{w}") + safe_tensors = safe_open(w, target_lib) + for key in safe_tensors.keys(): # type: ignore # noqa + weights[key] = safe_tensors.get_tensor(key) # type: ignore + + return sanitize(weights) + + +def load_t5_encoder( + backend: ml.Backend, + repo_id: str = "black-forest-labs/FLUX.1-schnell", + max_len: int = 256, +) -> ml.models.PhysicalModel: + config = hf_hub_download(repo_id, "text_encoder_2/config.json") + + with open(config) as f: + config = json.load(f) + + t5 = t5_encode(config, name="encoder") + + # model = ml.models.Model() + # model |= t5(input="input", output="output") + + encoder_pm = ml.compile( + t5, + backend, + data_keys={"input"}, + shapes={"input": [1, max_len]}, + jit=False, + use_short_namings=False, + ) + + return encoder_pm + + +def load_t5_tokenizer( + backend: ml.Backend, + repo_id: str = "black-forest-labs/FLUX.1-schnell", + pad: bool = True, +): + model_file = hf_hub_download(repo_id, "tokenizer_2/spiece.model") + return T5Tokenizer(model_file, backend, 256 if "schnell" in repo_id else 512) diff --git a/examples/t5.py b/examples/t5.py index 605f47c0..a3293466 100644 --- a/examples/t5.py +++ b/examples/t5.py @@ -148,7 +148,7 @@ def rms_norm(dim: int, *, name: str | None = None): def dense_activation(config: dict[str, Any], *, name: str | None = None): mlp_dims = config["d_ff"] or config["d_model"] * 4 - is_gated = hasattr(config, "feed_forward_proj") + is_gated = "feed_forward_proj" in config activation_name = ( "relu" if not is_gated else config["feed_forward_proj"].removeprefix("gated-") ) @@ -166,12 +166,13 @@ def dense_activation(config: dict[str, Any], *, name: str | None = None): input = IOKey("input") if is_gated: - block += Linear(mlp_dims, use_bias=False, name="wi_0")(input) + block |= Linear(mlp_dims, use_bias=False, name="wi_0")(input) block += activation(output="hidden_act") - block += Linear(mlp_dims, use_bias=False, name="wi_1")(input, output="lin_out") - block += Multiply()(left="hidden_act", right="lin_out", output="hidden_out") + block |= Linear(mlp_dims, use_bias=False, name="wi_1")(input, output="lin_out") + block |= Multiply()(left="hidden_act", right="lin_out", output="hidden_out") + block.set_cout("hidden_out") else: - block += Linear(mlp_dims, name="wi", use_bias=False)(input) + block |= Linear(mlp_dims, name="wi", use_bias=False)(input) block += Relu()(output="hidden_out") block += Linear(config["d_model"], name="wo", use_bias=False)( @@ -475,7 +476,7 @@ def sanitize(weights): (".layer.2.DenseReluDense.", ".dense."), ] - ignored_keys = ["decoder.layers.0.cross_attention.relative_attention_bias.weight"] + ignored_keys = ["decoder_layers_0_cross_attention_relative_attention_bias_weight"] def replace_key(key: str) -> str: for old, new in shared_replacement_patterns: @@ -486,7 +487,7 @@ def replace_key(key: str) -> str: elif key.startswith("decoder."): for old, new in decoder_replacement_patterns: key = key.replace(old, new) - return key + return key.replace(".", "_") weights = {replace_key(k): v for k, v in weights.items()} for key in ignored_keys: