diff --git a/include/twist_mux/topic_handle.hpp b/include/twist_mux/topic_handle.hpp index 6c02471..6e51e8a 100644 --- a/include/twist_mux/topic_handle.hpp +++ b/include/twist_mux/topic_handle.hpp @@ -154,12 +154,19 @@ class TopicHandle_ T msg_; }; -class VelocityTopicHandle : public TopicHandle_ +template +class VelocityTopicHandle : public TopicHandle_ { private: - typedef TopicHandle_ base_type; + typedef TopicHandle_ base_type; public: + using base_type::subscriber_; + using base_type::mux_; + using base_type::topic_; + using base_type::stamp_; + using base_type::msg_; + typedef typename base_type::priority_type priority_type; VelocityTopicHandle( @@ -167,7 +174,7 @@ class VelocityTopicHandle : public TopicHandle_ priority_type priority, TwistMux * mux) : base_type(name, topic, timeout, priority, mux) { - subscriber_ = mux_->create_subscription( + subscriber_ = mux_->template create_subscription( topic_, rclcpp::SystemDefaultsQoS(), std::bind(&VelocityTopicHandle::callback, this, std::placeholders::_1)); } @@ -175,48 +182,10 @@ class VelocityTopicHandle : public TopicHandle_ bool isMasked(priority_type lock_priority) const { // std::cout << hasExpired() << " / " << (getPriority() < lock_priority) << std::endl; - return hasExpired() || (getPriority() < lock_priority); - } - - void callback(const geometry_msgs::msg::Twist::ConstSharedPtr msg) - { - stamp_ = mux_->now(); - msg_ = *msg; - - // Check if this twist has priority. - // Note that we have to check all the locks because they might time out - // and since we have several topics we must look for the highest one in - // all the topic list; so far there's no O(1) solution. - if (mux_->hasPriority(*this)) { - mux_->publishTwist(msg); - } - } -}; - -class VelocityStampedTopicHandle : public TopicHandle_ -{ -private: - typedef TopicHandle_ base_type; - -public: - typedef typename base_type::priority_type priority_type; - - VelocityStampedTopicHandle( - const std::string & name, const std::string & topic, const rclcpp::Duration & timeout, - priority_type priority, TwistMux * mux) - : base_type(name, topic, timeout, priority, mux) - { - subscriber_ = mux_->create_subscription( - topic_, rclcpp::SystemDefaultsQoS(), - std::bind(&VelocityStampedTopicHandle::callback, this, std::placeholders::_1)); - } - - bool isMasked(priority_type lock_priority) const - { - return hasExpired() || (getPriority() < lock_priority); + return base_type::hasExpired() || (base_type::getPriority() < lock_priority); } - void callback(const geometry_msgs::msg::TwistStamped::ConstSharedPtr msg) + void callback(const typename T::ConstSharedPtr msg) { stamp_ = mux_->now(); msg_ = *msg; @@ -225,8 +194,8 @@ class VelocityStampedTopicHandle : public TopicHandle_hasPriorityStamped(*this)) { - mux_->publishTwistStamped(msg); + if (mux_->template hasPriority(*this)) { + mux_->template publishTwist(msg); } } }; @@ -244,7 +213,7 @@ class LockTopicHandle : public TopicHandle_ priority_type priority, TwistMux * mux) : base_type(name, topic, timeout, priority, mux) { - subscriber_ = mux_->create_subscription( + subscriber_ = mux_->template create_subscription( topic_, rclcpp::SystemDefaultsQoS(), std::bind(&LockTopicHandle::callback, this, std::placeholders::_1)); } diff --git a/include/twist_mux/twist_mux.hpp b/include/twist_mux/twist_mux.hpp index 6715caa..9b69b1d 100644 --- a/include/twist_mux/twist_mux.hpp +++ b/include/twist_mux/twist_mux.hpp @@ -51,8 +51,8 @@ namespace twist_mux // Forwarding declarations: class TwistMuxDiagnostics; struct TwistMuxDiagnosticsStatus; +template class VelocityTopicHandle; -class VelocityStampedTopicHandle; class LockTopicHandle; /** @@ -64,9 +64,11 @@ class TwistMux : public rclcpp::Node public: template using handle_container = std::list; + using velocity_handle_variant = std::variant, VelocityTopicHandle>; + using publisher_variant = std::variant::SharedPtr, rclcpp::Publisher::SharedPtr>; + using message_variant = std::variant; - using velocity_topic_container = handle_container; - using velocity_stamped_topic_container = handle_container; + using velocity_topic_container = handle_container; using lock_topic_container = handle_container; TwistMux(); @@ -74,13 +76,12 @@ class TwistMux : public rclcpp::Node void init(); - bool hasPriority(const VelocityTopicHandle & twist); + template + bool hasPriority(const VelocityTopicHandleT & twist); - bool hasPriorityStamped(const VelocityStampedTopicHandle & twist); - - void publishTwist(const geometry_msgs::msg::Twist::ConstSharedPtr & msg); - - void publishTwistStamped(const geometry_msgs::msg::TwistStamped::ConstSharedPtr & msg); + + template + void publishTwist(const MessageConstSharedPtrT & msg); void updateDiagnostics(); @@ -99,14 +100,12 @@ class TwistMux : public rclcpp::Node * must reserve the number of handles initially. */ std::shared_ptr velocity_hs_; - std::shared_ptr velocity_stamped_hs_; std::shared_ptr lock_hs_; - rclcpp::Publisher::SharedPtr cmd_pub_; - rclcpp::Publisher::SharedPtr cmd_pub_stamped_; + publisher_variant cmd_pub_; + message_variant last_cmd_; - geometry_msgs::msg::Twist last_cmd_; - geometry_msgs::msg::TwistStamped last_cmd_stamped_; + bool output_stamped; template void getTopicHandles(const std::string & param_name, handle_container & topic_hs); diff --git a/include/twist_mux/twist_mux_diagnostics_status.hpp b/include/twist_mux/twist_mux_diagnostics_status.hpp index 8bbd5d8..a90a4fc 100644 --- a/include/twist_mux/twist_mux_diagnostics_status.hpp +++ b/include/twist_mux/twist_mux_diagnostics_status.hpp @@ -55,21 +55,16 @@ struct TwistMuxDiagnosticsStatus LockTopicHandle::priority_type priority; - bool use_stamped; - std::shared_ptr velocity_hs; - std::shared_ptr velocity_stamped_hs; std::shared_ptr lock_hs; TwistMuxDiagnosticsStatus() : reading_age(0), last_loop_update(rclcpp::Clock().now()), main_loop_time(0), - priority(0), - use_stamped(true) + priority(0) { velocity_hs = std::make_shared(); - velocity_stamped_hs = std::make_shared(); lock_hs = std::make_shared(); } }; diff --git a/launch/twist_mux_launch.py b/launch/twist_mux_launch.py index e7c3824..24d655d 100644 --- a/launch/twist_mux_launch.py +++ b/launch/twist_mux_launch.py @@ -52,13 +52,18 @@ def generate_launch_description(): 'use_sim_time', default_value='False', description='Use simulation time'), + DeclareLaunchArgument( + 'output_stamped', + default_value=False, + description='Output as geometry_msgs/TwistStamped instead of geometry_msgs/Twist'), Node( package='twist_mux', executable='twist_mux', output='screen', remappings={('/cmd_vel_out', LaunchConfiguration('cmd_vel_out'))}, parameters=[ - {'use_sim_time': LaunchConfiguration('use_sim_time')}, + {'use_sim_time': LaunchConfiguration('use_sim_time'), + 'output_stamped': LaunchConfiguration('output_stamped')}, LaunchConfiguration('config_locks'), LaunchConfiguration('config_topics')] ), diff --git a/src/twist_marker.cpp b/src/twist_marker.cpp index e693f8e..71d69e7 100644 --- a/src/twist_marker.cpp +++ b/src/twist_marker.cpp @@ -175,4 +175,4 @@ int main(int argc, char * argv[]) rclcpp::shutdown(); return EXIT_SUCCESS; -} +} \ No newline at end of file diff --git a/src/twist_mux.cpp b/src/twist_mux.cpp index 873e79e..3918520 100644 --- a/src/twist_mux.cpp +++ b/src/twist_mux.cpp @@ -71,54 +71,39 @@ constexpr std::chrono::duration TwistMux::DIAGNOSTICS_PERIOD; TwistMux::TwistMux() : Node("twist_mux", "", rclcpp::NodeOptions().allow_undeclared_parameters( - true).automatically_declare_parameters_from_overrides(true)) + true).automatically_declare_parameters_from_overrides(true)), output_stamped(false) { } void TwistMux::init() { - // Get use stamped parameter - bool use_stamped; - auto nh = std::shared_ptr(this, [](rclcpp::Node *) {}); - fetch_param(nh, "use_stamped", use_stamped); - /// Get topics and locks: - if(use_stamped) - { - velocity_stamped_hs_ = std::make_shared(); - getTopicHandles("topics", *velocity_stamped_hs_); - } - else - { - velocity_hs_ = std::make_shared(); - getTopicHandles("topics", *velocity_hs_); - } + velocity_hs_ = std::make_shared(); lock_hs_ = std::make_shared(); + getTopicHandles("topics", *velocity_hs_); getTopicHandles("locks", *lock_hs_); - - /// Publisher for output topic: - if(use_stamped) - { - cmd_pub_stamped_ = - this->create_publisher( - "cmd_vel_out", - rclcpp::QoS(rclcpp::KeepLast(1))); - } - else + + try { + output_stamped = get_parameter("output_stamped").as_bool(); + } catch (const rclcpp::exceptions::ParameterNotDeclaredException& e) { - cmd_pub_ = - this->create_publisher( - "cmd_vel_out", - rclcpp::QoS(rclcpp::KeepLast(1))); + declare_parameter("output_stamped", false); } + /// Publisher for output topic: + if (output_stamped) { + cmd_pub_ = this->create_publisher( + "cmd_vel_out", rclcpp::QoS(rclcpp::KeepLast(1))); + } else { + cmd_pub_ = this->create_publisher( + "cmd_vel_out", rclcpp::QoS(rclcpp::KeepLast(1))); + } + /// Diagnostics: diagnostics_ = std::make_shared(this); status_ = std::make_shared(); status_->velocity_hs = velocity_hs_; - status_->velocity_stamped_hs = velocity_stamped_hs_; status_->lock_hs = lock_hs_; - status_->use_stamped = use_stamped; diagnostics_timer_ = this->create_wall_timer( DIAGNOSTICS_PERIOD, [this]() -> void { @@ -129,20 +114,51 @@ void TwistMux::init() void TwistMux::updateDiagnostics() { status_->priority = getLockPriority(); - RCLCPP_DEBUG(get_logger(), "updateDiagnostics: lol"); diagnostics_->updateStatus(status_); - RCLCPP_DEBUG(get_logger(), "returned from updateStatus"); } -void TwistMux::publishTwist(const geometry_msgs::msg::Twist::ConstSharedPtr & msg) +template +void TwistMux::publishTwist(const MessageConstSharedPtrT & msg) { - cmd_pub_->publish(*msg); + std::visit([&msg, this](auto&& pub) { + /* + There are four possible combinations: + In -> Out + 1. TwistStamped -> TwistStamped + 2. TwistStamped -> Twist + 3. Twist -> TwistStamped + 4. Twist -> Twist + */ + + // Decide based on output_stamped at runtime + if (output_stamped) { + if (auto twist_stamped_pub = std::dynamic_pointer_cast>(pub)) { + // If we have a TwistStamped publisher and the output needs to be TwistStamped + if constexpr (std::is_same_v, geometry_msgs::msg::TwistStamped::ConstSharedPtr>) { + twist_stamped_pub->publish(*msg); // Publish TwistStamped directly + } else if constexpr (std::is_same_v, geometry_msgs::msg::Twist::ConstSharedPtr>) { + geometry_msgs::msg::TwistStamped twist_stamped_msg; + twist_stamped_msg.twist = *msg; // Wrap Twist in TwistStamped + twist_stamped_pub->publish(twist_stamped_msg); // Publish the wrapped message + } + } else { + RCLCPP_FATAL(get_logger(), "Expected TwistStamped publisher, but received different type."); + } + } else { + if (auto twist_pub = std::dynamic_pointer_cast>(pub)) { + // If we have a Twist publisher and the output needs to be Twist + if constexpr (std::is_same_v, geometry_msgs::msg::TwistStamped::ConstSharedPtr>) { + twist_pub->publish(msg->twist); // Extract Twist from TwistStamped and publish + } else if constexpr (std::is_same_v, geometry_msgs::msg::Twist::ConstSharedPtr>) { + twist_pub->publish(*msg); // Publish Twist directly + } + } else { + RCLCPP_FATAL(get_logger(), "Expected Twist publisher, but received different type."); + } + } + }, cmd_pub_); } -void TwistMux::publishTwistStamped(const geometry_msgs::msg::TwistStamped::ConstSharedPtr & msg) -{ - cmd_pub_stamped_->publish(*msg); -} template void TwistMux::getTopicHandles(const std::string & param_name, std::list & topic_hs) @@ -158,6 +174,7 @@ void TwistMux::getTopicHandles(const std::string & param_name, std::list & to std::string topic; double timeout = 0; int priority = 0; + bool stamped = false; auto nh = std::shared_ptr(this, [](rclcpp::Node *) {}); @@ -168,8 +185,23 @@ void TwistMux::getTopicHandles(const std::string & param_name, std::list & to RCLCPP_DEBUG(get_logger(), "Retrieved topic: %s", topic.c_str()); RCLCPP_DEBUG(get_logger(), "Listed prefix: %.2f", timeout); RCLCPP_DEBUG(get_logger(), "Listed prefix: %d", priority); - - topic_hs.emplace_back(prefix, topic, std::chrono::duration(timeout), priority, this); + + if constexpr (std::is_same_v){ + try { + fetch_param(nh, prefix + ".stamped", stamped); + } catch (const ParamsHelperException& e) { + RCLCPP_WARN(get_logger(), ".stamped is not defined, false is assumed."); + } + if(stamped) { + topic_hs.emplace_back(std::in_place_type>, + prefix, topic, std::chrono::duration(timeout), priority, this); + } else { + topic_hs.emplace_back(std::in_place_type>, + prefix, topic, std::chrono::duration(timeout), priority, this); + } + } else { + topic_hs.emplace_back(prefix, topic, std::chrono::duration(timeout), priority, this); + } } } catch (const ParamsHelperException & e) { RCLCPP_FATAL(get_logger(), "Error parsing params '%s':\n\t%s", param_name.c_str(), e.what()); @@ -197,7 +229,8 @@ int TwistMux::getLockPriority() return priority; } -bool TwistMux::hasPriority(const VelocityTopicHandle & twist) +template +bool TwistMux::hasPriority(const VelocityTopicHandleT & twist) { const auto lock_priority = getLockPriority(); @@ -207,36 +240,15 @@ bool TwistMux::hasPriority(const VelocityTopicHandle & twist) /// max_element on the priority of velocity topic handles satisfying /// that is NOT masked by the lock priority: for (const auto & velocity_h : *velocity_hs_) { - if (!velocity_h.isMasked(lock_priority)) { - const auto velocity_priority = velocity_h.getPriority(); - if (priority < velocity_priority) { - priority = velocity_priority; - velocity_name = velocity_h.getName(); - } - } - } - - return twist.getName() == velocity_name; -} - - -bool TwistMux::hasPriorityStamped(const VelocityStampedTopicHandle & twist) -{ - const auto lock_priority = getLockPriority(); - - LockTopicHandle::priority_type priority = 0; - std::string velocity_name = "NULL"; - - /// max_element on the priority of velocity topic handles satisfying - /// that is NOT masked by the lock priority: - for (const auto & velocity_stamped_h : *velocity_stamped_hs_) { - if (!velocity_stamped_h.isMasked(lock_priority)) { - const auto velocity_priority = velocity_stamped_h.getPriority(); - if (priority < velocity_priority) { - priority = velocity_priority; - velocity_name = velocity_stamped_h.getName(); + std::visit([&](const auto& handle) { + if (!handle.isMasked(lock_priority)) { + const auto velocity_priority = handle.getPriority(); + if (priority < velocity_priority) { + priority = velocity_priority; + velocity_name = handle.getName(); + } } - } + }, velocity_h); } return twist.getName() == velocity_name; diff --git a/src/twist_mux_diagnostics.cpp b/src/twist_mux_diagnostics.cpp index 22ad29c..82f5119 100644 --- a/src/twist_mux_diagnostics.cpp +++ b/src/twist_mux_diagnostics.cpp @@ -58,13 +58,11 @@ void TwistMuxDiagnostics::update() void TwistMuxDiagnostics::updateStatus(const status_type::ConstPtr & status) { status_->velocity_hs = status->velocity_hs; - status_->velocity_stamped_hs = status->velocity_stamped_hs; status_->lock_hs = status->lock_hs; status_->priority = status->priority; status_->main_loop_time = status->main_loop_time; status_->reading_age = status->reading_age; - status_->use_stamped = status->use_stamped; update(); } @@ -80,25 +78,14 @@ void TwistMuxDiagnostics::diagnostics(diagnostic_updater::DiagnosticStatusWrappe stat.summary(OK, "ok"); } - if(status_->use_stamped) - { - for (auto & velocity_stamped_h : *status_->velocity_stamped_hs) { - stat.addf( - "velocity " + velocity_stamped_h.getName(), " %s (listening to %s @ %fs with priority #%d)", - (velocity_stamped_h.isMasked(status_->priority) ? "masked" : "unmasked"), - velocity_stamped_h.getTopic().c_str(), - velocity_stamped_h.getTimeout().seconds(), static_cast(velocity_stamped_h.getPriority())); - } - } - else - { - for (auto & velocity_h : *status_->velocity_hs) { + for (auto & velocity_h : *status_->velocity_hs) { + std::visit([&stat, this](auto&& velocity_h) { stat.addf( "velocity " + velocity_h.getName(), " %s (listening to %s @ %fs with priority #%d)", (velocity_h.isMasked(status_->priority) ? "masked" : "unmasked"), velocity_h.getTopic().c_str(), velocity_h.getTimeout().seconds(), static_cast(velocity_h.getPriority())); - } + }, velocity_h ); } for (const auto & lock_h : *status_->lock_hs) {