Skip to content

Commit

Permalink
Add registry to simplify mesh transform API
Browse files Browse the repository at this point in the history
  • Loading branch information
luisfpereira committed Mar 4, 2025
1 parent cfc8bf0 commit f2ed978
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 34 deletions.
2 changes: 1 addition & 1 deletion notebooks/how_to/mesh_viz_compare.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"outputs": [],
"source": [
"processing_pipe = Map(\n",
" [MeshCenterer(attr=\"points\"), MeshScaler(attr=\"points\")],\n",
" [MeshCenterer(), MeshScaler()],\n",
" force_iter=True,\n",
")\n",
"\n",
Expand Down
4 changes: 4 additions & 0 deletions polpo/preprocessing/mesh/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from polpo.preprocessing.base import PreprocessingStep
from polpo.utils import params_to_kwargs

from ._register import register_vertices_attr

register_vertices_attr(pv.PolyData, "points")


class PvFromData(PreprocessingStep):
def apply(self, mesh):
Expand Down
5 changes: 5 additions & 0 deletions polpo/preprocessing/mesh/_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
VERTICES_ATTR = {}


def register_vertices_attr(Obj, name):
VERTICES_ATTR[Obj] = name
4 changes: 4 additions & 0 deletions polpo/preprocessing/mesh/_trimesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from polpo.preprocessing.base import PreprocessingStep

from ._register import register_vertices_attr

register_vertices_attr(trimesh.Trimesh, "vertices")


class TrimeshFromData(PreprocessingStep):
def apply(self, mesh):
Expand Down
58 changes: 55 additions & 3 deletions polpo/preprocessing/mesh/conversion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Mesh type conversion."""

import copy
import sys

Expand Down Expand Up @@ -29,30 +31,80 @@
from polpo.macro import create_to_classes_from_from
from polpo.preprocessing.base import PreprocessingStep

from ._register import VERTICES_ATTR


class ToVertices(PreprocessingStep):
"""Get mesh vertices."""

def apply(self, mesh):
return mesh.vertices
"""Apply step.
Parameters
----------
mesh : Mesh
Mesh with vertex information.
Returns
-------
vertices : array-like
Mesh vertices.
"""
attr = VERTICES_ATTR.get(type(mesh), "vertices")
return getattr(mesh, attr)


class ToFaces(PreprocessingStep):
"""Get mesh faces."""

def apply(self, mesh):
"""Apply step.
Parameters
----------
mesh : Mesh
Mesh with face information.
Returns
-------
faces : array-like
Mesh faces.
"""
return mesh.faces


class FromCombinatorialStructure(PreprocessingStep):
"""Get mesh from combinatorial structure.
Parameters
----------
mesh : Mesh
Mesh containing combinatorial structure.
If None, must be supplied at apply.
"""

def __init__(self, mesh=None):
super().__init__()
self.mesh = mesh

def apply(self, data):
if self.mesh is None:
"""Apply step.
Returns
-------
mesh : Mesh
Mesh with new vertices but same combinatorial structure.
"""
if isinstance(data, (list, tuple)):
mesh, vertices = data
else:
mesh = self.mesh
vertices = data

mesh = copy.copy(mesh)
mesh.vertices = vertices

vertices_attr = VERTICES_ATTR.get(type(mesh), "vertices")
setattr(mesh, vertices_attr, vertices)

return mesh

Expand Down
118 changes: 88 additions & 30 deletions polpo/preprocessing/mesh/transform.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,114 @@
"""Mesh transformations."""

import numpy as np

from polpo.preprocessing.base import PreprocessingStep

from ._register import VERTICES_ATTR


def _apply_to_attr(func, mesh, attr=None):
# NB: done in place
if attr is None:
attr = VERTICES_ATTR.get(type(mesh), "vertices")
values = func(getattr(mesh, attr))
setattr(mesh, attr, values)

return mesh

class MeshCenterer(PreprocessingStep):
def __init__(self, attr="vertices"):

class MeshTransformer(PreprocessingStep):
"""Applies function to mesh attribute.
Parameters
----------
func : callable
Function to apply to attribute.
attr : str
Mesh attribute containing vertex information.
If None, it uses registered or "vertices".
"""

def __init__(self, func, attr=None):
super().__init__()
self.func = func
self.attr = attr

def _build_func(self):
return self.func

def apply(self, mesh):
"""Center a mesh by putting its barycenter at origin of the coordinates.
"""Apply step.
Parameters
----------
mesh : trimesh.Trimesh
Mesh to center.
mesh : Mesh
Mesh to transform.
Returns
-------
centered_mesh : trimesh.Trimesh
Centered Mesh.
hippocampus_center: coordinates of center of the mesh before centering
transformed_mesh : Mesh
Transformed mesh.
"""
values = getattr(mesh, self.attr)
center = np.mean(values, axis=0)
setattr(mesh, self.attr, values - center)
if isinstance(mesh, (list, tuple)):
mesh, *transform_args = mesh
func = self._build_func(*transform_args)
else:
func = self.func

return mesh
return _apply_to_attr(func, mesh, self.attr)


class MeshScaler(PreprocessingStep):
def __init__(self, scaling_factor=20.0, attr="vertices"):
super().__init__()
self.scaling_factor = scaling_factor
self.attr = attr
class MeshCenterer(MeshTransformer):
"""Center mesh by removing vertex mean.
def apply(self, mesh):
values = getattr(mesh, self.attr)
setattr(mesh, self.attr, values / self.scaling_factor)
return mesh
Parameters
----------
attr : str
Mesh attribute containing vertex information.
If None, it uses registered or "vertices".
"""

def __init__(self, attr=None):
super().__init__(
attr=attr,
func=lambda vertices: vertices - np.mean(vertices, axis=0),
)


class MeshScaler(MeshTransformer):
"""Scale mesh.
Parameters
----------
scaling_factor : float
Scaling factor.
attr : str
Mesh attribute containing vertex information.
If None, it uses registered or "vertices".
"""

def __init__(self, scaling_factor=20.0, attr=None):
super().__init__(
attr=attr,
func=self._build_function(scaling_factor),
)

def _build_function(self, scaling_factor):
return lambda vertices: vertices / scaling_factor

class TransformVertices(PreprocessingStep):
def apply(self, data):
# TODO: accept transformation at init?
# TODO: consider in place?

mesh, transformation = data
class AffineTransformation(MeshTransformer):
"""Apply affine transform to mesh vertices."""

rotation_matrix = transformation[:3, :3]
translation = transformation[:3, 3]
def __init__(self, transform=None, attr=None):
super().__init__(
attr=attr,
func=self._build_function(transform) if transform is not None else None,
)

mesh.vertices = (rotation_matrix @ mesh.vertices.T).T + translation
def _build_function(self, transform):
rotation_matrix = transform[:3, :3]
translation = transform[:3, 3]

return mesh
return lambda vertices: (rotation_matrix @ vertices.T).T + translation

0 comments on commit f2ed978

Please sign in to comment.