diff --git a/torchlens/model_history.py b/torchlens/model_history.py index fa0ef07..4dee46a 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -4886,11 +4886,17 @@ def _rename_model_history_layer_names(self): self.conditional_branch_edges[t] = (new_child, new_parent) for module_pass, arglist in self.module_layer_argnames.items(): + inds_to_remove = [] for a, arg in enumerate(arglist): raw_name = self.module_layer_argnames[module_pass][a][0] + if raw_name not in self.raw_to_final_layer_labels: + inds_to_remove.append(a) + continue new_name = self.raw_to_final_layer_labels[raw_name] argname = self.module_layer_argnames[module_pass][a][1] self.module_layer_argnames[module_pass][a] = (new_name, argname) + self.module_layer_argnames[module_pass] = [self.module_layer_argnames[module_pass][i] + for i in range(len(arglist)) if i not in inds_to_remove] def _trim_and_reorder_model_history_fields(self): """ @@ -6441,7 +6447,7 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor( and ((('scale_factor' in layer_to_validate_parents_for.creation_kwargs) and torch.equal( self[layers_to_perturb[0]].tensor_contents, - layer_to_validate_parents_for.creation_kwargs['scale_factor'])) + torch.tensor(layer_to_validate_parents_for.creation_kwargs['scale_factor']))) or ((len(layer_to_validate_parents_for.creation_args) >= 3) and torch.equal( self[layers_to_perturb[0]].tensor_contents, @@ -6472,7 +6478,7 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor( ]: # TODO: fix this recomputed_output = input_args["args"][0] - if type(recomputed_output) in [list, tuple]: + if any([issubclass(type(recomputed_output), which_type) for which_type in [list, tuple]]): recomputed_output = recomputed_output[ layer_to_validate_parents_for.iterable_output_index ]