Skip to content

Commit

Permalink
Fix body pose bugs (#174)
Browse files Browse the repository at this point in the history
* add mirroring

* rename score to confidence

* update mirror keypoint function

* resize video detection keypoints
  • Loading branch information
ziyuan-linn authored Jul 19, 2024
1 parent 17b6637 commit 542e5a5
Showing 1 changed file with 71 additions and 8 deletions.
79 changes: 71 additions & 8 deletions src/BodyPose/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { mediaReady } from "../utils/imageUtilities";
import handleOptions from "../utils/handleOptions";
import { handleModelName } from "../utils/handleOptions";
import objectRenameKey from "../utils/objectRenameKey";
import { isVideo } from "../utils/handleArguments";

class BodyPose {
/**
Expand Down Expand Up @@ -208,6 +209,17 @@ class BodyPose {
modelConfig.modelType =
poseDetection.movenet.modelType.MULTIPOSE_LIGHTNING;
}
this.runtimeConfig = handleOptions(
this.config,
{
flipHorizontal: {
type: "boolean",
alias: "flipped",
default: false,
},
},
"bodyPose"
);
}

// Load the detector model
Expand Down Expand Up @@ -242,13 +254,22 @@ class BodyPose {
const { image, callback } = argumentObject;

await mediaReady(image, false);
const predictions = await this.model.estimatePoses(
image,
this.runtimeConfig
);
const predictions = await this.model.estimatePoses(image);
let result = predictions;
// modify the raw tfjs output to a more usable format
this.renameScoreToConfidence(result);
if (this.modelName === "MoveNet" && isVideo(image)) {
this.resizeKeypoints(
result,
image.videoWidth,
image.videoHeight,
image.width,
image.height
);
}
if (this.runtimeConfig.flipHorizontal) {
this.mirrorKeypoints(result, image.width);
}
this.addKeypoints(result);
this.resizeBoundingBoxes(result, image.width, image.height);

Expand Down Expand Up @@ -298,12 +319,21 @@ class BodyPose {
async detectLoop() {
await mediaReady(this.detectMedia, false);
while (!this.signalStop) {
const predictions = await this.model.estimatePoses(
this.detectMedia,
this.runtimeConfig
);
const predictions = await this.model.estimatePoses(this.detectMedia);
let result = predictions;
this.renameScoreToConfidence(result);
if (this.modelName === "MoveNet" && isVideo(this.detectMedia)) {
this.resizeKeypoints(
result,
this.detectMedia.videoWidth,
this.detectMedia.videoHeight,
this.detectMedia.width,
this.detectMedia.height
);
}
if (this.runtimeConfig.flipHorizontal) {
this.mirrorKeypoints(result, this.detectMedia.width);
}
this.addKeypoints(result);
this.resizeBoundingBoxes(
result,
Expand Down Expand Up @@ -340,6 +370,39 @@ class BodyPose {
objectRenameKey(keypoint, "score", "confidence");
});
}
if (pose.score) objectRenameKey(pose, "score", "confidence");
});
}

/**
* Mirror the keypoints around x-axis.
* @param {HTMLVideoElement | HTMLImageElement | HTMLCanvasElement} detectMedia
* @param {Object} poses - the original detection results.
* @private
*/
mirrorKeypoints(poses, mediaWidth) {
poses.forEach((pose) => {
pose.keypoints.forEach((keypoint) => {
keypoint.x = mediaWidth - keypoint.x;
});
});
}

/**
* Resize the keypoints output of moveNet model to match the display size.
*
* @param {Object} poses - the original detection results.
* @param {HTMLVideoElement} mediaWidth - the actual width of the video.
* @param {HTMLVideoElement} mediaHeight- the actual height of the video.
* @param {HTMLVideoElement} displayWidth - the display width of the video.
* @param {HTMLVideoElement} displayHeight - the display height of the video.
*/
resizeKeypoints(poses, mediaWidth, mediaHeight, displayWidth, displayHeight) {
poses.forEach((pose) => {
pose.keypoints.forEach((keypoint) => {
keypoint.x = (keypoint.x / mediaWidth) * displayWidth;
keypoint.y = (keypoint.y / mediaHeight) * displayHeight;
});
});
}

Expand Down

0 comments on commit 542e5a5

Please sign in to comment.