From 34479f2eb0590a8b491bda028aba4ffe0fcb9019 Mon Sep 17 00:00:00 2001 From: vlado ovtcharov Date: Fri, 31 Jan 2025 16:11:41 -0500 Subject: [PATCH] Add reference tag to allow getting a varaible by reference --- hydra/_internal/instantiate/_instantiate2.py | 165 +++++++++++++++++-- 1 file changed, 155 insertions(+), 10 deletions(-) diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index bc27918274..0f948bf9f4 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -22,6 +22,29 @@ class _Keys(str, Enum): RECURSIVE = "_recursive_" ARGS = "_args_" PARTIAL = "_partial_" + REFERENCE = "_reference_" + + +class TaggedValue: + def __init__(self, traverse_node=None, value=None): + self.traverse_node = traverse_node + self.value = value + + @property + def is_list(self): + return isinstance(self.traverse_node, list) + + @property + def is_dict(self): + return isinstance(self.traverse_node, dict) + + def __getitem__(self, key): + if self.is_list: + return self.traverse_node[int(key)] + elif self.is_dict: + return self.traverse_node[key] + else: + raise Exception("Can't traverse node") def _is_target(x: Any) -> bool: @@ -31,6 +54,12 @@ def _is_target(x: Any) -> bool: return "_target_" in x return False +def _is_reference(x: Any) -> bool: + if isinstance(x, dict): + return "_reference_" in x + if OmegaConf.is_dict(x): + return "_reference_" in x + return False def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]: config_args = kwargs.pop(_Keys.ARGS, ()) @@ -174,6 +203,7 @@ def instantiate( config: Any, *args: Any, _skip_instantiate_full_deepcopy_: bool = False, + references: Any = None, **kwargs: Any, ) -> Any: """ @@ -240,6 +270,8 @@ def instantiate( config = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_dict(config): + if references is None: + references = TaggedValue() # Finalize config (convert targets to strings, merge with kwargs) # Create copy to avoid mutating original if _skip_instantiate_full_deepcopy_: @@ -262,9 +294,12 @@ def instantiate( _partial_ = config.pop(_Keys.PARTIAL, False) return instantiate_node( - config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + config, config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_, + ref_root=references, ref_node=references, ) elif OmegaConf.is_list(config): + if references is None: + references = TaggedValue() # Finalize config (convert targets to strings, merge with kwargs) # Create copy to avoid mutating original if _skip_instantiate_full_deepcopy_: @@ -289,7 +324,8 @@ def instantiate( ) return instantiate_node( - config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + config, config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_, + ref_root=references, ref_node=references, ) else: raise InstantiationException( @@ -317,18 +353,77 @@ def _convert_node(node: Any, convert: Union[ConvertMode, str]) -> Any: return node +def _resolve_reference( + config_root: Any, + config_node: Any, + ref_root: TaggedValue, + ref_node: TaggedValue, + trace: list, + *args: Any, + convert: Union[str, ConvertMode] = ConvertMode.NONE, + recursive: bool = True, + partial: bool = False, +) -> Any: + # check if this is the leaf node + if len(trace) == 0: + if ref_node.value is None: + instantiate_node( + root=config_root, + node=config_node, + ref_root=ref_root, + ref_node=ref_node, + *args, + convert=convert, + recursive=recursive, + partial=partial + ) + return ref_node.value + head, tail = trace[0], trace[1:] + # if we haven't instantiated this node do that now + if ref_node.traverse_node is None: + instantiate_node( + root=config_root, + node=config_node, + ref_root=ref_root, + ref_node=ref_node, + *args, + convert=convert, + recursive=recursive, + partial=partial + ) + if ref_node.is_list: + head = int(head) + return _resolve_reference( + config_root=config_root, + config_node=config_node[head], + ref_root=ref_root, + ref_node=ref_node[head], + trace=tail, + *args, + convert=convert, + recursive=recursive, + partial=partial + ) + def instantiate_node( + root: Any, node: Any, + ref_root: Any, + ref_node: TaggedValue, *args: Any, convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, partial: bool = False, ) -> Any: + if ref_node.value is not None: + return ref_node.value # Return None if config is None if node is None or (OmegaConf.is_config(node) and node._is_none()): + ref_node.value = None return None if not OmegaConf.is_config(node): + ref_node.value = node return node # Override parent modes from config if specified @@ -355,26 +450,43 @@ def instantiate_node( # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): + ref_node.traverse_node = [] + for idx, item in enumerate(node._iter_ex(resolve=True)): + ref_node.traverse_node.append(TaggedValue()) + items = [ - instantiate_node(item, convert=convert, recursive=recursive) - for item in node._iter_ex(resolve=True) + instantiate_node( + root, item, + ref_root, ref_node.traverse_node[idx], + convert=convert, recursive=recursive) + for idx, item in enumerate(node._iter_ex(resolve=True)) ] if convert in (ConvertMode.ALL, ConvertMode.PARTIAL, ConvertMode.OBJECT): # If ALL or PARTIAL or OBJECT, use plain list as container + ref_node.value = items return items else: # Otherwise, use ListConfig as container lst = OmegaConf.create(items, flags={"allow_objects": True}) lst._set_parent(node) + ref_node.value = lst return lst elif OmegaConf.is_dict(node): + ref_node.traverse_node = {} exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) if _is_target(node): _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) kwargs = {} is_partial = node.get("_partial_", False) or partial + # instantiate all the keys in ref_node to avoid circular refs. + for key in node.keys(): + if key not in exclude_keys: + if OmegaConf.is_missing(node, key) and is_partial: + continue + ref_node.traverse_node[key] = TaggedValue() + for key in node.keys(): if key not in exclude_keys: if OmegaConf.is_missing(node, key) and is_partial: @@ -382,12 +494,35 @@ def instantiate_node( value = node[key] if recursive: value = instantiate_node( - value, convert=convert, recursive=recursive + root, value, + ref_root=ref_root, + ref_node=ref_node.traverse_node[key], + convert=convert, recursive=recursive ) kwargs[key] = _convert_node(value, convert) - - return _call_target(_target_, partial, args, kwargs, full_key) + target_value = _call_target(_target_, partial, args, kwargs, full_key) + ref_node.value = target_value + return target_value + elif _is_reference(node): + _reference_ = node.get(_Keys.REFERENCE) + ref_value = _resolve_reference( + config_root=root, + config_node=root, + ref_root=ref_root, + ref_node=ref_root, + trace=_reference_.split("."), + *args, + convert=convert, + recursive=recursive, + partial=partial, + ) + ref_node.value = ref_value + return ref_value else: + # instantiate all the keys in ref_node to avoid circular refs. + for key, value in node.items(): + ref_node.traverse_node[key] = TaggedValue() + # If ALL or PARTIAL non structured or OBJECT non structured, # instantiate in dict and resolve interpolations eagerly. if convert == ConvertMode.ALL or ( @@ -398,20 +533,30 @@ def instantiate_node( for key, value in node.items(): # list items inherits recursive flag from the containing dict. dict_items[key] = instantiate_node( - value, convert=convert, recursive=recursive + root, value, + ref_root=ref_root, + ref_node=ref_node.traverse_node[key], + convert=convert, recursive=recursive ) + ref_node.value = dict_items return dict_items else: # Otherwise use DictConfig and resolve interpolations lazily. cfg = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in node.items(): cfg[key] = instantiate_node( - value, convert=convert, recursive=recursive + root, value, + ref_root=ref_root, + ref_node=ref_node.traverse_node[key], + convert=convert, recursive=recursive ) cfg._set_parent(node) cfg._metadata.object_type = node._metadata.object_type if convert == ConvertMode.OBJECT: - return OmegaConf.to_object(cfg) + obj_value = OmegaConf.to_object(cfg) + ref_node.value = obj_value + return obj_value + ref_node.value = cfg return cfg else: