From f2ed9781dddf7a3964fb53209dc7f3f38958058c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Mon, 3 Mar 2025 18:04:45 -0800 Subject: [PATCH] Add registry to simplify mesh transform API --- notebooks/how_to/mesh_viz_compare.ipynb | 2 +- polpo/preprocessing/mesh/_pyvista.py | 4 + polpo/preprocessing/mesh/_register.py | 5 + polpo/preprocessing/mesh/_trimesh.py | 4 + polpo/preprocessing/mesh/conversion.py | 58 +++++++++++- polpo/preprocessing/mesh/transform.py | 118 ++++++++++++++++++------ 6 files changed, 157 insertions(+), 34 deletions(-) create mode 100644 polpo/preprocessing/mesh/_register.py diff --git a/notebooks/how_to/mesh_viz_compare.ipynb b/notebooks/how_to/mesh_viz_compare.ipynb index f19978d..4d51db5 100644 --- a/notebooks/how_to/mesh_viz_compare.ipynb +++ b/notebooks/how_to/mesh_viz_compare.ipynb @@ -151,7 +151,7 @@ "outputs": [], "source": [ "processing_pipe = Map(\n", - " [MeshCenterer(attr=\"points\"), MeshScaler(attr=\"points\")],\n", + " [MeshCenterer(), MeshScaler()],\n", " force_iter=True,\n", ")\n", "\n", diff --git a/polpo/preprocessing/mesh/_pyvista.py b/polpo/preprocessing/mesh/_pyvista.py index 0409602..76115d4 100644 --- a/polpo/preprocessing/mesh/_pyvista.py +++ b/polpo/preprocessing/mesh/_pyvista.py @@ -7,6 +7,10 @@ from polpo.preprocessing.base import PreprocessingStep from polpo.utils import params_to_kwargs +from ._register import register_vertices_attr + +register_vertices_attr(pv.PolyData, "points") + class PvFromData(PreprocessingStep): def apply(self, mesh): diff --git a/polpo/preprocessing/mesh/_register.py b/polpo/preprocessing/mesh/_register.py new file mode 100644 index 0000000..4882ae7 --- /dev/null +++ b/polpo/preprocessing/mesh/_register.py @@ -0,0 +1,5 @@ +VERTICES_ATTR = {} + + +def register_vertices_attr(Obj, name): + VERTICES_ATTR[Obj] = name diff --git a/polpo/preprocessing/mesh/_trimesh.py b/polpo/preprocessing/mesh/_trimesh.py index 74851e4..ba86be1 100644 --- a/polpo/preprocessing/mesh/_trimesh.py +++ b/polpo/preprocessing/mesh/_trimesh.py @@ -5,6 +5,10 @@ from polpo.preprocessing.base import PreprocessingStep +from ._register import register_vertices_attr + +register_vertices_attr(trimesh.Trimesh, "vertices") + class TrimeshFromData(PreprocessingStep): def apply(self, mesh): diff --git a/polpo/preprocessing/mesh/conversion.py b/polpo/preprocessing/mesh/conversion.py index 85ccec7..aaa54a1 100644 --- a/polpo/preprocessing/mesh/conversion.py +++ b/polpo/preprocessing/mesh/conversion.py @@ -1,3 +1,5 @@ +"""Mesh type conversion.""" + import copy import sys @@ -29,30 +31,80 @@ from polpo.macro import create_to_classes_from_from from polpo.preprocessing.base import PreprocessingStep +from ._register import VERTICES_ATTR + class ToVertices(PreprocessingStep): + """Get mesh vertices.""" + def apply(self, mesh): - return mesh.vertices + """Apply step. + + Parameters + ---------- + mesh : Mesh + Mesh with vertex information. + + Returns + ------- + vertices : array-like + Mesh vertices. + """ + attr = VERTICES_ATTR.get(type(mesh), "vertices") + return getattr(mesh, attr) class ToFaces(PreprocessingStep): + """Get mesh faces.""" + def apply(self, mesh): + """Apply step. + + Parameters + ---------- + mesh : Mesh + Mesh with face information. + + Returns + ------- + faces : array-like + Mesh faces. + """ return mesh.faces class FromCombinatorialStructure(PreprocessingStep): + """Get mesh from combinatorial structure. + + Parameters + ---------- + mesh : Mesh + Mesh containing combinatorial structure. + If None, must be supplied at apply. + """ + def __init__(self, mesh=None): + super().__init__() self.mesh = mesh def apply(self, data): - if self.mesh is None: + """Apply step. + + Returns + ------- + mesh : Mesh + Mesh with new vertices but same combinatorial structure. + """ + if isinstance(data, (list, tuple)): mesh, vertices = data else: mesh = self.mesh vertices = data mesh = copy.copy(mesh) - mesh.vertices = vertices + + vertices_attr = VERTICES_ATTR.get(type(mesh), "vertices") + setattr(mesh, vertices_attr, vertices) return mesh diff --git a/polpo/preprocessing/mesh/transform.py b/polpo/preprocessing/mesh/transform.py index 7462127..3aeaa7d 100644 --- a/polpo/preprocessing/mesh/transform.py +++ b/polpo/preprocessing/mesh/transform.py @@ -1,56 +1,114 @@ +"""Mesh transformations.""" + import numpy as np from polpo.preprocessing.base import PreprocessingStep +from ._register import VERTICES_ATTR + + +def _apply_to_attr(func, mesh, attr=None): + # NB: done in place + if attr is None: + attr = VERTICES_ATTR.get(type(mesh), "vertices") + values = func(getattr(mesh, attr)) + setattr(mesh, attr, values) + + return mesh -class MeshCenterer(PreprocessingStep): - def __init__(self, attr="vertices"): + +class MeshTransformer(PreprocessingStep): + """Applies function to mesh attribute. + + Parameters + ---------- + func : callable + Function to apply to attribute. + attr : str + Mesh attribute containing vertex information. + If None, it uses registered or "vertices". + """ + + def __init__(self, func, attr=None): super().__init__() + self.func = func self.attr = attr + def _build_func(self): + return self.func + def apply(self, mesh): - """Center a mesh by putting its barycenter at origin of the coordinates. + """Apply step. Parameters ---------- - mesh : trimesh.Trimesh - Mesh to center. + mesh : Mesh + Mesh to transform. Returns ------- - centered_mesh : trimesh.Trimesh - Centered Mesh. - hippocampus_center: coordinates of center of the mesh before centering + transformed_mesh : Mesh + Transformed mesh. """ - values = getattr(mesh, self.attr) - center = np.mean(values, axis=0) - setattr(mesh, self.attr, values - center) + if isinstance(mesh, (list, tuple)): + mesh, *transform_args = mesh + func = self._build_func(*transform_args) + else: + func = self.func - return mesh + return _apply_to_attr(func, mesh, self.attr) -class MeshScaler(PreprocessingStep): - def __init__(self, scaling_factor=20.0, attr="vertices"): - super().__init__() - self.scaling_factor = scaling_factor - self.attr = attr +class MeshCenterer(MeshTransformer): + """Center mesh by removing vertex mean. - def apply(self, mesh): - values = getattr(mesh, self.attr) - setattr(mesh, self.attr, values / self.scaling_factor) - return mesh + Parameters + ---------- + attr : str + Mesh attribute containing vertex information. + If None, it uses registered or "vertices". + """ + + def __init__(self, attr=None): + super().__init__( + attr=attr, + func=lambda vertices: vertices - np.mean(vertices, axis=0), + ) + + +class MeshScaler(MeshTransformer): + """Scale mesh. + + Parameters + ---------- + scaling_factor : float + Scaling factor. + attr : str + Mesh attribute containing vertex information. + If None, it uses registered or "vertices". + """ + + def __init__(self, scaling_factor=20.0, attr=None): + super().__init__( + attr=attr, + func=self._build_function(scaling_factor), + ) + def _build_function(self, scaling_factor): + return lambda vertices: vertices / scaling_factor -class TransformVertices(PreprocessingStep): - def apply(self, data): - # TODO: accept transformation at init? - # TODO: consider in place? - mesh, transformation = data +class AffineTransformation(MeshTransformer): + """Apply affine transform to mesh vertices.""" - rotation_matrix = transformation[:3, :3] - translation = transformation[:3, 3] + def __init__(self, transform=None, attr=None): + super().__init__( + attr=attr, + func=self._build_function(transform) if transform is not None else None, + ) - mesh.vertices = (rotation_matrix @ mesh.vertices.T).T + translation + def _build_function(self, transform): + rotation_matrix = transform[:3, :3] + translation = transform[:3, 3] - return mesh + return lambda vertices: (rotation_matrix @ vertices.T).T + translation