Skip to content

Commit

Permalink
feat(torch): add dataloadaer collator
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitMY committed Dec 31, 2023
1 parent 79fdf23 commit c4b8ece
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/python/pose_format/numpy/pose_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class NumPyPoseBody(PoseBody):
confidence array of the pose keypoints.
"""

tensor_reader = 'unpack_numpy' """Specifies the method name for unpacking a numpy array (Value: 'unpack_numpy')."""
"""Specifies the method name for unpacking a numpy array (Value: 'unpack_numpy')."""
tensor_reader = 'unpack_numpy'

def __init__(self, fps: float, data: Union[ma.MaskedArray, np.ndarray], confidence: np.ndarray):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/python/pose_format/tensorflow/pose_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class TensorflowPoseBody(PoseBody):
confidence : tf.Tensor
The confidence scores for the pose data.
"""
tensor_reader = 'unpack_tensorflow' """str: The method used to read the tensor data. (Type: str)"""

"""str: The method used to read the tensor data. (Type: str)"""
tensor_reader = 'unpack_tensorflow'

def __init__(self, fps: float, data: Union[MaskedTensor, tf.Tensor], confidence: tf.Tensor):
"""
Expand Down
64 changes: 64 additions & 0 deletions src/python/pose_format/torch/masked/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from pose_format.torch.masked import MaskedTensor, MaskedTorch


def pad_tensors(batch: List[Union[torch.Tensor, MaskedTensor]], pad_value=0):
datum = batch[0]
torch_cls = MaskedTorch if isinstance(datum, MaskedTensor) else torch

max_len = max(len(t) for t in batch)
if max_len == 1:
return torch_cls.stack(batch, dim=0)

new_batch = []
for tensor in batch:
missing = list(tensor.shape)
missing[0] = max_len - tensor.shape[0]

if missing[0] > 0:
padding_tensor = torch.full(missing, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device)
tensor = torch_cls.cat([tensor, padding_tensor], dim=0)

new_batch.append(tensor)

return torch_cls.stack(new_batch, dim=0)


def collate_tensors(batch: List, pad_value=0) -> Union[torch.Tensor, List]:
datum = batch[0]

if isinstance(datum, dict): # Recurse over dictionaries
return zero_pad_collator(batch)

if isinstance(datum, (int, np.int32)):
return torch.tensor(batch, dtype=torch.long)

if isinstance(datum, (MaskedTensor, torch.Tensor)):
return pad_tensors(batch, pad_value=pad_value)

return batch


def zero_pad_collator(batch) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, ...]]:
datum = batch[0]

# For strings
if isinstance(datum, str):
return batch

# For tuples
if isinstance(datum, tuple):
return tuple(collate_tensors([b[i] for b in batch]) for i in range(len(datum)))

# For tensors
if isinstance(datum, MaskedTensor):
return collate_tensors(batch)

# For dictionaries
keys = datum.keys()
return {k: collate_tensors([b[k] for b in batch]) for k in keys}


4 changes: 3 additions & 1 deletion src/python/pose_format/torch/pose_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class TorchPoseBody(PoseBody):
This class extends the PoseBody class and provides methods for manipulating pose data using PyTorch tensors.
"""
tensor_reader = 'unpack_torch' """str: Reader format for unpacking Torch tensors."""

"""str: Reader format for unpacking Torch tensors."""
tensor_reader = 'unpack_torch'

def __init__(self, fps: float, data: Union[MaskedTensor, torch.Tensor], confidence: torch.Tensor):
if isinstance(data, torch.Tensor): # If array is not masked
Expand Down
2 changes: 1 addition & 1 deletion src/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "pose_format"
description = "Library for viewing, augmenting, and handling .pose files"
version = "0.2.3"
version = "0.3.0"
keywords = ["Pose Files", "Pose Interpolation", "Pose Augmentation"]
authors = [
{ name = "Amit Moryossef", email = "amitmoryossef@gmail.com" },
Expand Down

0 comments on commit c4b8ece

Please sign in to comment.