Skip to content

Commit

Permalink
[pydrake] Use nice names for default template classes (#18972)
Browse files Browse the repository at this point in the history
When a template class has a default type parameter, name the default
instantiation directly using the default name, instead of via an alias.
This makes IDE auto-complete and type annotations more natural.
  • Loading branch information
jwnimmer-tri authored Mar 14, 2023
1 parent 24239b6 commit 91e3dcc
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 80 deletions.
16 changes: 3 additions & 13 deletions bindings/pydrake/_math_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,37 +112,27 @@ def _indented_repr(o):
return repr(o).replace("\n", "\n ")


def _remove_float_suffix(typename):
suffix = "_[float]"
if typename.endswith(suffix):
return typename[:-len(suffix)]
return typename


def _roll_pitch_yaw_repr(rpy):
cls_name = _remove_float_suffix(_pretty_class_name(type(rpy)))
return (
f"{cls_name}("
f"{_pretty_class_name(type(rpy))}("
f"roll={repr(rpy.roll_angle())}, "
f"pitch={repr(rpy.pitch_angle())}, "
f"yaw={repr(rpy.yaw_angle())})")


def _rotation_matrix_repr(R):
cls_name = _remove_float_suffix(_pretty_class_name(type(R)))
M = R.matrix().tolist()
return (
f"{cls_name}([\n"
f"{_pretty_class_name(type(R))}([\n"
f" {_indented_repr(M[0])},\n"
f" {_indented_repr(M[1])},\n"
f" {_indented_repr(M[2])},\n"
f"])")


def _rigid_transform_repr(X):
cls_name = _remove_float_suffix(_pretty_class_name(type(X)))
return (
f"{cls_name}(\n"
f"{_pretty_class_name(type(X))}(\n"
f" R={_indented_repr(X.rotation())},\n"
f" p={_indented_repr(X.translation().tolist())},\n"
f")")
Expand Down
7 changes: 7 additions & 0 deletions bindings/pydrake/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ def __getattr__(name):
name = _MangledName.mangle(name)
if name in module_globals:
return module_globals[name]
float_tag = "_{}float{}".format(
_MangledName.UNICODE_LEFT_BRACKET,
_MangledName.UNICODE_RIGHT_BRACKET)
if name.endswith(float_tag):
shorter_name = name[:-len(float_tag)]
if shorter_name in module_globals:
return module_globals[shorter_name]
raise AttributeError(
f"module {module_name!r} has no attribute {name!r}")

Expand Down
35 changes: 18 additions & 17 deletions bindings/pydrake/common/cpp_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def get_instantiation(self, param=None, throw_error=True):
if instantiation is TemplateBase._deferred:
assert self._instantiation_func is not None
instantiation = self._instantiation_func(param)
self._add_instantiation_internal(param, instantiation)
self._add_instantiation_internal(param, instantiation,
skip_rename=False)
elif instantiation is None and throw_error:
raise RuntimeError("Invalid instantiation: {}".format(
self._instantiation_name(param)))
Expand All @@ -183,7 +184,7 @@ def get_instantiation(self, param=None, throw_error=True):
_warn_deprecated(deprecation.message, date=deprecation.date)
return (instantiation, param)

def add_instantiation(self, param, instantiation):
def add_instantiation(self, param, instantiation, skip_rename=False):
"""Adds a unique instantiation.
Note:
Expand All @@ -196,15 +197,15 @@ def add_instantiation(self, param, instantiation):
"Parameter instantiation already registered: {}".format(param))
# Register it.
self.param_list.append(param)
self._add_instantiation_internal(param, instantiation)
self._add_instantiation_internal(param, instantiation, skip_rename)
return param

def _add_instantiation_internal(self, param, instantiation):
def _add_instantiation_internal(self, param, instantiation, skip_rename):
# Adds instantiation. Permits overwriting for deferred cases.
assert instantiation is not None
if instantiation is not TemplateBase._deferred:
old = instantiation
instantiation = self._on_add(param, instantiation)
instantiation = self._on_add(param, instantiation, skip_rename)
assert instantiation is not None, (self, param, old)
if instantiation is not old:
self._instantiation_alias_map[old] = instantiation
Expand Down Expand Up @@ -322,7 +323,7 @@ def __str__(self):
cls_name = pretty_class_name(type(self))
return "<{} {}>".format(cls_name, self._full_name())

def _on_add(self, param, instantiation):
def _on_add(self, param, instantiation, skip_rename):
# To be overridden by child classes.
return instantiation

Expand Down Expand Up @@ -370,18 +371,16 @@ def decorator(instantiation_func):

class TemplateClass(TemplateBase):
"""Extension of `TemplateBase` for classes."""
def __init__(self, name, override_meta=True, scope=None, **kwargs):
def __init__(self, name, *, scope=None, **kwargs):
if scope is None:
scope = _get_module_from_stack()
TemplateBase.__init__(self, name, scope=scope, **kwargs)
self._override_meta = override_meta

def _on_add(self, param, cls):
if self._override_meta:
# Rename the class now to reflect its `template_name` and `param`.
# C++ templates are initially bound using a `TemporaryClassName()`
# which we overwrite here. Python templates are usually declared as
# a nested class helper, which likewise we need to replace.

def _on_add(self, param, cls, skip_rename):
# Unless this class was a default template instantiation, we need to
# rename it now to describe its template arguments. (Most templated
# C++ classes are bound using the TemporaryClassName() function.)
if not skip_rename:
cls._original_name = cls.__name__
cls._original_qualname = getattr(cls, "__qualname__", cls.__name__)
cls.__name__ = self._instantiation_name(param, mangle=True)
Expand Down Expand Up @@ -444,7 +443,8 @@ def f(*args, **kwargs): return orig(*args, **kwargs)

class TemplateFunction(TemplateBase):
"""Extension of `TemplateBase` for functions."""
def _on_add(self, param, func):
def _on_add(self, param, func, skip_rename):
assert skip_rename is False
new_name = self._instantiation_name(param, mangle=True)
func = _rename_callable(func, self._scope, new_name)
setattr(self._scope, func.__name__, func)
Expand All @@ -461,7 +461,8 @@ def __init__(self, name, cls, scope=None, **kwargs):
# only.
self._cls = cls

def _on_add(self, param, func):
def _on_add(self, param, func, skip_rename):
assert skip_rename is False
new_name = self._instantiation_name(param, mangle=True)
func = _rename_callable(func, self._scope, new_name, self._cls)
setattr(self._cls, func.__name__, func)
Expand Down
39 changes: 26 additions & 13 deletions bindings/pydrake/common/cpp_template_pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ inline py::object GetOrInitTemplate( // BR
}

// Adds instantiation to a Python template.
inline void AddInstantiation(
py::handle py_template, py::handle obj, py::tuple param) {
py_template.attr("add_instantiation")(param, obj);
inline void AddInstantiation(py::handle py_template, py::handle obj,
py::tuple param, bool skip_rename = false) {
py_template.attr("add_instantiation")(param, obj, skip_rename);
}

// Gets name for a given instantiation.
Expand Down Expand Up @@ -69,10 +69,10 @@ std::string TemporaryClassName(const std::string& name = "TemporaryName") {
/// @param param Parameters for the instantiation.
inline py::object AddTemplateClass( // BR
py::handle scope, const std::string& template_name, py::handle py_class,
py::tuple param) {
py::tuple param, bool skip_rename = false) {
py::object py_template =
internal::GetOrInitTemplate(scope, template_name, "TemplateClass");
internal::AddInstantiation(py_template, py_class, param);
internal::AddInstantiation(py_template, py_class, param, skip_rename);
return py_template;
}

Expand All @@ -85,16 +85,29 @@ template <typename Class, typename... Options>
py::class_<Class, Options...> DefineTemplateClassWithDefault( // BR
py::handle scope, const std::string& default_name, py::tuple param,
const char* doc_string = "", const std::string& template_suffix = "_") {
// The default instantiation is immediately assigned its correct class name.
// Other instantiations use a temporary name here that will be overwritten
// by the AddTemplateClass function during registration.
const bool is_default = !py::hasattr(scope, default_name.c_str());
const std::string class_name =
is_default ? default_name : TemporaryClassName<Class>();
const std::string template_name = default_name + template_suffix;
// Define class with temporary name.
py::class_<Class, Options...> py_class(
scope, TemporaryClassName<Class>().c_str(), doc_string);
// Register instantiation.
AddTemplateClass(scope, template_name, py_class, param);
// Declare default instantiation if it does not already exist.
if (!py::hasattr(scope, default_name.c_str())) {
scope.attr(default_name.c_str()) = py_class;
// Define the class.
std::string doc;
if (is_default) {
doc = fmt::format(
"{}\n\nNote:\n\n"
" This class is templated; see :class:`{}`\n"
" for the list of instantiations.",
doc_string, template_name);
} else {
doc = doc_string;
}
py::class_<Class, Options...> py_class(
scope, class_name.c_str(), doc.c_str());
// Register it as a template instantiation.
const bool skip_rename = is_default;
AddTemplateClass(scope, template_name, py_class, param, skip_rename);
return py_class;
}

Expand Down
2 changes: 1 addition & 1 deletion bindings/pydrake/common/test/eigen_geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_quaternion(self, T):
if T == float:
self.assertEqual(
str(q_identity),
"Quaternion_[float](w=1.0, x=0.0, y=0.0, z=0.0)")
"Quaternion(w=1.0, x=0.0, y=0.0, z=0.0)")
else:
self.assertIn("Quaternion_[", str(q_identity))
self.check_cast(mut.Quaternion_, T)
Expand Down
10 changes: 1 addition & 9 deletions bindings/pydrake/multibody/_math_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,13 @@ def _indented_repr(o):
return repr(o).replace("\n", "\n ")


def _remove_float_suffix(typename):
suffix = "_[float]"
if typename.endswith(suffix):
return typename[:-len(suffix)]
return typename


def _spatial_vector_repr(rotation_name, translation_name):

def repr_with_closure(self):
cls_name = _remove_float_suffix(_pretty_class_name(type(self)))
rotation = self.rotational().tolist()
translation = self.translational().tolist()
return (
f"{cls_name}(\n"
f"{_pretty_class_name(type(self))}(\n"
f" {rotation_name}={_indented_repr(rotation)},\n"
f" {translation_name}={_indented_repr(translation)},\n"
f")")
Expand Down
10 changes: 5 additions & 5 deletions bindings/pydrake/multibody/test/plant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def check_repr(element, expected):
self._test_joint_api(T, shoulder)
check_repr(
shoulder,
"<RevoluteJoint_[float] name='ShoulderJoint' index=0 "
"<RevoluteJoint name='ShoulderJoint' index=0 "
"model_instance=2>")
np.testing.assert_array_equal(
shoulder.position_lower_limits(), [-np.inf])
Expand All @@ -352,12 +352,12 @@ def check_repr(element, expected):
self.assertEqual(len(plant.GetBodyIndices(model_instance)), 2)
check_repr(
link1,
"<RigidBody_[float] name='Link1' index=1 model_instance=2>")
"<RigidBody name='Link1' index=1 model_instance=2>")
self._test_frame_api(T, plant.GetFrameByName(name="Link1"))
link1_frame = plant.GetFrameByName(name="Link1")
check_repr(
link1_frame,
"<BodyFrame_[float] name='Link1' index=1 model_instance=2>")
"<BodyFrame name='Link1' index=1 model_instance=2>")
self.assertIs(
link1_frame,
plant.GetFrameByName(name="Link1", model_instance=model_instance))
Expand Down Expand Up @@ -400,7 +400,7 @@ def check_repr(element, expected):
plant.GetJointActuatorIndices(model_instance=model_instance))
check_repr(
joint_actuator,
"<JointActuator_[float] name='ElbowJoint' index=0 "
"<JointActuator name='ElbowJoint' index=0 "
"model_instance=2>")
self.assertIsInstance(
plant.get_frame(frame_index=world_frame_index()), Frame)
Expand Down Expand Up @@ -806,7 +806,7 @@ def test_multibody_force_element(self, T):
if T == float:
self.assertEqual(
repr(linear_spring),
"<LinearSpringDamper_[float] index=1 model_instance=1>")
"<LinearSpringDamper index=1 model_instance=1>")
revolute_joint = plant.AddJoint(RevoluteJoint_[T](
name="revolve_joint", frame_on_parent=body_a.body_frame(),
frame_on_child=body_b.body_frame(), axis=[0, 0, 1],
Expand Down
4 changes: 2 additions & 2 deletions bindings/pydrake/systems/scalar_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def wrapped(param):

return decorator

def _on_add(self, param, cls):
TemplateClass._on_add(self, param, cls)
def _on_add(self, param, cls, skip_rename):
TemplateClass._on_add(self, param, cls, skip_rename)
T, = param

# Check that the user has not defined `__init__`, and has defined
Expand Down
7 changes: 3 additions & 4 deletions bindings/pydrake/systems/test/value_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,17 @@ def test_str_and_repr(self):
vector_f = [1.]
value_f = BasicVector_[float](vector_f)
self.assertEqual(str(value_f), "[1.0]")
self.assertEqual(repr(value_f), "BasicVector_[float]([1.0])")
self.assertEqual(repr(value_f), "BasicVector([1.0])")
# Check repr() invariant.
self.assert_basic_vector_equal(value_f, eval(repr(value_f)))
# - Empty.
value_f_empty = BasicVector_[float]([])
self.assertEqual(str(value_f_empty), "[]")
self.assertEqual(repr(value_f_empty), "BasicVector_[float]([])")
self.assertEqual(repr(value_f_empty), "BasicVector([])")
# - Multiple values.
value_f_multi = BasicVector_[float]([1., 2.])
self.assertEqual(str(value_f_multi), "[1.0, 2.0]")
self.assertEqual(
repr(value_f_multi), "BasicVector_[float]([1.0, 2.0])")
self.assertEqual(repr(value_f_multi), "BasicVector([1.0, 2.0])")
# TODO(eric.cousineau): Make repr() for AutoDiffXd and Expression be
# semi-usable.
# T=AutoDiffXd
Expand Down
21 changes: 5 additions & 16 deletions doc/pydrake/pydrake_sphinx_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,16 @@ class TemplateDocumenter(autodoc.ModuleLevelDocumenter):
# Take priority over attributes.
priority = 1 + autodoc.AttributeDocumenter.priority

option_spec = {
'show-all-instantiations': autodoc.bool_option,
}
# Permit propagation of class-specific properties.
option_spec.update(autodoc.ClassDocumenter.option_spec)

@classmethod
def can_document_member(cls, member, membername, isattr, parent):
"""Overrides base to check for template objects."""
return isinstance(member, TemplateBase)

def get_object_members(self, want_all):
"""Overrides base to return instantiations from templates."""
members = []
for param in self.object.param_list:
instantiation = self.object[param]
members.append((instantiation.__name__, instantiation))
if not self.options.show_all_instantiations:
break
return False, members
"""Overrides base; we shouldn't show any details beyond the list of
instantiations.
"""
return False, []

def check_module(self):
"""Overrides base to show template objects given the correct module."""
Expand Down Expand Up @@ -199,8 +189,7 @@ def tpl_attrgetter(obj, name, *defargs):
"""
# N.B. Rather than try to evaluate parameters from the string, we instead
# match based on instantiation name.
if "[" in name:
assert name.endswith(']'), name
if isinstance(obj, TemplateBase) and name[0] != "_":
for param in obj.param_list:
inst = obj[param]
if inst.__name__ == name:
Expand Down

0 comments on commit 91e3dcc

Please sign in to comment.