-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfusion_scannet_lseg.py
244 lines (192 loc) · 7.85 KB
/
fusion_scannet_lseg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import os
import torch
import imageio
import argparse
from os.path import join, exists
import numpy as np
from glob import glob
from tqdm import tqdm, trange
# add the parent directory to the path
import sys
sys.path.append('..')
from models import LSeg_MultiEvalModule
from modules.lseg_module import LSegModule
from encoding.models.sseg import BaseNet
import torchvision.transforms as transforms
from fusion_util import extract_lseg_img_feature, PointCloudToImageMapper, save_fused_feature_with_locs, adjust_intrinsic, make_intrinsic
def get_args():
# command line args
parser = argparse.ArgumentParser(
description='Multi-view feature fusion of LSeg on ScanNet.')
parser.add_argument('--data_dir', type=str, help='Where is the base logging directory')
parser.add_argument('--output_dir', type=str, help='Where is the base logging directory')
parser.add_argument('--split', type=str, default='val', help='split: "train"| "val"')
parser.add_argument('--lseg_model', type=str, default='dataset/lseg/lseg_checkpoint/demo_e200.ckpt', help='Where is the LSeg checkpoint')
parser.add_argument('--process_id_range', nargs='+', default=None, help='the id range to process')
parser.add_argument('--prefix', type=str, default='lseg', help='prefix for the output file')
# Hyper parameters
parser.add_argument('--hparams', default=[], nargs="+")
args = parser.parse_args()
return args
def process_one_scene(data_path, out_dir, args):
# short hand
scene_id = data_path.split('/')[-1].split('_vh')[0]
feat_dim = args.feat_dim
point2img_mapper = args.point2img_mapper
depth_scale = args.depth_scale
keep_features_in_memory = args.keep_features_in_memory
evaluator = args.evaluator
transform = args.transform
# load 3D data (point cloud)
locs_in = torch.load(data_path)[0]
n_points = locs_in.shape[0]
if exists(join(out_dir, args.prefix+'_features', scene_id + '.pt')):
print(scene_id +'.pt' + ' already exists, skip!')
return 1
# short hand for processing 2D features
scene = join(args.data_root_2d, scene_id)
img_dirs = sorted(glob(join(scene, 'color/*')), key=lambda x: int(os.path.basename(x)[:-4]))
num_img = len(img_dirs)
device = torch.device('cpu')
# extract image features and keep them in the memory
# default: False (extract image on the fly)
if keep_features_in_memory and evaluator is not None:
img_features = []
for img_dir in tqdm(img_dirs):
img_features.append(extract_lseg_img_feature(img_dir, transform, evaluator))
n_points_cur = n_points
counter = torch.zeros((n_points_cur, 1), device=device)
sum_features = torch.zeros((n_points_cur, feat_dim), device=device)
################ Feature Fusion ###################
vis_id = torch.zeros((n_points_cur, num_img), dtype=int, device=device)
for img_id, img_dir in enumerate(tqdm(img_dirs)):
# load pose
posepath = img_dir.replace('color', 'pose').replace('.jpg', '.txt')
pose = np.loadtxt(posepath)
# load depth and convert to meter
depth = imageio.v2.imread(img_dir.replace('color', 'depth').replace('jpg', 'png')) / depth_scale
# calculate the 3d-2d mapping based on the depth
mapping = np.ones([n_points, 4], dtype=int)
mapping[:, 1:4] = point2img_mapper.compute_mapping(pose, locs_in, depth)
if mapping[:, 3].sum() == 0: # no points corresponds to this image, skip
continue
mapping = torch.from_numpy(mapping).to(device)
mask = mapping[:, 3]
vis_id[:, img_id] = mask
if keep_features_in_memory:
feat_2d = img_features[img_id].to(device)
else:
feat_2d = extract_lseg_img_feature(img_dir, transform, evaluator).to(device) # [512, 240, 320]
# import pdb; pdb.set_trace()
feat_2d_3d = feat_2d[:, mapping[:, 1], mapping[:, 2]].permute(1, 0)
counter[mask!=0]+= 1
sum_features[mask!=0] += feat_2d_3d[mask!=0]
counter[counter==0] = 1e-5
feat_bank = sum_features/counter
point_ids = torch.unique(vis_id.nonzero(as_tuple=False)[:, 0])
save_fused_feature_with_locs(feat_bank, point_ids, locs_in, n_points, out_dir, scene_id, args)
def main(args):
seed = 1457
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
#!### Dataset specific parameters #####
img_dim = (320, 240)
depth_scale = 1000.0
fx = 577.870605
fy = 577.870605
mx=319.5
my=239.5
#######################################
visibility_threshold = 0.25 # threshold for the visibility check
args.depth_scale = depth_scale
args.cut_num_pixel_boundary = 10 # do not use the features on the image boundary
args.keep_features_in_memory = False # keep image features in the memory, very expensive
split = args.split
data_dir = args.data_dir
data_root = join(data_dir, 'scannet_3d')
data_root_2d = join(data_dir,'scannet_2d')
args.data_root_2d = data_root_2d
out_dir = args.output_dir
args.feat_dim = 512 # CLIP feature dimension
os.makedirs(out_dir, exist_ok=True)
process_id_range = args.process_id_range
if split== 'train': # for training set, export a chunk of point cloud
args.n_split_points = 300000
else: # for the validation set, export the entire point cloud instead of chunks
args.n_split_points = 2000000
##############################
##### load the LSeg model ####
module = LSegModule.load_from_checkpoint(
checkpoint_path=args.lseg_model,
data_path='../datasets/',
dataset='ade20k',
backbone='clip_vitl16_384',
aux=False,
num_features=256,
aux_weight=0,
se_loss=False,
se_weight=0,
base_lr=0,
batch_size=1,
max_epochs=0,
ignore_index=255,
dropout=0.0,
scale_inv=False,
augment=False,
no_batchnorm=False,
widehead=True,
widehead_hr=False,
map_locatin="cpu",
arch_option=0,
block_depth=0,
activation='lrelu',
)
# model
if isinstance(module.net, BaseNet):
model = module.net
else:
model = module
model = model.eval()
model = model.cpu()
model.mean = [0.5, 0.5, 0.5]
model.std = [0.5, 0.5, 0.5]
#############################################
# THE trick for getting proper LSeg feature for ScanNet
scales = ([1])
model.crop_size = 640
model.base_size = 640
evaluator = LSeg_MultiEvalModule(
model, scales=scales, flip=True
).cuda() # LSeg model has to be in GPU
evaluator.eval()
args.evaluator = evaluator
args.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
# calculate image pixel-3D points correspondances
intrinsic = make_intrinsic(fx=fx, fy=fy, mx=mx, my=my)
intrinsic = adjust_intrinsic(intrinsic, intrinsic_image_dim=[640, 480], image_dim=img_dim)
args.point2img_mapper = PointCloudToImageMapper(
image_dim=img_dim, intrinsics=intrinsic,
visibility_threshold=visibility_threshold,
cut_bound=args.cut_num_pixel_boundary)
data_paths = sorted(glob(join(data_root, split, '*.pth')))
total_num = len(data_paths)
id_range = None
if process_id_range is not None:
id_range = [int(process_id_range[0].split(',')[0]), int(process_id_range[0].split(',')[1])]
for i in trange(total_num):
if id_range is not None and \
(i<id_range[0] or i>id_range[1]):
print('skip ', i, data_paths[i])
continue
process_one_scene(data_paths[i], out_dir, args)
if __name__ == "__main__":
args = get_args()
print("Arguments:")
print(args)
main(args)