Part 03 : Deploying a trained Semantic segmentation model
This article builds upon the previous two post, viz., training, inference and shows how a live video stream from intel realsense camera could be used to do inference. This only gives a basic idea and there are other methods like tensorrt
, torchscript
, onnx
etc for faster inference.
The implementation uses ROS but it can be done otherwise as well using the realsense SDK
import rospy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
import cv2
import torch
import numpy as np
import utils_fn as util
from lightningModel import OurModel
class Deploy(object):
def __init__(self, modelPath="chk_pts/trained_cityscapes_final.pth", lightning=False):
# Params
self.image = None
self.lightning = lightning
self.br = CvBridge()
# Node cycle rate (in Hz).
self.loop_rate = rospy.Rate(1)
self.modelPath = modelPath
self.loadModel(self.modelPath, lightning=self.lightning)
self.imgPath = "../datasets/cityscapes/leftImg8bit/img.png"
# Subscribers
rospy.Subscriber("/camera/color/image_raw", Image, self.callback)
def callback(self, msg):
if msg is not None:
rospy.loginfo('Image received...')
self.image = self.br.imgmsg_to_cv2(msg)
def loadModel(self, modelPath=None, lightning=False):
if lightning:
self.model = OurModel(n_classes=20)
self.model.load_state_dict(torch.load('chk_pts/model.pth'))
elif modelPath is not None:
self.model = torch.load(modelPath)["model"]
else:
self.model = torch.load(self.modelPath)["model"]
self.model.to(util.device)
self.model.eval()
def getPrediction(self, img_tensor):
with torch.no_grad():
prediction = self.model(img_tensor.to(util.device)).squeeze(0)
return prediction.detach().cpu()
def postProcess(self, prediction):
t = torch.argmax(prediction, 0)
decoded_output = util.decode_segmap(t)
return cv2.cvtColor(decoded_output.astype(np.float32), cv2.COLOR_RGB2BGR)
def showImage(self, img):
cv2.namedWindow("Image window", cv2.WINDOW_NORMAL)
cv2.imshow("Image window", img)
cv2.waitKey(3)
def start(self):
while not rospy.is_shutdown():
if self.image is not None:
imgCV2 = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
img_tensor = util.transform(image=imgCV2)["image"].unsqueeze(0)
invimg = util.inv_normalize(img_tensor).squeeze(0)
predicted = self.getPrediction(img_tensor)
processed_image = self.postProcess(predicted)
numpy_horizontal = np.hstack((np.moveaxis(invimg.numpy(), 0, 2), processed_image))
self.showImage(numpy_horizontal)
self.loop_rate.sleep()
if __name__ == '__main__':
rospy.init_node("deploy", anonymous=True)
modelPath = "chk_pts/hrnet_latest.pth"
my_node = Deploy(modelPath=modelPath, lightning=True)
my_node.start()
print("done")
Explanation
The cv_bridge
package is used to convert sensor_msgs/Image
to opencv
format. This converted image is sent to the model for performing inference.
Enjoy Reading This Article?
Here are some more articles you might like to read next: