From c4b8ece9e8d7e25340c5ce4d58cecb24a376d2c6 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Sun, 31 Dec 2023 12:01:58 +0100 Subject: [PATCH] feat(torch): add dataloadaer collator --- src/python/pose_format/numpy/pose_body.py | 3 +- .../pose_format/tensorflow/pose_body.py | 4 +- .../pose_format/torch/masked/collator.py | 64 +++++++++++++++++++ src/python/pose_format/torch/pose_body.py | 4 +- src/python/pyproject.toml | 2 +- 5 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 src/python/pose_format/torch/masked/collator.py diff --git a/src/python/pose_format/numpy/pose_body.py b/src/python/pose_format/numpy/pose_body.py index 0bc7fd2..2ba1014 100644 --- a/src/python/pose_format/numpy/pose_body.py +++ b/src/python/pose_format/numpy/pose_body.py @@ -37,7 +37,8 @@ class NumPyPoseBody(PoseBody): confidence array of the pose keypoints. """ - tensor_reader = 'unpack_numpy' """Specifies the method name for unpacking a numpy array (Value: 'unpack_numpy').""" + """Specifies the method name for unpacking a numpy array (Value: 'unpack_numpy').""" + tensor_reader = 'unpack_numpy' def __init__(self, fps: float, data: Union[ma.MaskedArray, np.ndarray], confidence: np.ndarray): """ diff --git a/src/python/pose_format/tensorflow/pose_body.py b/src/python/pose_format/tensorflow/pose_body.py index 1e74d5f..b804cf8 100644 --- a/src/python/pose_format/tensorflow/pose_body.py +++ b/src/python/pose_format/tensorflow/pose_body.py @@ -28,7 +28,9 @@ class TensorflowPoseBody(PoseBody): confidence : tf.Tensor The confidence scores for the pose data. """ - tensor_reader = 'unpack_tensorflow' """str: The method used to read the tensor data. (Type: str)""" + + """str: The method used to read the tensor data. (Type: str)""" + tensor_reader = 'unpack_tensorflow' def __init__(self, fps: float, data: Union[MaskedTensor, tf.Tensor], confidence: tf.Tensor): """ diff --git a/src/python/pose_format/torch/masked/collator.py b/src/python/pose_format/torch/masked/collator.py new file mode 100644 index 0000000..9ec2aa9 --- /dev/null +++ b/src/python/pose_format/torch/masked/collator.py @@ -0,0 +1,64 @@ +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from pose_format.torch.masked import MaskedTensor, MaskedTorch + + +def pad_tensors(batch: List[Union[torch.Tensor, MaskedTensor]], pad_value=0): + datum = batch[0] + torch_cls = MaskedTorch if isinstance(datum, MaskedTensor) else torch + + max_len = max(len(t) for t in batch) + if max_len == 1: + return torch_cls.stack(batch, dim=0) + + new_batch = [] + for tensor in batch: + missing = list(tensor.shape) + missing[0] = max_len - tensor.shape[0] + + if missing[0] > 0: + padding_tensor = torch.full(missing, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device) + tensor = torch_cls.cat([tensor, padding_tensor], dim=0) + + new_batch.append(tensor) + + return torch_cls.stack(new_batch, dim=0) + + +def collate_tensors(batch: List, pad_value=0) -> Union[torch.Tensor, List]: + datum = batch[0] + + if isinstance(datum, dict): # Recurse over dictionaries + return zero_pad_collator(batch) + + if isinstance(datum, (int, np.int32)): + return torch.tensor(batch, dtype=torch.long) + + if isinstance(datum, (MaskedTensor, torch.Tensor)): + return pad_tensors(batch, pad_value=pad_value) + + return batch + + +def zero_pad_collator(batch) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, ...]]: + datum = batch[0] + + # For strings + if isinstance(datum, str): + return batch + + # For tuples + if isinstance(datum, tuple): + return tuple(collate_tensors([b[i] for b in batch]) for i in range(len(datum))) + + # For tensors + if isinstance(datum, MaskedTensor): + return collate_tensors(batch) + + # For dictionaries + keys = datum.keys() + return {k: collate_tensors([b[k] for b in batch]) for k in keys} + + diff --git a/src/python/pose_format/torch/pose_body.py b/src/python/pose_format/torch/pose_body.py index c3e163f..a062cea 100644 --- a/src/python/pose_format/torch/pose_body.py +++ b/src/python/pose_format/torch/pose_body.py @@ -15,7 +15,9 @@ class TorchPoseBody(PoseBody): This class extends the PoseBody class and provides methods for manipulating pose data using PyTorch tensors. """ - tensor_reader = 'unpack_torch' """str: Reader format for unpacking Torch tensors.""" + + """str: Reader format for unpacking Torch tensors.""" + tensor_reader = 'unpack_torch' def __init__(self, fps: float, data: Union[MaskedTensor, torch.Tensor], confidence: torch.Tensor): if isinstance(data, torch.Tensor): # If array is not masked diff --git a/src/python/pyproject.toml b/src/python/pyproject.toml index b947584..2459d37 100644 --- a/src/python/pyproject.toml +++ b/src/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "pose_format" description = "Library for viewing, augmenting, and handling .pose files" -version = "0.2.3" +version = "0.3.0" keywords = ["Pose Files", "Pose Interpolation", "Pose Augmentation"] authors = [ { name = "Amit Moryossef", email = "amitmoryossef@gmail.com" },