From 84f454cf27575f19acc76e8bc96356d93fef3568 Mon Sep 17 00:00:00 2001 From: Jaden Fiotto-Kaufman Date: Mon, 22 Jan 2024 13:15:16 -0500 Subject: [PATCH] Handling annoying built in methods for whitelist --- src/nnsight/pydantics/format/functions.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/nnsight/pydantics/format/functions.py b/src/nnsight/pydantics/format/functions.py index e863d986..63dd1d3f 100644 --- a/src/nnsight/pydantics/format/functions.py +++ b/src/nnsight/pydantics/format/functions.py @@ -1,5 +1,6 @@ import operator -from inspect import getmembers, isbuiltin, isfunction, ismethoddescriptor +from inspect import (getmembers, isbuiltin, isfunction, ismethod, + ismethoddescriptor) import einops import torch @@ -8,11 +9,19 @@ from ...tracing.Proxy import Proxy -def get_function_name(fn): +def get_function_name(fn, module_name=None): if isinstance(fn, str): return fn - return f"{getattr(fn, '__module__', '')}.{fn.__qualname__}" + if module_name is not None: + return f"{module_name}.{fn.__name__}" + + module_name = getattr(fn, "__module__", None) + + if module_name is None: + return fn.__qualname__ + + return f"{module_name}.{fn.__qualname__}" FUNCTIONS_WHITELIST = {} @@ -30,8 +39,8 @@ def get_function_name(fn): ) FUNCTIONS_WHITELIST.update( { - get_function_name(value): value - for key, value in getmembers(torch._C._TensorBase, ismethoddescriptor) + get_function_name(value, module_name="Tensor"): value + for key, value in getmembers(torch.Tensor, ismethoddescriptor) } ) FUNCTIONS_WHITELIST.update(