Skip to content

Commit

Permalink
add desqueeze output to nn
Browse files Browse the repository at this point in the history
  • Loading branch information
Serafadam committed Oct 2, 2024
1 parent 37f4c4b commit d8d1b99
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 47 deletions.
2 changes: 0 additions & 2 deletions depthai_filters/config/detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@
camera_i_nn_type: rgb
rgb_i_enable_preview: true
nn_i_enable_passthrough: true
/detection_overlay:
desqueeze: true
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,23 @@
#include "nodelet/nodelet.h"
#include "ros/ros.h"
#include "sensor_msgs/Image.h"
#include "sensor_msgs/CameraInfo.h"
#include "vision_msgs/Detection2DArray.h"

namespace depthai_filters {
class Detection2DOverlay : public nodelet::Nodelet {
public:
void onInit() override;

void overlayCB(const sensor_msgs::ImageConstPtr& preview, const sensor_msgs::CameraInfoConstPtr& info, const vision_msgs::Detection2DArrayConstPtr& detections);
void overlayCB(const sensor_msgs::ImageConstPtr& preview, const vision_msgs::Detection2DArrayConstPtr& detections);

message_filters::Subscriber<sensor_msgs::Image> previewSub;
message_filters::Subscriber<sensor_msgs::CameraInfo> infoSub;
message_filters::Subscriber<vision_msgs::Detection2DArray> detSub;

typedef message_filters::sync_policies::ApproximateTime<sensor_msgs::Image, sensor_msgs::CameraInfo, vision_msgs::Detection2DArray> syncPolicy;
typedef message_filters::sync_policies::ApproximateTime<sensor_msgs::Image, vision_msgs::Detection2DArray> syncPolicy;
std::unique_ptr<message_filters::Synchronizer<syncPolicy>> sync;
ros::Publisher overlayPub;
std::vector<std::string> labelMap = {"background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus",
"car", "cat", "chair", "cow", "diningtable", "dog", "horse",
"motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
bool desqueeze = false;
};
} // namespace depthai_filters
1 change: 0 additions & 1 deletion depthai_filters/launch/example_det2d_overlay.launch
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

<node name="overlay" pkg="nodelet" type="nodelet" output="screen" required="true" args="load depthai_filters/Detection2DOverlay $(arg name)_nodelet_manager">
<remap from="/rgb/preview/image_raw" to="$(arg name)/nn/passthrough/image_raw"/>
<remap from="/rgb/preview/camera_info" to="$(arg name)/nn/passthrough/camera_info"/>
<remap from="/nn/detections" to="$(arg name)/nn/detections"/>

</node>
Expand Down
32 changes: 2 additions & 30 deletions depthai_filters/src/detection2d_overlay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,26 @@ void Detection2DOverlay::onInit() {
auto pNH = getPrivateNodeHandle();
previewSub.subscribe(pNH, "/rgb/preview/image_raw", 1);
detSub.subscribe(pNH, "/nn/detections", 1);
infoSub.subscribe(pNH, "/rgb/preview/camera_info", 1);
sync = std::make_unique<message_filters::Synchronizer<syncPolicy>>(syncPolicy(10), previewSub, infoSub, detSub);
sync = std::make_unique<message_filters::Synchronizer<syncPolicy>>(syncPolicy(10), previewSub, detSub);
pNH.getParam("label_map", labelMap);
pNH.getParam("desqueeze", desqueeze);
sync->registerCallback(std::bind(&Detection2DOverlay::overlayCB, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3));
sync->registerCallback(std::bind(&Detection2DOverlay::overlayCB, this, std::placeholders::_1, std::placeholders::_2));
overlayPub = pNH.advertise<sensor_msgs::Image>("overlay", 10);
}

void Detection2DOverlay::overlayCB(const sensor_msgs::ImageConstPtr& preview,
const sensor_msgs::CameraInfoConstPtr& info,
const vision_msgs::Detection2DArrayConstPtr& detections) {
cv::Mat previewMat = utils::msgToMat(preview, sensor_msgs::image_encodings::BGR8);
auto white = cv::Scalar(255, 255, 255);
auto black = cv::Scalar(0, 0, 0);
auto blue = cv::Scalar(255, 0, 0);

double ratioX = 1.0;
double ratioY = 1.0;
int offsetX = 0;
double offsetY = 0;
// if preview size is less than camera info size
if(previewMat.rows < info->height || previewMat.cols < info->width) {
ratioY = double(info->height) / double(previewMat.rows);
if(desqueeze) {
ratioX = double(info->width) / double(previewMat.cols);
} else {
ratioX = ratioY;
offsetX = (info->width - info->height) / 2.0;
}
} else {
ratioY = double(previewMat.rows) / double(info->height);
if(desqueeze) {
ratioX = double(previewMat.cols) / double(info->width);
} else {
ratioX = double(previewMat.cols) / double(info->width);
}
}
for(auto& detection : detections->detections) {
auto x1 = detection.bbox.center.x - detections->detections[0].bbox.size_x / 2.0;
auto x2 = detection.bbox.center.x + detections->detections[0].bbox.size_x / 2.0;
auto y1 = detection.bbox.center.y - detections->detections[0].bbox.size_y / 2.0;
auto y2 = detection.bbox.center.y + detections->detections[0].bbox.size_y / 2.0;
auto labelStr = labelMap[detection.results[0].id];
auto confidence = detection.results[0].score;
x1 = x1 * ratioX + offsetX;
x2 = x2 * ratioX + offsetX;
y1 = y1 * ratioY + offsetY;
y2 = y2 * ratioY + offsetY;
cv::putText(previewMat, labelStr, cv::Point(x1 + 10, y1 + 20), cv::FONT_HERSHEY_TRIPLEX, 0.5, white, 3);
cv::putText(previewMat, labelStr, cv::Point(x1 + 10, y1 + 20), cv::FONT_HERSHEY_TRIPLEX, 0.5, black);
std::stringstream confStr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class Detection : public BaseNode {
if(ph->getParam<bool>("i_disable_resize")) {
width = ph->getOtherNodeParam<int>(socketName, "i_preview_width");
height = ph->getOtherNodeParam<int>(socketName, "i_preview_height");
} else if(ph->getParam<bool>("i_desqueeze_output")) {
width = ph->getOtherNodeParam<int>(socketName, "i_width");
height = ph->getOtherNodeParam<int>(socketName, "i_height");
} else {
width = imageManip->initialConfig.getResizeConfig().width;
height = imageManip->initialConfig.getResizeConfig().height;
Expand All @@ -86,12 +89,12 @@ class Detection : public BaseNode {
convConf.updateROSBaseTimeOnRosMsg = ph->getParam<bool>("i_update_ros_base_time_on_ros_msg");

utils::ImgPublisherConfig pubConf;
pubConf.width = width;
pubConf.height = height;
pubConf.width = width;
pubConf.height = height;
pubConf.daiNodeName = getName();
pubConf.topicName = getName() + "/passthrough";
pubConf.infoSuffix = "/passthrough";
pubConf.infoMgrSuffix = "/passthrough";
pubConf.infoMgrSuffix = "/passthrough";
pubConf.socket = static_cast<dai::CameraBoardSocket>(ph->getParam<int>("i_board_socket_id"));

ptPub->setup(device, convConf, pubConf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class NNParamHandler : public BaseParamHandler {
template <typename T>
void declareParams(std::shared_ptr<T> nn, std::shared_ptr<dai::node::ImageManip> imageManip) {
declareAndLogParam<bool>("i_disable_resize", false);
declareAndLogParam<bool>("i_desqueeze_output", false);
declareAndLogParam<bool>("i_enable_passthrough", false);
declareAndLogParam<bool>("i_enable_passthrough_depth", false);
declareAndLogParam<bool>("i_get_base_device_timestamp", false);
Expand Down Expand Up @@ -124,4 +125,4 @@ class NNParamHandler : public BaseParamHandler {
std::vector<std::string> labels;
};
} // namespace param_handlers
} // namespace depthai_ros_driver
} // namespace depthai_ros_driver
6 changes: 3 additions & 3 deletions depthai_ros_driver/src/dai_nodes/nn/nn_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ NNWrapper::NNWrapper(const std::string& daiNodeName, ros::NodeHandle node, std::
auto family = ph->getNNFamily();
switch(family) {
case param_handlers::nn::NNFamily::Yolo: {
nnNode = std::make_unique<dai_nodes::nn::Detection<dai::node::YoloDetectionNetwork>>(getName(), getROSNode(), pipeline);
nnNode = std::make_unique<dai_nodes::nn::Detection<dai::node::YoloDetectionNetwork>>(getName(), getROSNode(), pipeline, socket);
break;
}
case param_handlers::nn::NNFamily::Mobilenet: {
nnNode = std::make_unique<dai_nodes::nn::Detection<dai::node::MobileNetDetectionNetwork>>(getName(), getROSNode(), pipeline);
nnNode = std::make_unique<dai_nodes::nn::Detection<dai::node::MobileNetDetectionNetwork>>(getName(), getROSNode(), pipeline, socket);
break;
}
case param_handlers::nn::NNFamily::Segmentation: {
nnNode = std::make_unique<dai_nodes::nn::Segmentation>(getName(), getROSNode(), pipeline);
nnNode = std::make_unique<dai_nodes::nn::Segmentation>(getName(), getROSNode(), pipeline, socket);
break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions depthai_ros_driver/src/dai_nodes/nn/spatial_nn_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ SpatialNNWrapper::SpatialNNWrapper(const std::string& daiNodeName,
auto family = ph->getNNFamily();
switch(family) {
case param_handlers::nn::NNFamily::Yolo: {
nnNode = std::make_unique<dai_nodes::nn::SpatialDetection<dai::node::YoloSpatialDetectionNetwork>>(getName(), getROSNode(), pipeline);
nnNode = std::make_unique<dai_nodes::nn::SpatialDetection<dai::node::YoloSpatialDetectionNetwork>>(getName(), getROSNode(), pipeline, socket);
break;
}
case param_handlers::nn::NNFamily::Mobilenet: {
nnNode = std::make_unique<dai_nodes::nn::SpatialDetection<dai::node::MobileNetSpatialDetectionNetwork>>(getName(), getROSNode(), pipeline);
nnNode = std::make_unique<dai_nodes::nn::SpatialDetection<dai::node::MobileNetSpatialDetectionNetwork>>(getName(), getROSNode(), pipeline, socket);
break;
}
case param_handlers::nn::NNFamily::Segmentation: {
Expand Down

0 comments on commit d8d1b99

Please sign in to comment.