-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathmodels.py
1964 lines (1766 loc) · 91.7 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from diffusers import DiffusionPipeline
from diffusers.loaders import LoraLoaderMixin
from diffusers.models import (
AutoencoderKL,
AutoencoderKLTemporalDecoder,
ControlNetModel,
UNet2DConditionModel,
UNetSpatioTemporalConditionModel,
StableCascadeUNet
)
from diffusers.pipelines.wuerstchen import PaellaVQModel
import json
import numpy as np
import onnx
from onnx import numpy_helper, shape_inference
import onnx_graphsurgeon as gs
from safetensors import safe_open
import os
from polygraphy.backend.onnx.loader import fold_constants
import re
import tempfile
import torch
import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from huggingface_hub import hf_hub_download
from utilities import merge_loras
from utils_sd3.sd3_impls import BaseModel as BaseModelSD3
from utils_sd3.sd3_impls import SDVAE
from utils_sd3.other_impls import load_into, SDClipModel, SDXLClipG, T5XXLModel
from utils_modelopt import (
convert_zp_fp8,
cast_resize_io,
convert_fp16_io,
cast_fp8_mha_io,
)
from onnxmltools.utils.float16_converter import convert_float_to_float16
class Optimizer():
def __init__(
self,
onnx_graph,
verbose=False
):
self.graph = gs.import_onnx(onnx_graph)
self.verbose = verbose
def info(self, prefix):
if self.verbose:
print(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
def cleanup(self, return_onnx=False):
self.graph.cleanup().toposort()
return gs.export_onnx(self.graph) if return_onnx else self.graph
def select_outputs(self, keep, names=None):
self.graph.outputs = [self.graph.outputs[o] for o in keep]
if names:
for i, name in enumerate(names):
self.graph.outputs[i].name = name
def fold_constants(self, return_onnx=False):
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
self.graph = gs.import_onnx(onnx_graph)
if return_onnx:
return onnx_graph
def infer_shapes(self, return_onnx=False):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
temp_dir = tempfile.TemporaryDirectory().name
os.makedirs(temp_dir, exist_ok=True)
onnx_orig_path = os.path.join(temp_dir, 'model.onnx')
onnx_inferred_path = os.path.join(temp_dir, 'inferred.onnx')
onnx.save_model(onnx_graph,
onnx_orig_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=False)
onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path)
onnx_graph = onnx.load(onnx_inferred_path)
else:
onnx_graph = shape_inference.infer_shapes(onnx_graph)
self.graph = gs.import_onnx(onnx_graph)
if return_onnx:
return onnx_graph
def clip_add_hidden_states(self, hidden_layer_offset, return_onnx=False):
hidden_layers = -1
onnx_graph = gs.export_onnx(self.graph)
for i in range(len(onnx_graph.graph.node)):
for j in range(len(onnx_graph.graph.node[i].output)):
name = onnx_graph.graph.node[i].output[j]
if "layers" in name:
hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers)
for i in range(len(onnx_graph.graph.node)):
for j in range(len(onnx_graph.graph.node[i].output)):
if onnx_graph.graph.node[i].output[j] == "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers+hidden_layer_offset):
onnx_graph.graph.node[i].output[j] = "hidden_states"
for j in range(len(onnx_graph.graph.node[i].input)):
if onnx_graph.graph.node[i].input[j] == "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers+hidden_layer_offset):
onnx_graph.graph.node[i].input[j] = "hidden_states"
if return_onnx:
return onnx_graph
def fuse_mha_qkv_int8_sq(self):
tensors = self.graph.tensors()
keys = tensors.keys()
# mha : fuse QKV QDQ nodes
# mhca : fuse KV QDQ nodes
q_pat = (
"/down_blocks.\\d+/attentions.\\d+/transformer_blocks"
".\\d+/attn\\d+/to_q/input_quantizer/DequantizeLinear_output_0"
)
k_pat = (
"/down_blocks.\\d+/attentions.\\d+/transformer_blocks"
".\\d+/attn\\d+/to_k/input_quantizer/DequantizeLinear_output_0"
)
v_pat = (
"/down_blocks.\\d+/attentions.\\d+/transformer_blocks"
".\\d+/attn\\d+/to_v/input_quantizer/DequantizeLinear_output_0"
)
qs = list(sorted(map(
lambda x: x.group(0), # type: ignore
filter(lambda x: x is not None, [re.match(q_pat, key) for key in keys]),
)))
ks = list(sorted(map(
lambda x: x.group(0), # type: ignore
filter(lambda x: x is not None, [re.match(k_pat, key) for key in keys]),
)))
vs = list(sorted(map(
lambda x: x.group(0), # type: ignore
filter(lambda x: x is not None, [re.match(v_pat, key) for key in keys]),
)))
removed = 0
assert len(qs) == len(ks) == len(vs), "Failed to collect tensors"
for q, k, v in zip(qs, ks, vs):
is_mha = all(["attn1" in tensor for tensor in [q, k, v]])
is_mhca = all(["attn2" in tensor for tensor in [q, k, v]])
assert (is_mha or is_mhca) and (not (is_mha and is_mhca))
if is_mha:
tensors[k].outputs[0].inputs[0] = tensors[q]
tensors[v].outputs[0].inputs[0] = tensors[q]
del tensors[k]
del tensors[v]
removed += 2
else: # is_mhca
tensors[k].outputs[0].inputs[0] = tensors[v]
del tensors[k]
removed += 1
print(f"Removed {removed} QDQ nodes")
return removed # expected 72 for L2.5
def modify_fp8_graph(self):
onnx_graph = gs.export_onnx(self.graph)
# Convert INT8 Zero to FP8.
onnx_graph = convert_zp_fp8(onnx_graph)
# Convert weights and activations to FP16 and insert Cast nodes in FP8 MHA.
onnx_graph = convert_float_to_float16(onnx_graph, keep_io_types=True, disable_shape_infer=True)
self.graph = gs.import_onnx(onnx_graph)
# Add cast nodes to Resize I/O.
cast_resize_io(self.graph)
# Convert model inputs and outputs to fp16 I/O.
convert_fp16_io(self.graph)
# Add cast nodes to MHA's BMM1 and BMM2's I/O.
cast_fp8_mha_io(self.graph)
def get_path(version, pipeline, controlnets=None):
if controlnets is not None:
return ["lllyasviel/sd-controlnet-" + modality for modality in controlnets]
if version in ("1.4", "1.5") and pipeline.is_inpaint():
return "benjamin-paine/stable-diffusion-v1-5-inpainting"
elif version == "1.4":
return "CompVis/stable-diffusion-v1-4"
elif version == "1.5":
return "benjamin-paine/stable-diffusion-v1-5"
elif version == 'dreamshaper-7':
return 'Lykon/dreamshaper-7'
elif version in ("2.0-base", "2.0") and pipeline.is_inpaint():
return "stabilityai/stable-diffusion-2-inpainting"
elif version == "2.0-base":
return "stabilityai/stable-diffusion-2-base"
elif version == "2.0":
return "stabilityai/stable-diffusion-2"
elif version == "2.1-base":
return "stabilityai/stable-diffusion-2-1-base"
elif version == "2.1":
return "stabilityai/stable-diffusion-2-1"
elif version == 'xl-1.0' and pipeline.is_sd_xl_base():
return "stabilityai/stable-diffusion-xl-base-1.0"
elif version == 'xl-1.0' and pipeline.is_sd_xl_refiner():
return "stabilityai/stable-diffusion-xl-refiner-1.0"
# TODO SDXL turbo with refiner
elif version == 'xl-turbo' and pipeline.is_sd_xl_base():
return "stabilityai/sdxl-turbo"
elif version == 'sd3':
return "stabilityai/stable-diffusion-3-medium"
elif version == 'svd-xt-1.1' and pipeline.is_img2vid():
return "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif version == 'cascade':
if pipeline.is_cascade_decoder():
return "stabilityai/stable-cascade"
else:
return "stabilityai/stable-cascade-prior"
else:
raise ValueError(f"Unsupported version {version} + pipeline {pipeline.name}")
def get_clip_embedding_dim(version, pipeline):
if version in ("1.4", "1.5", "dreamshaper-7"):
return 768
elif version in ("2.0", "2.0-base", "2.1", "2.1-base"):
return 1024
elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_base():
return 768
elif version in ("sd3"):
return 4096
else:
raise ValueError(f"Invalid version {version} + pipeline {pipeline}")
def get_clipwithproj_embedding_dim(version, pipeline):
if version in ("xl-1.0", "xl-turbo", "cascade"):
return 1280
else:
raise ValueError(f"Invalid version {version} + pipeline {pipeline}")
def get_unet_embedding_dim(version, pipeline):
if version in ("1.4", "1.5", "dreamshaper-7"):
return 768
elif version in ("2.0", "2.0-base", "2.1", "2.1-base"):
return 1024
elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_base():
return 2048
elif version in ("cascade"):
return 1280
elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_refiner():
return 1280
elif pipeline.is_img2vid():
return 1024
else:
raise ValueError(f"Invalid version {version} + pipeline {pipeline}")
# FIXME serialization not supported for torch.compile
def get_checkpoint_dir(framework_model_dir, version, pipeline, subfolder):
return os.path.join(framework_model_dir, version, pipeline, subfolder)
torch_inference_modes = ['default', 'reduce-overhead', 'max-autotune']
# FIXME update callsites after serialization support for torch.compile is added
def optimize_checkpoint(model, torch_inference):
if not torch_inference or torch_inference == 'eager':
return model
assert torch_inference in torch_inference_modes
return torch.compile(model, mode=torch_inference, dynamic=False, fullgraph=False)
class LoraLoader(LoraLoaderMixin):
def __init__(self,
paths,
):
self.paths = paths
self.state_dict = dict()
self.network_alphas = dict()
for path in paths:
state_dict, network_alphas = self.lora_state_dict(path)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.state_dict[path] = state_dict
self.network_alphas[path] = network_alphas
def get_dicts(self,
prefix='unet',
convert_to_diffusers=False,
):
state_dict = dict()
network_alphas = dict()
for path in self.paths:
keys = list(self.state_dict[path].keys())
if all(key.startswith(('unet', 'text_encoder')) for key in keys):
keys = [k for k in keys if k.startswith(prefix)]
if keys:
print(f"Processing {prefix} LoRA: {path}")
state_dict[path] = {k.replace(f"{prefix}.", ""): v for k, v in self.state_dict[path].items() if k in keys}
network_alphas[path] = None
if path in self.network_alphas and self.network_alphas[path] is not None:
alpha_keys = [k for k in self.network_alphas[path].keys() if k.startswith(prefix)]
network_alphas[path] = {
k.replace(f"{prefix}.", ""): v for k, v in self.network_alphas[path].items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format.
warn_message = "You have saved the LoRA weights using the old format. To convert LoRA weights to the new format, first load them in a dictionary and then create a new dictionary as follows: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
print(warn_message)
return state_dict, network_alphas
class BaseModel():
def __init__(self,
version='1.5',
pipeline=None,
device='cuda',
hf_token='',
verbose=True,
framework_model_dir='pytorch_model',
fp16=False,
bf16=False,
int8=False,
fp8=False,
max_batch_size=16,
text_maxlen=77,
embedding_dim=768,
compression_factor=8
):
self.name = self.__class__.__name__
self.pipeline = pipeline.name
self.version = version
self.path = get_path(version, pipeline)
self.device = device
self.hf_token = hf_token
self.hf_safetensor = not (pipeline.is_inpaint() and version in ("1.4", "1.5"))
self.verbose = verbose
self.framework_model_dir = framework_model_dir
self.fp16 = fp16
self.bf16 = bf16
self.int8 = int8
self.fp8 = fp8
self.compression_factor = compression_factor
self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_latent_shape = self.min_image_shape // self.compression_factor
self.max_latent_shape = self.max_image_shape // self.compression_factor
self.text_maxlen = text_maxlen
self.embedding_dim = embedding_dim
self.extra_output_names = []
self.lora_dict = None
self.do_constant_folding = True
def get_pipeline(self):
model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {}
model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else model_opts
return DiffusionPipeline.from_pretrained(
self.path,
use_safetensors=self.hf_safetensor,
use_auth_token=self.hf_token,
**model_opts,
).to(self.device)
def get_model_path(self, model_dir, model_opts, model_name="diffusion_pytorch_model"):
variant = "." + model_opts.get("variant") if "variant" in model_opts else ""
suffix = ".safetensors" if self.hf_safetensor else ".bin"
model_file = model_name + variant + suffix
return os.path.join(model_dir, model_file)
def get_model(self, torch_inference=''):
pass
def get_input_names(self):
pass
def get_output_names(self):
pass
def get_dynamic_axes(self):
return None
def get_sample_input(self, batch_size, image_height, image_width, static_shape):
pass
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
return None
def get_shape_dict(self, batch_size, image_height, image_width):
return None
# Helper utility for ONNX export
def export_onnx(
self,
onnx_path,
onnx_opt_path,
onnx_opset,
opt_image_height,
opt_image_width,
custom_model=None,
enable_lora_merge=False,
static_shape=False,
):
onnx_opt_graph = None
# Export optimized ONNX model (if missing)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
print(f"[I] Exporting ONNX model: {onnx_path}")
def export_onnx(model):
if enable_lora_merge:
model = merge_loras(model, self.lora_dict, self.lora_alphas, self.lora_scales)
inputs = self.get_sample_input(1, opt_image_height, opt_image_width, static_shape)
torch.onnx.export(model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=self.do_constant_folding,
input_names=self.get_input_names(),
output_names=self.get_output_names(),
dynamic_axes=self.get_dynamic_axes(),
)
if custom_model:
with torch.inference_mode():
export_onnx(custom_model)
else:
with torch.inference_mode(), torch.autocast("cuda"):
export_onnx(self.get_model())
else:
print(f"[I] Found cached ONNX model: {onnx_path}")
print(f"[I] Optimizing ONNX model: {onnx_opt_path}")
onnx_opt_graph = self.optimize(onnx.load(onnx_path))
if onnx_opt_graph.ByteSize() > 2147483648:
onnx.save_model(
onnx_opt_graph,
onnx_opt_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=False)
else:
onnx.save(onnx_opt_graph, onnx_opt_path)
else:
print(f"[I] Found cached optimized ONNX model: {onnx_opt_path} ")
# Helper utility for weights map
def export_weights_map(self, onnx_opt_path, weights_map_path):
if not os.path.exists(weights_map_path):
onnx_opt_dir = os.path.dirname(onnx_opt_path)
onnx_opt_model = onnx.load(onnx_opt_path)
state_dict = self.get_model().state_dict()
# Create initializer data hashes
initializer_hash_mapping = {}
for initializer in onnx_opt_model.graph.initializer:
initializer_data = numpy_helper.to_array(initializer, base_dir=onnx_opt_dir).astype(np.float16)
initializer_hash = hash(initializer_data.data.tobytes())
initializer_hash_mapping[initializer.name] = (initializer_hash, initializer_data.shape)
weights_name_mapping = {}
weights_shape_mapping = {}
# set to keep track of initializers already added to the name_mapping dict
initializers_mapped = set()
for wt_name, wt in state_dict.items():
# get weight hash
wt = wt.cpu().detach().numpy().astype(np.float16)
wt_hash = hash(wt.data.tobytes())
wt_t_hash = hash(np.transpose(wt).data.tobytes())
for initializer_name, (initializer_hash, initializer_shape) in initializer_hash_mapping.items():
# Due to constant folding, some weights are transposed during export
# To account for the transpose op, we compare the initializer hash to the
# hash for the weight and its transpose
if wt_hash == initializer_hash or wt_t_hash == initializer_hash:
# The assert below ensures there is a 1:1 mapping between
# PyTorch and ONNX weight names. It can be removed in cases where 1:many
# mapping is found and name_mapping[wt_name] = list()
assert initializer_name not in initializers_mapped
weights_name_mapping[wt_name] = initializer_name
initializers_mapped.add(initializer_name)
is_transpose = False if wt_hash == initializer_hash else True
weights_shape_mapping[wt_name] = (initializer_shape, is_transpose)
# Sanity check: Were any weights not matched
if wt_name not in weights_name_mapping:
print(f'[I] PyTorch weight {wt_name} not matched with any ONNX initializer')
print(f'[I] {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers')
assert weights_name_mapping.keys() == weights_shape_mapping.keys()
with open(weights_map_path, 'w') as fp:
json.dump([weights_name_mapping, weights_shape_mapping], fp)
else:
print(f"[I] Found cached weights map: {weights_map_path} ")
def optimize(self, onnx_graph, return_onnx=True, **kwargs):
opt = Optimizer(onnx_graph, verbose=self.verbose)
opt.info(self.name + ': original')
opt.cleanup()
opt.info(self.name + ': cleanup')
if kwargs.get('modify_fp8_graph', False):
opt.modify_fp8_graph()
opt.info(self.name + ': modify fp8 graph')
else:
opt.fold_constants()
opt.info(self.name + ': fold constants')
opt.infer_shapes()
opt.info(self.name + ': shape inference')
if kwargs.get('fuse_mha_qkv_int8', False):
opt.fuse_mha_qkv_int8_sq()
opt.info(self.name + ': fuse QKV nodes')
onnx_opt_graph = opt.cleanup(return_onnx=return_onnx)
opt.info(self.name + ': finished')
return onnx_opt_graph
def check_dims(self, batch_size, image_height, image_width):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
latent_height = image_height // self.compression_factor
latent_width = image_width // self.compression_factor
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
latent_height = image_height // self.compression_factor
latent_width = image_width // self.compression_factor
min_image_height = image_height if static_shape else self.min_image_shape
max_image_height = image_height if static_shape else self.max_image_shape
min_image_width = image_width if static_shape else self.min_image_shape
max_image_width = image_width if static_shape else self.max_image_shape
min_latent_height = latent_height if static_shape else self.min_latent_shape
max_latent_height = latent_height if static_shape else self.max_latent_shape
min_latent_width = latent_width if static_shape else self.min_latent_shape
max_latent_width = latent_width if static_shape else self.max_latent_shape
return (min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width)
class CLIPModel(BaseModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
max_batch_size,
embedding_dim,
fp16=False,
bf16=False,
output_hidden_states=False,
subfolder="text_encoder",
lora_dict=None,
lora_alphas=None,
):
super(CLIPModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
self.subfolder = subfolder
self.hidden_layer_offset = 0 if pipeline.is_cascade() else -1
# Output the final hidden state
if output_hidden_states:
self.extra_output_names = ['hidden_states']
def get_model(self, torch_inference=''):
clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
if not os.path.exists(clip_model_dir):
model = CLIPTextModel.from_pretrained(self.path,
subfolder=self.subfolder,
use_safetensors=self.hf_safetensor,
use_auth_token=self.hf_token).to(self.device)
model.save_pretrained(clip_model_dir)
else:
print(f"[I] Load CLIPTextModel model from: {clip_model_dir}")
model = CLIPTextModel.from_pretrained(clip_model_dir).to(self.device)
model = optimize_checkpoint(model, torch_inference)
return model
def get_input_names(self):
return ['input_ids']
def get_output_names(self):
return ['text_embeddings']
def get_dynamic_axes(self):
return {
'input_ids': {0: 'B'},
'text_embeddings': {0: 'B'}
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
return {
'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
output = {
'input_ids': (batch_size, self.text_maxlen),
'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim)
}
if 'hidden_states' in self.extra_output_names:
output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
return output
def get_sample_input(self, batch_size, image_height, image_width, static_shape):
self.check_dims(batch_size, image_height, image_width)
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
def optimize(self, onnx_graph):
opt = Optimizer(onnx_graph, verbose=self.verbose)
opt.info(self.name + ': original')
opt.select_outputs([0]) # delete graph output#1
opt.cleanup()
opt.info(self.name + ': remove output[1]')
opt.fold_constants()
opt.info(self.name + ': fold constants')
opt.infer_shapes()
opt.info(self.name + ': shape inference')
opt.select_outputs([0], names=['text_embeddings']) # rename network output
opt.info(self.name + ': remove output[0]')
opt_onnx_graph = opt.cleanup(return_onnx=True)
if 'hidden_states' in self.extra_output_names:
opt_onnx_graph = opt.clip_add_hidden_states(self.hidden_layer_offset, return_onnx=True)
opt.info(self.name + ': added hidden_states')
opt.info(self.name + ': finished')
return opt_onnx_graph
class CLIPWithProjModel(CLIPModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
fp16=False,
bf16=False,
max_batch_size=16,
output_hidden_states=False,
subfolder="text_encoder_2",
lora_dict=None,
lora_alphas=None,
):
super(CLIPWithProjModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=get_clipwithproj_embedding_dim(version, pipeline), output_hidden_states=output_hidden_states)
self.subfolder = subfolder
def get_model(self, torch_inference=''):
model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else {}
clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
clip_path = self.get_model_path(clip_model_dir, model_opts, model_name='model')
if not os.path.exists(clip_path):
model = CLIPTextModelWithProjection.from_pretrained(self.path,
subfolder=self.subfolder,
use_safetensors=self.hf_safetensor,
use_auth_token=self.hf_token,
**model_opts).to(self.device)
model.save_pretrained(clip_model_dir, **model_opts)
else:
print(f"[I] Load CLIPTextModelWithProjection model from: {clip_path}")
model = CLIPTextModelWithProjection.from_pretrained(clip_model_dir, **model_opts).to(self.device)
model = optimize_checkpoint(model, torch_inference)
return model
def get_input_names(self):
return ['input_ids', 'attention_mask']
def get_output_names(self):
return ['text_embeddings']
def get_dynamic_axes(self):
return {
'input_ids': {0: 'B'},
'attention_mask': {0: 'B'},
'text_embeddings': {0: 'B'}
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
return {
'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)],
'attention_mask': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
output = {
'input_ids': (batch_size, self.text_maxlen),
'attention_mask': (batch_size, self.text_maxlen),
'text_embeddings': (batch_size, self.embedding_dim)
}
if 'hidden_states' in self.extra_output_names:
output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
return output
def get_sample_input(self, batch_size, image_height, image_width, static_shape):
self.check_dims(batch_size, image_height, image_width)
return (
torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),
torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
)
class SD3_CLIPGModel(CLIPModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
max_batch_size,
embedding_dim=None,
fp16=False,
pooled_output=False,
):
self.CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32
}
super(SD3_CLIPGModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, embedding_dim=self.CLIPG_CONFIG["hidden_size"] if embedding_dim is None else embedding_dim)
self.subfolder = 'text_encoders'
if pooled_output:
self.extra_output_names = ['pooled_output']
def get_model(self, torch_inference=''):
clip_g_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
clip_g_filename="clip_g.safetensors"
clip_g_model_path = f"{clip_g_model_dir}/{clip_g_filename}"
if not os.path.exists(clip_g_model_path):
hf_hub_download(
repo_id=self.path,
filename=clip_g_filename,
local_dir=get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, ''),
subfolder=self.subfolder
)
with safe_open(clip_g_model_path, framework="pt", device=self.device) as f:
dtype = torch.float16 if self.fp16 else torch.float32
model = SDXLClipG(self.CLIPG_CONFIG, device=self.device, dtype=dtype)
load_into(f, model.transformer, "", self.device, dtype)
model = optimize_checkpoint(model, torch_inference)
return model
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
output = {
'input_ids': (batch_size, self.text_maxlen),
'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim)
}
if 'pooled_output' in self.extra_output_names:
output["pooled_output"] = (batch_size, self.embedding_dim)
return output
def optimize(self, onnx_graph):
opt = Optimizer(onnx_graph, verbose=self.verbose)
opt.info(self.name + ': original')
opt.select_outputs([0, 1])
opt.cleanup()
opt.fold_constants()
opt.info(self.name + ': fold constants')
opt.infer_shapes()
opt.info(self.name + ': shape inference')
opt.select_outputs([0, 1], names=['text_embeddings', 'pooled_output']) # rename network output
opt.info(self.name + ': rename output[0] and output[1]')
opt_onnx_graph = opt.cleanup(return_onnx=True)
opt.info(self.name + ': finished')
return opt_onnx_graph
class SD3_CLIPLModel(SD3_CLIPGModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
max_batch_size,
fp16=False,
pooled_output=False,
):
self.CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12
}
super(SD3_CLIPLModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, embedding_dim=self.CLIPL_CONFIG["hidden_size"])
self.subfolder = 'text_encoders'
if pooled_output:
self.extra_output_names = ['pooled_output']
def get_model(self, torch_inference=''):
clip_l_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
clip_l_filename="clip_l.safetensors"
clip_l_model_path = f"{clip_l_model_dir}/{clip_l_filename}"
if not os.path.exists(clip_l_model_path):
hf_hub_download(
repo_id=self.path,
filename=clip_l_filename,
local_dir=get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, ''),
subfolder=self.subfolder
)
with safe_open(clip_l_model_path, framework="pt", device=self.device) as f:
dtype = torch.float16 if self.fp16 else torch.float32
model = SDClipModel(layer="hidden", layer_idx=-2, device=self.device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=self.CLIPL_CONFIG)
load_into(f, model.transformer, "", self.device, dtype)
model = optimize_checkpoint(model, torch_inference)
return model
class SD3_T5XXLModel(CLIPModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
max_batch_size,
embedding_dim,
fp16=False,
):
super(SD3_T5XXLModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
self.T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128
}
self.subfolder = 'text_encoders'
def get_model(self, torch_inference=''):
t5xxl_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
t5xxl_filename="t5xxl_fp16.safetensors"
t5xxl_model_path = f"{t5xxl_model_dir}/{t5xxl_filename}"
if not os.path.exists(t5xxl_model_path):
hf_hub_download(
repo_id=self.path,
filename=t5xxl_filename,
local_dir=get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, ''),
subfolder=self.subfolder
)
with safe_open(t5xxl_model_path, framework="pt", device=self.device) as f:
dtype = torch.float16 if self.fp16 else torch.float32
model = T5XXLModel(self.T5_CONFIG, device=self.device, dtype=dtype)
load_into(f, model.transformer, "", self.device, dtype)
model = optimize_checkpoint(model, torch_inference)
return model
class CLIPVisionWithProjModel(BaseModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
max_batch_size=1,
subfolder="image_encoder",
):
super(CLIPVisionWithProjModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, max_batch_size=max_batch_size)
self.subfolder = subfolder
def get_model(self, torch_inference=''):
clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
if not os.path.exists(clip_model_dir):
model = CLIPVisionModelWithProjection.from_pretrained(self.path,
subfolder=self.subfolder,
use_safetensors=self.hf_safetensor,
use_auth_token=self.hf_token).to(self.device)
model.save_pretrained(clip_model_dir)
else:
print(f"[I] Load CLIPVisionModelWithProjection model from: {clip_model_dir}")
model = CLIPVisionModelWithProjection.from_pretrained(clip_model_dir).to(self.device)
model = optimize_checkpoint(model, torch_inference)
return model
class CLIPImageProcessorModel(BaseModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
max_batch_size=1,
subfolder="feature_extractor",
):
super(CLIPImageProcessorModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, max_batch_size=max_batch_size)
self.subfolder = subfolder
def get_model(self, torch_inference=''):
clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
# NOTE to(device) not supported
if not os.path.exists(clip_model_dir):
model = CLIPImageProcessor.from_pretrained(self.path,
subfolder=self.subfolder,
use_safetensors=self.hf_safetensor,
use_auth_token=self.hf_token)
model.save_pretrained(clip_model_dir)
else:
print(f"[I] Load CLIPImageProcessor model from: {clip_model_dir}")
model = CLIPImageProcessor.from_pretrained(clip_model_dir)
model = optimize_checkpoint(model, torch_inference)
return model
class UNet2DConditionControlNetModel(torch.nn.Module):
def __init__(self, unet, controlnets) -> None:
super().__init__()
self.unet = unet
self.controlnets = controlnets
def forward(self, sample, timestep, encoder_hidden_states, images, controlnet_scales):
for i, (image, conditioning_scale, controlnet) in enumerate(zip(images, controlnet_scales, self.controlnets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=image,
return_dict=False,
)
down_samples = [
down_sample * conditioning_scale
for down_sample in down_samples
]
mid_sample *= conditioning_scale
# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
noise_pred = self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
return noise_pred
class UNetModel(BaseModel):
def __init__(self,
version,
pipeline,
device,
hf_token,
verbose,
framework_model_dir,
fp16 = False,
int8 = False,
fp8 = False,
max_batch_size = 16,
text_maxlen = 77,
controlnets = None,
lora_scales = None,
lora_dict = None,
lora_alphas = None,
do_classifier_free_guidance = False,
):