Skip to content

Commit c6c3882

Browse files
committed
fix all optional types in train config
1 parent 512b52b commit c6c3882

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

dalle2_pytorch/train_configs.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool =
115115
class AdapterConfig(BaseModel):
116116
make: str = "openai"
117117
model: str = "ViT-L/14"
118-
base_model_kwargs: Dict[str, Any] = None
118+
base_model_kwargs: Optional[Dict[str, Any]] = None
119119

120120
def create(self):
121121
if self.make == "openai":
@@ -134,8 +134,8 @@ def create(self):
134134
class DiffusionPriorNetworkConfig(BaseModel):
135135
dim: int
136136
depth: int
137-
max_text_len: int = None
138-
num_timesteps: int = None
137+
max_text_len: Optional[int] = None
138+
num_timesteps: Optional[int] = None
139139
num_time_embeds: int = 1
140140
num_image_embeds: int = 1
141141
num_text_embeds: int = 1
@@ -158,7 +158,7 @@ def create(self):
158158
return DiffusionPriorNetwork(**kwargs)
159159

160160
class DiffusionPriorConfig(BaseModel):
161-
clip: AdapterConfig = None
161+
clip: Optional[AdapterConfig] = None
162162
net: DiffusionPriorNetworkConfig
163163
image_embed_dim: int
164164
image_size: int
@@ -195,7 +195,7 @@ class DiffusionPriorTrainConfig(BaseModel):
195195
use_ema: bool = True
196196
ema_beta: float = 0.99
197197
amp: bool = False
198-
warmup_steps: int = None # number of warmup steps
198+
warmup_steps: Optional[int] = None # number of warmup steps
199199
save_every_seconds: int = 3600 # how often to save
200200
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
201201
best_validation_loss: float = 1e9 # the current best valudation loss observed
@@ -228,10 +228,10 @@ def from_json_path(cls, json_path):
228228
class UnetConfig(BaseModel):
229229
dim: int
230230
dim_mults: ListOrTuple[int]
231-
image_embed_dim: int = None
232-
text_embed_dim: int = None
233-
cond_on_text_encodings: bool = None
234-
cond_dim: int = None
231+
image_embed_dim: Optional[int] = None
232+
text_embed_dim: Optional[int] = None
233+
cond_on_text_encodings: Optional[bool] = None
234+
cond_dim: Optional[int] = None
235235
channels: int = 3
236236
self_attn: ListOrTuple[int]
237237
attn_dim_head: int = 32
@@ -243,14 +243,14 @@ class Config:
243243

244244
class DecoderConfig(BaseModel):
245245
unets: ListOrTuple[UnetConfig]
246-
image_size: int = None
246+
image_size: Optional[int] = None
247247
image_sizes: ListOrTuple[int] = None
248248
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
249249
channels: int = 3
250250
timesteps: int = 1000
251251
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
252252
loss_type: str = 'l2'
253-
beta_schedule: ListOrTuple[str] = None # None means all cosine
253+
beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
254254
learned_variance: SingularOrIterable[bool] = True
255255
image_cond_drop_prob: float = 0.1
256256
text_cond_drop_prob: float = 0.5
@@ -320,20 +320,20 @@ class DecoderTrainConfig(BaseModel):
320320
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
321321
cond_scale: Union[float, List[float]] = 1.0
322322
device: str = 'cuda:0'
323-
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
324-
validation_samples: int = None # Same as above but for validation.
323+
epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
324+
validation_samples: Optional[int] = None # Same as above but for validation.
325325
save_immediately: bool = False
326326
use_ema: bool = True
327327
ema_beta: float = 0.999
328328
amp: bool = False
329-
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
329+
unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
330330

331331
class DecoderEvaluateConfig(BaseModel):
332332
n_evaluation_samples: int = 1000
333-
FID: Dict[str, Any] = None
334-
IS: Dict[str, Any] = None
335-
KID: Dict[str, Any] = None
336-
LPIPS: Dict[str, Any] = None
333+
FID: Optional[Dict[str, Any]] = None
334+
IS: Optional[Dict[str, Any]] = None
335+
KID: Optional[Dict[str, Any]] = None
336+
LPIPS: Optional[Dict[str, Any]] = None
337337

338338
class TrainDecoderConfig(BaseModel):
339339
decoder: DecoderConfig

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.15.2'
1+
__version__ = '1.15.3'

0 commit comments

Comments
 (0)