diff --git a/matbench_discovery/structure/prototype.py b/matbench_discovery/structure/prototype.py index 225568af..4b8ca853 100644 --- a/matbench_discovery/structure/prototype.py +++ b/matbench_discovery/structure/prototype.py @@ -142,17 +142,12 @@ def get_protostructure_label( spg_num = symmetry_data.number # Group sites by orbit - orbit_groups: list[list[int]] = [] - current_orbit: list[int] = [] - - for idx, orbit_id in enumerate(symmetry_data.orbits): - if not current_orbit or orbit_id == symmetry_data.orbits[current_orbit[0]]: - current_orbit += [idx] - else: - orbit_groups += [current_orbit] - current_orbit = [idx] - if current_orbit: - orbit_groups += [current_orbit] + orbit_groups: dict[int, list[int]] = {} + + for idx, orbit_id in enumerate(moyo_data.orbits): + if orbit_id not in orbit_groups: + orbit_groups[orbit_id] = [] + orbit_groups[orbit_id].append(idx) # Create equivalent_wyckoff_labels from orbit groups element_dict: dict[str, int] = {} @@ -166,7 +161,7 @@ def get_protostructure_label( str.maketrans("", "", string.digits) ), ) - for orbit in orbit_groups + for orbit in orbit_groups.values() ] equivalent_wyckoff_labels = sorted( equivalent_wyckoff_labels, key=lambda x: (x[0], x[2])