@@ -2727,11 +2727,16 @@ def one_unet_in_gpu(self, unet_number = None, unet = None):
2727
2727
if exists (unet_number ):
2728
2728
unet = self .get_unet (unet_number )
2729
2729
2730
+ # devices
2731
+
2732
+ cuda , cpu = torch .device ('cuda' ), torch .device ('cpu' )
2733
+
2730
2734
self .cuda ()
2731
2735
2732
2736
devices = [module_device (unet ) for unet in self .unets ]
2733
- self .unets .cpu ()
2734
- unet .cuda ()
2737
+
2738
+ self .unets .to (cpu )
2739
+ unet .to (cuda )
2735
2740
2736
2741
yield
2737
2742
@@ -3114,7 +3119,8 @@ def sample(
3114
3119
distributed = False ,
3115
3120
inpaint_image = None ,
3116
3121
inpaint_mask = None ,
3117
- inpaint_resample_times = 5
3122
+ inpaint_resample_times = 5 ,
3123
+ one_unet_in_gpu_at_time = True
3118
3124
):
3119
3125
assert self .unconditional or exists (image_embed ), 'image embed must be present on sampling from decoder unless if trained unconditionally'
3120
3126
@@ -3137,6 +3143,7 @@ def sample(
3137
3143
assert image .shape [0 ] == batch_size , 'image must have batch size of {} if starting at unet number > 1' .format (batch_size )
3138
3144
prev_unet_output_size = self .image_sizes [start_at_unet_number - 2 ]
3139
3145
img = resize_image_to (image , prev_unet_output_size , nearest = True )
3146
+
3140
3147
is_cuda = next (self .parameters ()).is_cuda
3141
3148
3142
3149
num_unets = self .num_unets
@@ -3146,7 +3153,7 @@ def sample(
3146
3153
if unet_number < start_at_unet_number :
3147
3154
continue # It's the easiest way to do it
3148
3155
3149
- context = self .one_unet_in_gpu (unet = unet ) if is_cuda else null_context ()
3156
+ context = self .one_unet_in_gpu (unet = unet ) if is_cuda and one_unet_in_gpu_at_time else null_context ()
3150
3157
3151
3158
with context :
3152
3159
# prepare low resolution conditioning for upsamplers
0 commit comments