Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Integrate T5 in FLUX #212

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions examples/flux/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import numpy as np
from PIL import Image
from sampling import denoise, get_noise, get_schedule, prepare, rearrange, unpack
from util import configs, load_clip, load_decoder, load_flow_model, load_t5
from t5 import download_t5_encoder_weights, load_t5_encoder, load_t5_tokenizer
from util import configs, load_clip, load_decoder, load_flow_model

import mithril as ml

Expand Down Expand Up @@ -76,9 +77,11 @@ def run(
backend = ml.TorchBackend(device="cuda", dtype=ml.bfloat16)
backend.seed = seed

t5 = load_t5(
device=device, max_length=256 if model_name == "flux-schnell" else 512
).to("cuda")
t5 = load_t5_encoder(backend)
t5_tokenizer = load_t5_tokenizer(backend, pad=False)
t5_np_weights = download_t5_encoder_weights(backend)
t5_weights = {key: backend.array(value) for key, value in t5_np_weights.items()}

clip = load_clip(device=device).to("cuda")

flow_model, flow_params = load_flow_model(model_name, backend=backend)
Expand All @@ -94,7 +97,9 @@ def run(
)

noise = get_noise(1, opts.height, opts.width, backend)
inp = prepare(t5, clip, noise, prompt=opts.prompt, backend=backend)
inp = prepare(
t5, t5_weights, t5_tokenizer, clip, noise, prompt=opts.prompt, backend=backend
)

timesteps = get_schedule(
opts.num_steps,
Expand All @@ -112,7 +117,7 @@ def run(
x = x.clamp(-1, 1) # TODO: add to backend
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray(np.array(127.5 * (x.cpu() + 1.0)).astype(np.uint8))
img.save("qwe.png")
img.save("img.png")


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion examples/flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def get_noise(

def prepare(
t5: HFEmbedder,
t5_weights: dict,
t5_tokenizer,
clip: HFEmbedder,
img: torch.Tensor,
prompt: str | list[str],
Expand All @@ -61,7 +63,8 @@ def prepare(
if isinstance(prompt, str):
prompt = [prompt]

txt = t5(prompt)
t5_prompt = t5_tokenizer.encode(prompt)
txt = t5(t5_weights, {"input": t5_prompt})["output"]
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = backend.zeros(bs, txt.shape[1], 3)
Expand Down
211 changes: 211 additions & 0 deletions examples/flux/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2022 Synnada, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import sys
from typing import Any

import numpy as np
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from sentencepiece import SentencePieceProcessor

import mithril as ml
from examples.t5 import t5_encode


def sanitize(weights):
shared_replacement_patterns = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]

encoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]

decoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]

ignored_keys = ["decoder_layers_0_cross_attention_relative_attention_bias_weight"]

def replace_key(key: str) -> str:
for old, new in shared_replacement_patterns:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in encoder_replacement_patterns:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in decoder_replacement_patterns:
key = key.replace(old, new)
return key.replace(".", "_")

weights = {replace_key(k): v for k, v in weights.items()}
for key in ignored_keys:
if key in weights:
del weights[key]
return weights


class T5Tokenizer:
def __init__(self, model_file, backend: ml.Backend, max_length=512):
self._tokenizer = SentencePieceProcessor(model_file)
self.max_length = max_length
self.backend = backend

@property
def pad(self):
try:
return self._tokenizer.id_to_piece(self.pad_token)
except IndexError:
return None

@property
def pad_token(self):
return self._tokenizer.pad_id()

@property
def bos(self):
try:
return self._tokenizer.id_to_piece(self.bos_token)
except IndexError:
return None

@property
def bos_token(self):
return self._tokenizer.bos_id()

@property
def eos(self):
try:
return self._tokenizer.id_to_piece(self.eos_token)
except IndexError:
return None

@property
def eos_token(self):
return self._tokenizer.eos_id()

def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]

tokens = self._tokenizer.encode(text)

if prepend_bos and self.bos_token >= 0:
tokens = [self.bos_token] + tokens
if append_eos and self.eos_token >= 0:
tokens.append(self.eos_token)
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
tokens += [self.pad_token] * (self.max_length - len(tokens))

return tokens

def encode(self, text, pad=True):
if not isinstance(text, list):
return self.encode([text], pad=pad)

pad_token = self.pad_token if self.pad_token >= 0 else 0
tokens = self.tokenize(text, pad=pad)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([pad_token] * (length - len(t)))

return self.backend.array(tokens)


sys.setrecursionlimit(3500)


def download_t5_encoder_weights(
backend: ml.Backend, repo_id: str = "black-forest-labs/FLUX.1-schnell"
) -> dict[str, np.ndarray[Any, Any]]:
model_index = hf_hub_download(
repo_id, "text_encoder_2/model.safetensors.index.json"
)

weight_files = set()
with open(model_index) as f:
for _, w in json.load(f)["weight_map"].items():
weight_files.add(w)

if backend.backend_type == "torch":
target_lib = "pt"
elif backend.backend_type == "jax":
target_lib = "jax"
else:
# TODO Fix here
raise NotImplementedError("T5 encoder only supported for Jax and Torch!")

weights = {}
for w in weight_files:
w = hf_hub_download(repo_id, f"text_encoder_2/{w}")
safe_tensors = safe_open(w, target_lib)
for key in safe_tensors.keys(): # type: ignore # noqa
weights[key] = safe_tensors.get_tensor(key) # type: ignore

return sanitize(weights)


def load_t5_encoder(
backend: ml.Backend,
repo_id: str = "black-forest-labs/FLUX.1-schnell",
max_len: int = 256,
) -> ml.models.PhysicalModel:
config = hf_hub_download(repo_id, "text_encoder_2/config.json")

with open(config) as f:
config = json.load(f)

t5 = t5_encode(config, name="encoder")

# model = ml.models.Model()
# model |= t5(input="input", output="output")

encoder_pm = ml.compile(
t5,
backend,
data_keys={"input"},
shapes={"input": [1, max_len]},
jit=False,
use_short_namings=False,
)

return encoder_pm


def load_t5_tokenizer(
backend: ml.Backend,
repo_id: str = "black-forest-labs/FLUX.1-schnell",
pad: bool = True,
):
model_file = hf_hub_download(repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, backend, 256 if "schnell" in repo_id else 512)
15 changes: 8 additions & 7 deletions examples/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def rms_norm(dim: int, *, name: str | None = None):
def dense_activation(config: dict[str, Any], *, name: str | None = None):
mlp_dims = config["d_ff"] or config["d_model"] * 4

is_gated = hasattr(config, "feed_forward_proj")
is_gated = "feed_forward_proj" in config
activation_name = (
"relu" if not is_gated else config["feed_forward_proj"].removeprefix("gated-")
)
Expand All @@ -166,12 +166,13 @@ def dense_activation(config: dict[str, Any], *, name: str | None = None):
input = IOKey("input")

if is_gated:
block += Linear(mlp_dims, use_bias=False, name="wi_0")(input)
block |= Linear(mlp_dims, use_bias=False, name="wi_0")(input)
block += activation(output="hidden_act")
block += Linear(mlp_dims, use_bias=False, name="wi_1")(input, output="lin_out")
block += Multiply()(left="hidden_act", right="lin_out", output="hidden_out")
block |= Linear(mlp_dims, use_bias=False, name="wi_1")(input, output="lin_out")
block |= Multiply()(left="hidden_act", right="lin_out", output="hidden_out")
block.set_cout("hidden_out")
else:
block += Linear(mlp_dims, name="wi", use_bias=False)(input)
block |= Linear(mlp_dims, name="wi", use_bias=False)(input)
block += Relu()(output="hidden_out")

block += Linear(config["d_model"], name="wo", use_bias=False)(
Expand Down Expand Up @@ -475,7 +476,7 @@ def sanitize(weights):
(".layer.2.DenseReluDense.", ".dense."),
]

ignored_keys = ["decoder.layers.0.cross_attention.relative_attention_bias.weight"]
ignored_keys = ["decoder_layers_0_cross_attention_relative_attention_bias_weight"]

def replace_key(key: str) -> str:
for old, new in shared_replacement_patterns:
Expand All @@ -486,7 +487,7 @@ def replace_key(key: str) -> str:
elif key.startswith("decoder."):
for old, new in decoder_replacement_patterns:
key = key.replace(old, new)
return key
return key.replace(".", "_")

weights = {replace_key(k): v for k, v in weights.items()}
for key in ignored_keys:
Expand Down