1
1
import json
2
2
from torchvision import transforms as T
3
- from pydantic import BaseModel , validator , root_validator
3
+ from pydantic import BaseModel , validator , model_validator
4
4
from typing import List , Optional , Union , Tuple , Dict , Any , TypeVar
5
5
6
6
from x_clip import CLIP as XCLIP
@@ -38,12 +38,12 @@ class TrainSplitConfig(BaseModel):
38
38
val : float = 0.15
39
39
test : float = 0.1
40
40
41
- @root_validator
42
- def validate_all (cls , fields ):
43
- actual_sum = sum ([* fields .values ()])
41
+ @model_validator ( mode = 'after' )
42
+ def validate_all (self , m ):
43
+ actual_sum = sum ([* dict ( self ) .values ()])
44
44
if actual_sum != 1. :
45
- raise ValueError (f'{ fields .keys ()} must sum to 1.0. Found: { actual_sum } ' )
46
- return fields
45
+ raise ValueError (f'{ dict ( self ) .keys ()} must sum to 1.0. Found: { actual_sum } ' )
46
+ return self
47
47
48
48
class TrackerLogConfig (BaseModel ):
49
49
log_type : str = 'console'
@@ -59,6 +59,7 @@ def create(self, data_path: str):
59
59
kwargs = self .dict ()
60
60
return create_logger (self .log_type , data_path , ** kwargs )
61
61
62
+
62
63
class TrackerLoadConfig (BaseModel ):
63
64
load_from : Optional [str ] = None
64
65
only_auto_resume : bool = False # Only attempt to load if the logger is auto-resuming
@@ -277,9 +278,9 @@ class Config:
277
278
extra = "allow"
278
279
279
280
class DecoderDataConfig (BaseModel ):
280
- webdataset_base_url : str # path to a webdataset with jpg images
281
- img_embeddings_url : Optional [str ] # path to .npy files with embeddings
282
- text_embeddings_url : Optional [str ] # path to .npy files with embeddings
281
+ webdataset_base_url : str # path to a webdataset with jpg images
282
+ img_embeddings_url : Optional [str ] = None # path to .npy files with embeddings
283
+ text_embeddings_url : Optional [str ] = None # path to .npy files with embeddings
283
284
num_workers : int = 4
284
285
batch_size : int = 64
285
286
start_shard : int = 0
@@ -346,11 +347,14 @@ class TrainDecoderConfig(BaseModel):
346
347
def from_json_path (cls , json_path ):
347
348
with open (json_path ) as f :
348
349
config = json .load (f )
350
+ print (config )
349
351
return cls (** config )
350
352
351
- @root_validator
352
- def check_has_embeddings (cls , values ):
353
+ @model_validator ( mode = 'after' )
354
+ def check_has_embeddings (self , m ):
353
355
# Makes sure that enough information is provided to get the embeddings specified for training
356
+ values = dict (self )
357
+
354
358
data_config , decoder_config = values .get ('data' ), values .get ('decoder' )
355
359
356
360
if not exists (data_config ) or not exists (decoder_config ):
@@ -375,4 +379,4 @@ def check_has_embeddings(cls, values):
375
379
if text_emb_url :
376
380
assert using_text_embeddings , "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
377
381
378
- return values
382
+ return m
0 commit comments