Created
February 3, 2021 13:32
-
-
Save SETIADEEPANSHU/3c2103bc49844471fd19fc3db21bdd70 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import trt_pose.coco | |
import torch | |
import torch2trt | |
from torch2trt import TRTModule | |
#import cv2 | |
import torchvision.transforms as transforms | |
import PIL.Image | |
from trt_pose.parse_objects import ParseObjects | |
import os.path | |
from inprotobuf import rgb_dl_pb2 | |
import glob | |
class Executor: | |
def __init__(self,model="densenet"): | |
with open('human_pose.json', 'r') as f: | |
human_pose = json.load(f) | |
topology = trt_pose.coco.coco_category_to_topology(human_pose) | |
self.parse_objects = ParseObjects(topology) | |
num_parts = len(human_pose['keypoints']) | |
num_links = len(human_pose['skeleton']) | |
if model == 'resnet': | |
print('------ model = resnet loaded--------') | |
# MODEL_WEIGHTS = 'resnet18_baseline_att_224x224_A_epoch_249.pth' | |
OPTIMIZED_MODEL = 'resnet18_baseline_att_224x224_A_epoch_249_trt.pth' | |
# model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval() | |
WIDTH = 224 | |
HEIGHT = 224 | |
self.IMAGE_SHAPE = (224, 224) | |
else: | |
print('------ model = densenet loaded--------') | |
# MODEL_WEIGHTS = 'densenet121_baseline_att_256x256_B_epoch_160.pth' | |
OPTIMIZED_MODEL = 'densenet121_baseline_att_256x256_B_epoch_160_trt.pth' | |
# model = trt_pose.models.densenet121_baseline_att(num_parts, 2 * num_links).cuda().eval() | |
WIDTH = 256 | |
HEIGHT = 256 | |
self.IMAGE_SHAPE = (256, 256) | |
if os.path.exists(OPTIMIZED_MODEL) == False: | |
data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda() | |
model.load_state_dict(torch.load(MODEL_WEIGHTS)) | |
self.model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25) | |
torch.save(self.model_trt.state_dict(), OPTIMIZED_MODEL) | |
self.model_trt = TRTModule() | |
self.model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL)) | |
self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda() | |
self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda() | |
self.device = torch.device('cuda') | |
''' | |
hnum: 0 based human index | |
kpoint : index + keypoints (float type range : 0.0 ~ 1.0 ==> later multiply by image width, height) | |
''' | |
def get_keypoint(width,height, humans, hnum, peaks): | |
w, h = width, height | |
#check invalid human index | |
kpoint = [] | |
human = humans[0][hnum] | |
C = human.shape[0] | |
pose = rgb_dl_pb2.PB_HumanPose() | |
score = 0.0 | |
for j in range(C): | |
k = int(human[j]) | |
keypoint = rgb_dl_pb2.PB_Keypoint() | |
if k >= 0: | |
peak = peaks[0][j][k] # peak[1]:width, peak[0]:height | |
peak = (j, float(peak[0]), float(peak[1])) | |
keypoint.type = j | |
keypoint.score = 1 | |
keypoint.x = int(peak[1] * h) | |
keypoint.y = int(peak[2] * w) | |
pose.keypoints.append(keypoint) | |
kpoint.append(peak) | |
#print('index:%d : success [%5.3f, %5.3f]'%(j, peak[1], peak[2]) ) | |
score += 1.0 | |
else: | |
peak = (j, 0, 0) | |
keypoint.type = j | |
keypoint.score = 0 | |
keypoint.x = 0 | |
keypoint.y = 0 | |
pose.keypoints.append(keypoint) | |
kpoint.append(peak) | |
#print('index:%d : None'%(j) ) | |
pose.score = score / 18.0 | |
# print("Score = ",score / 18.0) | |
return pose | |
''' | |
Draw to original image | |
''' | |
def execute(self): | |
image_paths = [] | |
image_paths = glob.glob("/home/Jon_img/*.jpg") | |
#image_paths.sort() | |
poseslist = [] | |
for imagepath in image_paths: | |
src = PIL.Image.open(imagepath) | |
width, height = src.size | |
image = src.resize(self.IMAGE_SHAPE) | |
image = transforms.functional.to_tensor(image).to(self.device) | |
image.sub_(self.mean[:, None, None]).div_(self.std[:, None, None]) | |
data = image[None, ...] | |
cmap, paf = self.model_trt(data) | |
cmap, paf = cmap.detach().cpu(), paf.detach().cpu() | |
counts, objects, peaks = self.parse_objects(cmap, paf)#, cmap_threshold=0.15, link_threshold=0.15) | |
print("Object counts = ",counts[0]) | |
poses = [] | |
## Multiple people ## | |
for i in range(counts[0]): | |
# print("Human index:%d "%( i )) | |
pose = Executor.get_keypoint(width,height, objects, i, peaks) | |
poseslist.append(poses) | |
return poseslist | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment