-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindoor.py
91 lines (81 loc) · 3.14 KB
/
indoor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Author: HONGTOU TU
Last modified: 19.5.2024
"""
import os, torch
import numpy as np
from torch.utils.data import Dataset
from benchmark_utils import to_o3d_pcd, to_tsfm, get_correspondences
class IndoorDataset(Dataset):
"""
功能: 为室内点云数据集提供一个数据加载接口。
输入参数:
infos: 包含数据集信息的字典。
config: 包含数据集配置的对象。
data_augmentation: 是否应用数据增强。
方法:
__len__: 返回数据集中样本的数量。
__getitem__: 根据索引获取单个样本。
备注:
Load subsampled coordinates, relative rotation and translation
Output(torch.Tensor):
src_pcd: [N,3]
tgt_pcd: [M,3]
rot: [3,3]
trans: [3,1]
"""
def __init__(self,infos,config,data_augmentation=True):
"""
功能: 初始化IndoorDataset对象。
输入参数:
infos: 包含数据集信息的字典。
config: 包含数据集配置的对象。
data_augmentation: 是否应用数据增强。
"""
super().__init__()
self.infos = infos
self.base_dir = config.root
self.overlap_radius = config.overlap_radius
self.config = config
def __len__(self):
"""
功能: 返回数据集中样本的数量。
返回值:
数据集中样本的数量。
"""
return len(self.infos['rot'])
def __getitem__(self,item):
"""
功能: 根据索引获取单个样本。
输入参数:
item: 样本的索引。
返回值:
包含源点云、目标点云、旋转矩阵、平移向量和对应点的元组。
"""
# get transformation,获取样本的旋转矩阵和平移向量
rot=self.infos['rot'][item]
trans=self.infos['trans'][item]
# get pointcloud,根据索引获取源点云和目标点云的文件路径,并加载它们
src_path=os.path.join(self.base_dir,self.infos['src'][item])
tgt_path=os.path.join(self.base_dir,self.infos['tgt'][item])
src_pcd = torch.load(src_path)
tgt_pcd = torch.load(tgt_path)
# 如果平移向量是一维的,将其转换为二维形式
if(trans.ndim==1):
trans=trans[:,None]
# get correspondence at fine level,使用旋转矩阵和平移向量生成变换矩阵,并获取源点云和目标点云之间的对应点
tsfm = to_tsfm(rot, trans)
correspondences = get_correspondences(to_o3d_pcd(src_pcd), to_o3d_pcd(tgt_pcd), tsfm,self.overlap_radius)
# 将旋转矩阵和平移向量转换为torch.Tensor类型
rot = rot.astype(np.float32)
rot = torch.from_numpy(rot)
trans = trans.astype(np.float32)
trans = torch.from_numpy(trans)
# 返回源点云、目标点云、旋转矩阵、平移向量和对应点
return (
torch.from_numpy(src_pcd),
torch.from_numpy(tgt_pcd),
rot,
trans,
correspondences,
)