forked from IDEA-Research/Grounded-Segment-Anything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathram_sam_ros_wrapper.py
executable file
·252 lines (215 loc) · 9.96 KB
/
ram_sam_ros_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
import rospy
from sensor_msgs.msg import Image as ROSImage
from cv_bridge import CvBridge, CvBridgeError
import argparse
import os
import numpy as np
import json
import torch
import torchvision
from PIL import Image as PILImage
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from segment_anything import (
build_sam,
build_sam_hq,
SamPredictor
)
import cv2
import numpy as np
import matplotlib.pyplot as plt
# Recognize Anything Model & Tag2Text
import sys
sys.path.append('/home/appuser/catkin_ws/src/grounded_sam/script/Tag2Text')
from Tag2Text.models import tag2text
from Tag2Text import inference_ram
import torchvision.transforms as TS
class GroundedSAMServiceNode:
def __init__(self):
# ROS Initialization
rospy.init_node('grounded_sam_service_node')
# To store the latest image
self.bridge = CvBridge()
self.latest_image_msg = ROSImage()
try:
self.check_and_get_parameters()
# Load all models
self.grounding_dino_model, self.ram_model, self.sam_predictor = self.load_all_models()
# ROS Subscribers and Timer
self.image_subscriber = rospy.Subscriber("/input_image", ROSImage, self.image_callback)
self.timer = rospy.Timer(rospy.Duration(1.0), self.timer_callback)
except Exception as e:
rospy.logerr(f"Error during initialization: {e}")
rospy.signal_shutdown("Error during initialization")
def get_required_param(self, name, default=None):
value = rospy.get_param(name, default)
if value is None:
rospy.logerr(f"Required parameter {name} is missing!")
raise ValueError(f"Required parameter {name} is not set on the ROS parameter server!")
rospy.loginfo(f"Loaded parameter {name}: {value}")
return value
def get_float_param(self, name, default=None):
value = rospy.get_param(name, default)
if not isinstance(value, (float, int)):
rospy.logerr(f"Parameter {name} should be a float but got {type(value)}!")
raise ValueError(f"Parameter {name} is not a float!")
rospy.loginfo(f"Loaded parameter {name}: {value}")
return float(value)
def check_and_get_parameters(self):
self.config = self.get_required_param("~config")
self.ram_checkpoint = self.get_required_param("~ram_checkpoint")
self.grounded_checkpoint = self.get_required_param("~grounded_checkpoint")
self.sam_checkpoint = self.get_required_param("~sam_checkpoint")
# self.sam_hq_checkpoint = self.get_required_param("~sam_hq_checkpoint", default=None)
self.use_sam_hq = self.get_required_param("~use_sam_hq", default=False)
# self.input_image_path = self.get_required_param("~input_image")
self.split = self.get_required_param("~split", default=",")
# self.openai_key = self.get_required_param("~openai_key", default=None)
# self.openai_proxy = self.get_required_param("~openai_proxy", default=None)
# self.output_dir = self.get_required_param("~output_dir")
self.box_threshold = self.get_float_param("~box_threshold", default=0.25)
self.text_threshold = self.get_float_param("~text_threshold", default=0.2)
self.iou_threshold = self.get_float_param("~iou_threshold", default=0.5)
self.device = self.get_required_param("~device", default="cuda")
def load_grounding_dino_model(self):
args = SLConfig.fromfile(self.config)
args.device = self.device
model = build_model(args)
checkpoint = torch.load(self.grounded_checkpoint, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model = model.eval().to(self.device)
return model
def load_ram_model(self):
ram_model = tag2text.ram(pretrained=self.ram_checkpoint, image_size=384, vit='swin_l')
ram_model = ram_model.eval().to(self.device)
return ram_model
def load_sam_model(self):
if self.use_sam_hq:
predictor = SamPredictor(build_sam_hq(checkpoint=self.sam_hq_checkpoint).to(self.device))
else:
predictor = SamPredictor(build_sam(checkpoint=self.sam_checkpoint).to(self.device))
return predictor
def load_all_models(self):
grounding_dino_model = self.load_grounding_dino_model()
ram_model = self.load_ram_model()
sam_predictor = self.load_sam_model()
return grounding_dino_model, ram_model, sam_predictor
def show_mask(self, mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(self, box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.text(x0, y0, label)
def image_callback(self, data):
self.latest_image_msg = data
def timer_callback(self, event):
if self.latest_image_msg is None:
return
# Process the image using your models
mask = self.generate_mask_and_tags(self.latest_image_msg)
# If you have other publishers, you can publish the result here.
def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
scores.append(logit.max().item())
return boxes_filt, torch.Tensor(scores), pred_phrases
def load_image_from_msg(self, image_msg):
# Convert sensor_msgs/Image to PIL Image
cv_image = self.bridge.imgmsg_to_cv2(image_msg, "rgb8")
image_pil = PILImage.fromarray(cv_image)
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image_pil, image
def generate_mask_and_tags(self, image_msg):
image_pil, image = self.load_image_from_msg(image_msg)
normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = TS.Compose([TS.Resize((384, 384)), TS.ToTensor(), normalize])
raw_image = image_pil.resize((384, 384))
raw_image = transform(raw_image).unsqueeze(0).to(self.device)
# Get tags using RAM model
res = inference_ram.inference(raw_image, self.ram_model)
# Currently ", " is better for detecting single tags
# while ". " is a little worse in some case
tags = res[0].replace(' |', ',')
print("Image Tags: ", res[0])
# Grounding DINO
boxes_filt, scores, pred_phrases = self.get_grounding_output(
self.grounding_dino_model, image, tags, self.box_threshold, self.text_threshold, device=self.device
)
# Initialize SAM
predictor = self.sam_predictor
predictor.set_image(image)
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, self.iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(self.device)
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.device),
multimask_output=False,
)
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
self.show_box(box.numpy(), plt.gca(), label)
return masks
if __name__ == "__main__":
node = GroundedSAMServiceNode()
rospy.spin()