From b28b5dc276bc05bfe8a592a3a07ec541bb1d1864 Mon Sep 17 00:00:00 2001 From: Jeremy Nimmer Date: Thu, 12 Sep 2024 10:57:05 -0700 Subject: [PATCH] [pydrake] Bind Joint.kTypeName constants (#21896) --- .../pydrake/examples/gym/envs/cart_pole.py | 10 +++++++--- bindings/pydrake/multibody/test/plant_test.py | 1 + bindings/pydrake/multibody/tree_py.cc | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/bindings/pydrake/examples/gym/envs/cart_pole.py b/bindings/pydrake/examples/gym/envs/cart_pole.py index a95ed39e6ed2..1448d9e0067a 100644 --- a/bindings/pydrake/examples/gym/envs/cart_pole.py +++ b/bindings/pydrake/examples/gym/envs/cart_pole.py @@ -30,6 +30,10 @@ MultibodyPlant, MultibodyPlantConfig, ) +from pydrake.multibody.tree import ( + PrismaticJoint, + RevoluteJoint, +) from pydrake.systems.analysis import Simulator from pydrake.systems.drawing import plot_graphviz, plot_system_graphviz from pydrake.systems.framework import ( @@ -338,14 +342,14 @@ def reset_handler(simulator, diagram_context, seed): # Ensure the positions are within the joint limits. for pair in home_positions: joint = plant.GetJointByName(pair[0]) - if joint.type_name() == "revolute": + if joint.type_name() == RevoluteJoint.kTypeName: joint.set_angle(plant_context, np.clip(pair[1], joint.position_lower_limit(), joint.position_upper_limit() ) ) - if joint.type_name() == "prismatic": + if joint.type_name() == PrismaticJoint.kTypeName: joint.set_translation(plant_context, np.clip(pair[1], joint.position_lower_limit(), @@ -354,7 +358,7 @@ def reset_handler(simulator, diagram_context, seed): ) for pair in home_velocities: joint = plant.GetJointByName(pair[0]) - if joint.type_name() == "revolute": + if joint.type_name() == RevoluteJoint.kTypeName: joint.set_angular_rate(plant_context, np.clip(pair[1], joint.velocity_lower_limit(), diff --git a/bindings/pydrake/multibody/test/plant_test.py b/bindings/pydrake/multibody/test/plant_test.py index 10a64aead2c6..e9ac224a9da5 100644 --- a/bindings/pydrake/multibody/test/plant_test.py +++ b/bindings/pydrake/multibody/test/plant_test.py @@ -654,6 +654,7 @@ def _test_joint_api(self, T, joint): self._test_multibody_tree_element_mixin(T, joint) self.assertIsInstance(joint.name(), str) self.assertIsInstance(joint.type_name(), str) + self.assertEqual(joint.type_name(), joint.kTypeName) self.assertIsInstance(joint.parent_body(), Body) self.assertIsInstance(joint.child_body(), Body) self.assertIsInstance(joint.frame_on_parent(), Frame) diff --git a/bindings/pydrake/multibody/tree_py.cc b/bindings/pydrake/multibody/tree_py.cc index a138ccf4674b..4297e30cb060 100644 --- a/bindings/pydrake/multibody/tree_py.cc +++ b/bindings/pydrake/multibody/tree_py.cc @@ -492,6 +492,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "BallRpyJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def( py::init&, const Frame&, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -523,6 +525,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "PlanarJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, Vector3>(), py::arg("name"), py::arg("frame_on_parent"), @@ -572,6 +576,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "PrismaticJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, const Vector3&, double, double, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -628,6 +634,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "QuaternionFloatingJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, double, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -688,6 +696,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "RevoluteJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, const Vector3&, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -744,6 +754,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "RpyFloatingJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, double, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -801,6 +813,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "ScrewJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, double, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -854,6 +868,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "UniversalJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def( py::init&, const Frame&, double>(), py::arg("name"), py::arg("frame_on_parent"), @@ -884,6 +900,8 @@ void DoScalarDependentDefinitions(py::module m, T) { auto cls = DefineTemplateClassWithDefault>( m, "WeldJoint", param, cls_doc.doc); cls // BR + .def_property_readonly_static( + "kTypeName", [](py::object /* self */) { return Class::kTypeName; }) .def(py::init&, const Frame&, const RigidTransform&>(), py::arg("name"), py::arg("frame_on_parent_F"),