Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reference tag to allow getting a varaible by reference #3025

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 155 additions & 10 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ class _Keys(str, Enum):
RECURSIVE = "_recursive_"
ARGS = "_args_"
PARTIAL = "_partial_"
REFERENCE = "_reference_"


class TaggedValue:
def __init__(self, traverse_node=None, value=None):
self.traverse_node = traverse_node
self.value = value

@property
def is_list(self):
return isinstance(self.traverse_node, list)

@property
def is_dict(self):
return isinstance(self.traverse_node, dict)

def __getitem__(self, key):
if self.is_list:
return self.traverse_node[int(key)]
elif self.is_dict:
return self.traverse_node[key]
else:
raise Exception("Can't traverse node")


def _is_target(x: Any) -> bool:
Expand All @@ -31,6 +54,12 @@ def _is_target(x: Any) -> bool:
return "_target_" in x
return False

def _is_reference(x: Any) -> bool:
if isinstance(x, dict):
return "_reference_" in x
if OmegaConf.is_dict(x):
return "_reference_" in x
return False

def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]:
config_args = kwargs.pop(_Keys.ARGS, ())
Expand Down Expand Up @@ -174,6 +203,7 @@ def instantiate(
config: Any,
*args: Any,
_skip_instantiate_full_deepcopy_: bool = False,
references: Any = None,
**kwargs: Any,
) -> Any:
"""
Expand Down Expand Up @@ -240,6 +270,8 @@ def instantiate(
config = OmegaConf.structured(config, flags={"allow_objects": True})

if OmegaConf.is_dict(config):
if references is None:
references = TaggedValue()
# Finalize config (convert targets to strings, merge with kwargs)
# Create copy to avoid mutating original
if _skip_instantiate_full_deepcopy_:
Expand All @@ -262,9 +294,12 @@ def instantiate(
_partial_ = config.pop(_Keys.PARTIAL, False)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config, config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_,
ref_root=references, ref_node=references,
)
elif OmegaConf.is_list(config):
if references is None:
references = TaggedValue()
# Finalize config (convert targets to strings, merge with kwargs)
# Create copy to avoid mutating original
if _skip_instantiate_full_deepcopy_:
Expand All @@ -289,7 +324,8 @@ def instantiate(
)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config, config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_,
ref_root=references, ref_node=references,
)
else:
raise InstantiationException(
Expand Down Expand Up @@ -317,18 +353,77 @@ def _convert_node(node: Any, convert: Union[ConvertMode, str]) -> Any:
return node


def _resolve_reference(
config_root: Any,
config_node: Any,
ref_root: TaggedValue,
ref_node: TaggedValue,
trace: list,
*args: Any,
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
) -> Any:
# check if this is the leaf node
if len(trace) == 0:
if ref_node.value is None:
instantiate_node(
root=config_root,
node=config_node,
ref_root=ref_root,
ref_node=ref_node,
*args,
convert=convert,
recursive=recursive,
partial=partial
)
return ref_node.value
head, tail = trace[0], trace[1:]
# if we haven't instantiated this node do that now
if ref_node.traverse_node is None:
instantiate_node(
root=config_root,
node=config_node,
ref_root=ref_root,
ref_node=ref_node,
*args,
convert=convert,
recursive=recursive,
partial=partial
)
if ref_node.is_list:
head = int(head)
return _resolve_reference(
config_root=config_root,
config_node=config_node[head],
ref_root=ref_root,
ref_node=ref_node[head],
trace=tail,
*args,
convert=convert,
recursive=recursive,
partial=partial
)

def instantiate_node(
root: Any,
node: Any,
ref_root: Any,
ref_node: TaggedValue,
*args: Any,
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
) -> Any:
if ref_node.value is not None:
return ref_node.value
# Return None if config is None
if node is None or (OmegaConf.is_config(node) and node._is_none()):
ref_node.value = None
return None

if not OmegaConf.is_config(node):
ref_node.value = node
return node

# Override parent modes from config if specified
Expand All @@ -355,39 +450,79 @@ def instantiate_node(

# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(node):
ref_node.traverse_node = []
for idx, item in enumerate(node._iter_ex(resolve=True)):
ref_node.traverse_node.append(TaggedValue())

items = [
instantiate_node(item, convert=convert, recursive=recursive)
for item in node._iter_ex(resolve=True)
instantiate_node(
root, item,
ref_root, ref_node.traverse_node[idx],
convert=convert, recursive=recursive)
for idx, item in enumerate(node._iter_ex(resolve=True))
]

if convert in (ConvertMode.ALL, ConvertMode.PARTIAL, ConvertMode.OBJECT):
# If ALL or PARTIAL or OBJECT, use plain list as container
ref_node.value = items
return items
else:
# Otherwise, use ListConfig as container
lst = OmegaConf.create(items, flags={"allow_objects": True})
lst._set_parent(node)
ref_node.value = lst
return lst

elif OmegaConf.is_dict(node):
ref_node.traverse_node = {}
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
kwargs = {}
is_partial = node.get("_partial_", False) or partial
# instantiate all the keys in ref_node to avoid circular refs.
for key in node.keys():
if key not in exclude_keys:
if OmegaConf.is_missing(node, key) and is_partial:
continue
ref_node.traverse_node[key] = TaggedValue()

for key in node.keys():
if key not in exclude_keys:
if OmegaConf.is_missing(node, key) and is_partial:
continue
value = node[key]
if recursive:
value = instantiate_node(
value, convert=convert, recursive=recursive
root, value,
ref_root=ref_root,
ref_node=ref_node.traverse_node[key],
convert=convert, recursive=recursive
)
kwargs[key] = _convert_node(value, convert)

return _call_target(_target_, partial, args, kwargs, full_key)
target_value = _call_target(_target_, partial, args, kwargs, full_key)
ref_node.value = target_value
return target_value
elif _is_reference(node):
_reference_ = node.get(_Keys.REFERENCE)
ref_value = _resolve_reference(
config_root=root,
config_node=root,
ref_root=ref_root,
ref_node=ref_root,
trace=_reference_.split("."),
*args,
convert=convert,
recursive=recursive,
partial=partial,
)
ref_node.value = ref_value
return ref_value
else:
# instantiate all the keys in ref_node to avoid circular refs.
for key, value in node.items():
ref_node.traverse_node[key] = TaggedValue()

# If ALL or PARTIAL non structured or OBJECT non structured,
# instantiate in dict and resolve interpolations eagerly.
if convert == ConvertMode.ALL or (
Expand All @@ -398,20 +533,30 @@ def instantiate_node(
for key, value in node.items():
# list items inherits recursive flag from the containing dict.
dict_items[key] = instantiate_node(
value, convert=convert, recursive=recursive
root, value,
ref_root=ref_root,
ref_node=ref_node.traverse_node[key],
convert=convert, recursive=recursive
)
ref_node.value = dict_items
return dict_items
else:
# Otherwise use DictConfig and resolve interpolations lazily.
cfg = OmegaConf.create({}, flags={"allow_objects": True})
for key, value in node.items():
cfg[key] = instantiate_node(
value, convert=convert, recursive=recursive
root, value,
ref_root=ref_root,
ref_node=ref_node.traverse_node[key],
convert=convert, recursive=recursive
)
cfg._set_parent(node)
cfg._metadata.object_type = node._metadata.object_type
if convert == ConvertMode.OBJECT:
return OmegaConf.to_object(cfg)
obj_value = OmegaConf.to_object(cfg)
ref_node.value = obj_value
return obj_value
ref_node.value = cfg
return cfg

else:
Expand Down