You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
4.4 KiB
112 lines
4.4 KiB
import os
|
|
from xml.dom import minidom
|
|
|
|
import cv2
|
|
import mediapipe as mp
|
|
from DataStreamModule import DataStreamModule
|
|
from IHPEModule import IHPEModule
|
|
import xml.etree.ElementTree as ET
|
|
|
|
# initialize mediapipe drawing utilities and pose models
|
|
mp_drawing = mp.solutions.drawing_utils
|
|
mp_pose = mp.solutions.pose
|
|
|
|
|
|
# This class implements Human Pose Estimation (HPE) using the mediapipe library
|
|
class HPEModule(IHPEModule):
|
|
|
|
# This method starts HPE using a camera specified by its name
|
|
def startHPEwithCamera(self, camera_name, export_landmarks):
|
|
out = None
|
|
# check if the camera_name is a file path or not
|
|
if os.path.isfile(camera_name):
|
|
# open the video file using cv2.VideoCapture
|
|
cap = cv2.VideoCapture(camera_name)
|
|
# set the output video file path
|
|
output_path = os.path.splitext(camera_name)[0] + "_output.mp4"
|
|
# get the frame rate and size of the input video
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
# initialize video writer to save the output video
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
else:
|
|
# initialize data stream and camera object
|
|
data_stream = DataStreamModule()
|
|
cap = data_stream.get_camera_stream(camera_name)
|
|
|
|
# set the window name using the camera name
|
|
window_name = f"Pose Estimation on Camera {camera_name}"
|
|
print(export_landmarks.get())
|
|
if export_landmarks.get() is True:
|
|
# create the root element
|
|
root = ET.Element("Landmarks")
|
|
|
|
# loop through each landmark
|
|
for i in range(33):
|
|
ET.SubElement(root, f'landmark{i}')
|
|
|
|
framecounter = 0
|
|
|
|
|
|
# start pose detection
|
|
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
|
|
|
|
while True:
|
|
# get the next frame from the camera
|
|
if out is not None:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
else:
|
|
frame = next(cap)
|
|
|
|
# Recolor image to RGB
|
|
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
image.flags.writeable = False
|
|
|
|
# Make detection
|
|
results = pose.process(image)
|
|
|
|
# Recolor back to BGR
|
|
image.flags.writeable = True
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
|
# Extract landmarks
|
|
try:
|
|
landmarks = results.pose_landmarks.landmark
|
|
if export_landmarks.get() is True:
|
|
for j, landmark in enumerate(landmarks):
|
|
frame_element = ET.SubElement(root.find(f'landmark{j}'), f'frame{framecounter}')
|
|
frame_element.set("X", str(landmark.x))
|
|
frame_element.set("Y", str(landmark.y))
|
|
frame_element.set("Z", str(landmark.z))
|
|
frame_element.set("visibility", str(landmark.visibility))
|
|
framecounter += 1
|
|
# print(landmarks)
|
|
except:
|
|
pass
|
|
|
|
# Render detections
|
|
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
|
|
mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=2),
|
|
mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
|
|
)
|
|
|
|
cv2.imshow(window_name, image)
|
|
|
|
if out is not None:
|
|
out.write(image)
|
|
|
|
keyCode = cv2.waitKey(1)
|
|
if cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) < 1:
|
|
break
|
|
|
|
if export_landmarks.get() is True:
|
|
xml_string = ET.tostring(root, encoding="UTF-8")
|
|
parsed_xml = minidom.parseString(xml_string)
|
|
pretty_xml_string = parsed_xml.toprettyxml(indent=" ")
|
|
with open(f"{camera_name}_output.xml", "w") as xml_file:
|
|
xml_file.write(pretty_xml_string)
|
|
|