From 4c0ca26bf4e8501aa6c703e1002c49fd8d6d2f66 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 9 Jan 2025 11:46:13 +0800 Subject: [PATCH 1/3] init --- dspy/signatures/__init__.py | 2 - dspy/signatures/signature.py | 222 ++++++++++++++--------------------- 2 files changed, 90 insertions(+), 134 deletions(-) diff --git a/dspy/signatures/__init__.py b/dspy/signatures/__init__.py index 3b902dbb5..322654793 100644 --- a/dspy/signatures/__init__.py +++ b/dspy/signatures/__init__.py @@ -2,7 +2,6 @@ from dspy.signatures.signature import ( SignatureMeta, Signature, - update_signatures, ensure_signature, make_signature, infer_prefix, @@ -18,7 +17,6 @@ "SignatureMeta", "Signature", "infer_prefix", - "update_signatures", "ensure_signature", "make_signature", ] diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index bd7b35a86..ea13fd1de 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,13 +1,11 @@ import ast +import importlib import inspect -import logging import re import types import typing -from contextlib import ExitStack, contextmanager from copy import deepcopy from typing import Any, Dict, Tuple, Type, Union # noqa: UP035 -import importlib from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo @@ -16,10 +14,10 @@ from dspy.signatures.field import InputField, OutputField -def _default_instructions(cls) -> str: - inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields]) - outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields]) - return f"Given the fields {inputs_}, produce the fields {outputs_}." +def _default_instructions(signature) -> str: + input_fields = ", ".join([f"`{field}`" for field in signature.input_fields]) + output_fields = ", ".join([f"`{field}`" for field in signature.output_fields]) + return f"Given the fields {input_fields}, produce the fields {output_fields}." class SignatureMeta(type(BaseModel)): @@ -73,15 +71,41 @@ def _validate_fields(cls): field_type = extra.get("__dspy_field_type") if field_type not in ["input", "output"]: raise TypeError( - f"Field '{name}' in '{cls.__name__}' must be declared with " - "InputField or OutputField. {field.json_schema_extra=}", + f"Field `{name}` in `{cls.__name__}` must be declared with InputField or OutputField, but " + f"field `{name}` has `field.json_schema_extra={field.json_schema_extra}`", ) + +# A signature for a predictor. +# +# You typically subclass it, like this: +# class MySignature(Signature): +# input: str = InputField(desc="...") +# output: int = OutputField(desc="...") +# +# You can call Signature("input1, input2 -> output1, output2") to create a new signature type. +# You can also include instructions, Signature("input -> output", "This is a test"). +# But it's generally better to use the make_signature function. +# +# If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), +# or a signature, you can use the ensure_signature function. +# +# For compatibility with the legacy dsp format, you can use the signature_to_template function. +# + + +class Signature(BaseModel, metaclass=SignatureMeta): + "" # noqa: D419 + + # Note: Don't put a docstring here, as it will become the default instructions + # for any signature that doesn't define it's own instructions. + @property def signature(cls) -> str: - in_args = ", ".join(cls.input_fields.keys()) - out_args = ", ".join(cls.output_fields.keys()) - return f"{in_args} -> {out_args}" + """The string representation of the signature.""" + input_fields = ", ".join(cls.input_fields.keys()) + output_fields = ", ".join(cls.output_fields.keys()) + return f"{input_fields} -> {output_fields}" @property def instructions(cls) -> str: @@ -100,12 +124,22 @@ def fields(cls) -> dict[str, FieldInfo]: return {**cls.input_fields, **cls.output_fields} def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]: - """Update the field, name, in a new Signature type. + """Create a new Signature class with the updated field information. - Returns a new Signature type with the field, name, updated + Returns a new Signature class with the field, name, updated with fields[name].json_schema_extra[key] = value. + + Args: + name: The name of the field to update. + type_: The new type of the field. + **kwargs: The new values for the field. + + Returns: + A new Signature class (not an instance) with the updated field information. """ fields_copy = deepcopy(cls.fields) + # Update `fields_copy[name].json_schema_extra` with the new kwargs, on conflicts + # we use the new value in kwargs. fields_copy[name].json_schema_extra = { **fields_copy[name].json_schema_extra, **kwargs, @@ -148,7 +182,10 @@ def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signa if index < 0: index += len(lst) + 1 if index < 0 or index > len(lst): - raise ValueError(f"Invalid index: {index}") + raise ValueError( + f"Invalid index to insert: {index}, index must be in the range of [{len(lst) - 1}, {len(lst)}] for " + f"{field.json_schema_extra['__dspy_field_type']} fields, but received: {index}.", + ) lst.insert(index, (name, (type_, field))) new_fields = dict(input_fields + output_fields) @@ -177,7 +214,7 @@ def load_state(cls, state): return signature_copy def equals(cls, other) -> bool: - """Compare the JSON schema of two Pydantic models.""" + """Compare the JSON schema of two Signature classes.""" if not isinstance(other, type) or not issubclass(other, BaseModel): return False if cls.instructions != other.instructions: @@ -185,7 +222,8 @@ def equals(cls, other) -> bool: for name in cls.fields.keys() | other.fields.keys(): if name not in other.fields or name not in cls.fields: return False - # TODO: Should we compare the fields? + if cls.fields[name].json_schema_extra != other.fields[name].json_schema_extra: + return False return True def __repr__(cls): @@ -205,98 +243,6 @@ def __repr__(cls): return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)" -# A signature for a predictor. -# -# You typically subclass it, like this: -# class MySignature(Signature): -# input: str = InputField(desc="...") # noqa: ERA001 -# output: int = OutputField(desc="...") # noqa: ERA001 -# -# You can call Signature("input1, input2 -> output1, output2") to create a new signature type. -# You can also include instructions, Signature("input -> output", "This is a test"). -# But it's generally better to use the make_signature function. -# -# If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), -# or a signature, you can use the ensure_signature function. -# -# For compatibility with the legacy dsp format, you can use the signature_to_template function. -# - - -class Signature(BaseModel, metaclass=SignatureMeta): - "" # noqa: D419 - - # Note: Don't put a docstring here, as it will become the default instructions - # for any signature that doesn't define it's own instructions. - pass - - @classmethod - @contextmanager - def replace( - cls, - new_signature: "Type[Signature]", - validate_new_signature: bool = True, - ) -> typing.Generator[None, None, None]: - """Replace the signature with an updated version. - - This is useful for updating the internal signatures of dspy - - Args: - new_signature: The new signature to replace the old one with. - validate_new_signature: Whether to validate the new signature against the old one - to ensure that no fields are missing. - """ - if validate_new_signature: - for field in cls.model_fields: - if field not in new_signature.model_fields: - raise ValueError( - f"Field '{field}' is missing from the updated signature '{new_signature.__class__}.", - ) - - class OldSignature(cls): - pass - - def swap_attributes(source: Type[Signature]): - unhandled = {} - - for attr in ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"]: - try: - setattr(cls, attr, getattr(source, attr)) - except AttributeError as exc: - if attr in ("__pydantic_fields__", "model_fields"): - version = "< 2.10" if attr == "__pydantic_fields__" else ">= 2.10" - logging.debug(f"Model attribute {attr} not replaced, expected with pydantic {version}") - unhandled[attr] = exc - else: - raise exc - - # if neither of the attributes were replaced, raise an error to prevent silent failures - if set(unhandled.keys()) >= {"model_fields", "__pydantic_fields__"}: - raise ValueError("Failed to replace either model_fields or __pydantic_fields__") from ( - unhandled.get("model_fields") or unhandled.get("__pydantic_fields__") - ) - - swap_attributes(new_signature) - cls.model_rebuild(force=True) - - yield - - swap_attributes(OldSignature) - cls.model_rebuild(force=True) - - -@contextmanager -def update_signatures( - signature_map: Dict[Type[Signature], Type[Signature]], - validate_new_signature: bool = True, -) -> typing.Generator[None, None, None]: - """Replace multiple signatures with updated versions, according to a mapping between the old and new signatures.""" - with ExitStack() as stack: - for old_signature, new_signature in signature_map.items(): - stack.enter_context(old_signature.replace(new_signature, validate_new_signature=validate_new_signature)) - yield - - def ensure_signature(signature: Union[str, Type[Signature]], instructions=None) -> Signature: if signature is None: return None @@ -312,17 +258,31 @@ def make_signature( instructions: str = None, signature_name: str = "StringSignature", ) -> Type[Signature]: - """Create a new Signature type with the given fields and instructions. - - Note: - Even though we're calling a type, we're not making an instance of the type. - In general, instances of Signature types are not allowed to be made. The call - syntax is provided for convenience. + """Create a new Signature subclass with the specified fields and instructions. Args: - signature: The signature format, specified as "input1, input2 -> output1, output2". - instructions: An optional prompt for the signature. - signature_name: An optional name for the new signature type. + signature: Either a string in the format "input1, input2 -> output1, output2" + or a dictionary mapping field names to tuples of (type, FieldInfo). + instructions: Optional string containing instructions/prompt for the signature. + If not provided, defaults to a basic description of inputs and outputs. + signature_name: Optional string to name the generated Signature subclass. + Defaults to "StringSignature". + + Returns: + A new signature class with the specified fields and instructions. + + Examples: + + ``` + # Using string format + sig1 = make_signature("question, context -> answer") + + # Using dictionary format + sig2 = make_signature({ + "question": (str, InputField()), + "answer": (str, OutputField()) + }) + ``` """ fields = _parse_signature(signature) if isinstance(signature, str) else signature @@ -331,28 +291,24 @@ def make_signature( fixed_fields = {} for name, type_field in fields.items(): if not isinstance(name, str): - raise ValueError(f"Field names must be strings, not {type(name)}") + raise ValueError(f"Field names must be strings, but received: {name}.") if isinstance(type_field, FieldInfo): type_ = type_field.annotation field = type_field else: if not isinstance(type_field, tuple): - raise ValueError(f"Field values must be tuples, not {type(type_field)}") + raise ValueError(f"Field values must be tuples, but received: {type_field}.") type_, field = type_field # It might be better to be explicit about the type, but it currently would break # program of thought and teleprompters, so we just silently default to string. if type_ is None: type_ = str - # if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type): if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias, typing._SpecialForm)): - raise ValueError(f"Field types must be types, not {type(type_)}") + raise ValueError(f"Field types must be types, but received: {type_} of type {type(type_)}.") if not isinstance(field, FieldInfo): - raise ValueError(f"Field values must be Field instances, not {type(field)}") + raise ValueError(f"Field values must be Field instances, but received: {field}.") fixed_fields[name] = (type_, field) - # Fixing the fields shouldn't change the order - assert list(fixed_fields.keys()) == list(fields.keys()) # noqa: S101 - # Default prompt when no instructions are provided if instructions is None: sig = Signature(signature, "") # Simple way to parse input/output fields @@ -366,7 +322,7 @@ def make_signature( ) -def _parse_signature(signature: str) -> Tuple[Type, Field]: +def _parse_signature(signature: str) -> Dict[str, Tuple[Type, Field]]: if signature.count("->") != 1: raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.") @@ -393,7 +349,7 @@ def _parse_type_node(node, names=None) -> Any: if names is None: names = dict(typing.__dict__) - names['NoneType'] = type(None) + names["NoneType"] = type(None) def resolve_name(id_: str): # Check if it's a built-in known type or in the provided names @@ -404,16 +360,17 @@ def resolve_name(id_: str): builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray] # Try PIL Image if 'Image' encountered - if 'Image' not in names: + if "Image" not in names: try: from PIL import Image - names['Image'] = Image + + names["Image"] = Image except ImportError: pass # If we have PIL Image and id_ is 'Image', return it - if 'Image' in names and id_ == 'Image': - return names['Image'] + if "Image" in names and id_ == "Image": + return names["Image"] # Check if it matches any known built-in type by name for t in builtin_types: @@ -490,6 +447,7 @@ def resolve_name(id_: str): raise ValueError(f"Unhandled AST node type in annotation: {ast.dump(node)}") + def infer_prefix(attribute_name: str) -> str: """Infer a prefix from an attribute name.""" # Convert camelCase to snake_case, but handle sequences of capital letters properly From c7f93350bb1f8a5028895c0284bca61f8e31fd4b Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 9 Jan 2025 14:29:06 +0800 Subject: [PATCH 2/3] incremental --- tests/signatures/test_signature.py | 61 ++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index 8aeebdc34..de53104ad 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -310,18 +310,27 @@ def test_typed_signatures_unions_and_optionals(): # Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct. # We'll just check for a Union containing str and NoneType: input_opt_annotation = sig.input_fields["input_opt"].annotation - assert (input_opt_annotation == Optional[str] or - (getattr(input_opt_annotation, '__origin__', None) is Union and str in input_opt_annotation.__args__ and type(None) in input_opt_annotation.__args__)) + assert input_opt_annotation == Optional[str] or ( + getattr(input_opt_annotation, "__origin__", None) is Union + and str in input_opt_annotation.__args__ + and type(None) in input_opt_annotation.__args__ + ) assert "input_union" in sig.input_fields input_union_annotation = sig.input_fields["input_union"].annotation - assert (getattr(input_union_annotation, '__origin__', None) is Union and - int in input_union_annotation.__args__ and type(None) in input_union_annotation.__args__) + assert ( + getattr(input_union_annotation, "__origin__", None) is Union + and int in input_union_annotation.__args__ + and type(None) in input_union_annotation.__args__ + ) assert "output_union" in sig.output_fields output_union_annotation = sig.output_fields["output_union"].annotation - assert (getattr(output_union_annotation, '__origin__', None) is Union and - int in output_union_annotation.__args__ and str in output_union_annotation.__args__) + assert ( + getattr(output_union_annotation, "__origin__", None) is Union + and int in output_union_annotation.__args__ + and str in output_union_annotation.__args__ + ) def test_typed_signatures_any(): @@ -336,22 +345,22 @@ def test_typed_signatures_nested(): # Nested generics and unions sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]") input_nested_ann = sig.input_fields["input_nested"].annotation - assert getattr(input_nested_ann, '__origin__', None) is list + assert getattr(input_nested_ann, "__origin__", None) is list assert len(input_nested_ann.__args__) == 1 union_arg = input_nested_ann.__args__[0] - assert getattr(union_arg, '__origin__', None) is Union + assert getattr(union_arg, "__origin__", None) is Union assert str in union_arg.__args__ and int in union_arg.__args__ output_nested_ann = sig.output_fields["output_nested"].annotation - assert getattr(output_nested_ann, '__origin__', None) is tuple + assert getattr(output_nested_ann, "__origin__", None) is tuple assert output_nested_ann.__args__[0] == int # The second arg is Optional[float], which is Union[float, None] second_arg = output_nested_ann.__args__[1] - assert getattr(second_arg, '__origin__', None) is Union + assert getattr(second_arg, "__origin__", None) is Union assert float in second_arg.__args__ and type(None) in second_arg.__args__ # The third arg is List[str] third_arg = output_nested_ann.__args__[2] - assert getattr(third_arg, '__origin__', None) is list + assert getattr(third_arg, "__origin__", None) is list assert third_arg.__args__[0] == str @@ -374,30 +383,44 @@ def test_typed_signatures_from_dict(): def test_typed_signatures_complex_combinations(): # Test a very complex signature with multiple nested constructs # input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]] - sig = Signature("input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]") + sig = Signature( + "input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]" + ) input_complex_ann = sig.input_fields["input_complex"].annotation - assert getattr(input_complex_ann, '__origin__', None) is dict + assert getattr(input_complex_ann, "__origin__", None) is dict key_arg, value_arg = input_complex_ann.__args__ assert key_arg == str # value_arg: List[Optional[Tuple[int, str]]] - assert getattr(value_arg, '__origin__', None) is list + assert getattr(value_arg, "__origin__", None) is list inner_union = value_arg.__args__[0] # inner_union should be Optional[Tuple[int, str]] # which is Union[Tuple[int, str], None] - assert getattr(inner_union, '__origin__', None) is Union + assert getattr(inner_union, "__origin__", None) is Union tuple_type = [t for t in inner_union.__args__ if t != type(None)][0] - assert getattr(tuple_type, '__origin__', None) is tuple + assert getattr(tuple_type, "__origin__", None) is tuple assert tuple_type.__args__ == (int, str) output_complex_ann = sig.output_fields["output_complex"].annotation - assert getattr(output_complex_ann, '__origin__', None) is Union + assert getattr(output_complex_ann, "__origin__", None) is Union assert len(output_complex_ann.__args__) == 2 possible_args = set(output_complex_ann.__args__) # Expecting List[str] and Dict[str, Any] # Because sets don't preserve order, just check membership. # Find the List[str] arg - list_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is list) - dict_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is dict) + list_arg = next(a for a in possible_args if getattr(a, "__origin__", None) is list) + dict_arg = next(a for a in possible_args if getattr(a, "__origin__", None) is dict) assert list_arg.__args__ == (str,) k, v = dict_arg.__args__ assert k == str and v == Any + + +def test_make_signature_from_string(): + sig = Signature("input1: int, input2: Dict[str, int] -> output1: List[str], output2: Union[int, str]") + assert "input1" in sig.input_fields + assert sig.input_fields["input1"].annotation == int + assert "input2" in sig.input_fields + assert sig.input_fields["input2"].annotation == Dict[str, int] + assert "output1" in sig.output_fields + assert sig.output_fields["output1"].annotation == List[str] + assert "output2" in sig.output_fields + assert sig.output_fields["output2"].annotation == Union[int, str] From 1262ff4069e6a0c9872b9f3c7c212b483e599945 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 9 Jan 2025 20:58:26 +0800 Subject: [PATCH 3/3] cleanup --- dspy/signatures/signature.py | 270 +++++++++++++++++------------ tests/signatures/test_signature.py | 44 ----- 2 files changed, 158 insertions(+), 156 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index ea13fd1de..e5052c31c 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,3 +1,20 @@ +"""Signature class for DSPy. + +You typically subclass the Signature class, like this: + class MySignature(dspy.Signature): + input: str = InputField(desc="...") + output: int = OutputField(desc="...") + +You can call Signature("input1, input2 -> output1, output2") to create a new signature type. +You can also include instructions, Signature("input -> output", "This is a test"). +But it's generally better to use the make_signature function. + +If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), +or a signature, you can use the ensure_signature function. + +For compatibility with the legacy dsp format, you can use the signature_to_template function. +""" + import ast import importlib import inspect @@ -14,15 +31,16 @@ from dspy.signatures.field import InputField, OutputField -def _default_instructions(signature) -> str: - input_fields = ", ".join([f"`{field}`" for field in signature.input_fields]) - output_fields = ", ".join([f"`{field}`" for field in signature.output_fields]) - return f"Given the fields {input_fields}, produce the fields {output_fields}." +def _default_instructions(cls) -> str: + inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields]) + outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields]) + return f"Given the fields {inputs_}, produce the fields {outputs_}." class SignatureMeta(type(BaseModel)): def __call__(cls, *args, **kwargs): # noqa: ANN002 if cls is Signature: + # We don't create an actual Signature instance, instead, we create a new Signature class. return make_signature(*args, **kwargs) return super().__call__(*args, **kwargs) @@ -75,30 +93,26 @@ def _validate_fields(cls): f"field `{name}` has `field.json_schema_extra={field.json_schema_extra}`", ) + @property + def instructions(cls) -> str: + return inspect.cleandoc(getattr(cls, "__doc__", "")) -# A signature for a predictor. -# -# You typically subclass it, like this: -# class MySignature(Signature): -# input: str = InputField(desc="...") -# output: int = OutputField(desc="...") -# -# You can call Signature("input1, input2 -> output1, output2") to create a new signature type. -# You can also include instructions, Signature("input -> output", "This is a test"). -# But it's generally better to use the make_signature function. -# -# If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), -# or a signature, you can use the ensure_signature function. -# -# For compatibility with the legacy dsp format, you can use the signature_to_template function. -# + @instructions.setter + def instructions(cls, instructions: str) -> None: + setattr(cls, "__doc__", instructions) + @property + def input_fields(cls) -> dict[str, FieldInfo]: + return cls._get_fields_with_type("input") -class Signature(BaseModel, metaclass=SignatureMeta): - "" # noqa: D419 + @property + def output_fields(cls) -> dict[str, FieldInfo]: + return cls._get_fields_with_type("output") - # Note: Don't put a docstring here, as it will become the default instructions - # for any signature that doesn't define it's own instructions. + @property + def fields(cls) -> dict[str, FieldInfo]: + # Make sure to give input fields before output fields + return {**cls.input_fields, **cls.output_fields} @property def signature(cls) -> str: @@ -107,22 +121,37 @@ def signature(cls) -> str: output_fields = ", ".join(cls.output_fields.keys()) return f"{input_fields} -> {output_fields}" - @property - def instructions(cls) -> str: - return inspect.cleandoc(getattr(cls, "__doc__", "")) + def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]: + return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type} - @instructions.setter - def instructions(cls, instructions: str) -> None: - setattr(cls, "__doc__", instructions) + def __repr__(cls): + """Output a representation of the signature. + Uses the form: + Signature(question, context -> answer + question: str = InputField(desc="..."), + context: List[str] = InputField(desc="..."), + answer: int = OutputField(desc="..."), + ). + """ + field_reprs = [] + for name, field in cls.fields.items(): + field_reprs.append(f"{name} = Field({field})") + field_repr = "\n ".join(field_reprs) + return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)" + + +class Signature(BaseModel, metaclass=SignatureMeta): + "" # noqa: D419 + + # Note: Don't put a docstring here, as it will become the default instructions + # for any signature that doesn't define it's own instructions. + + @classmethod def with_instructions(cls, instructions: str) -> Type["Signature"]: return Signature(cls.fields, instructions) - @property - def fields(cls) -> dict[str, FieldInfo]: - # Make sure to give input fields before output fields - return {**cls.input_fields, **cls.output_fields} - + @classmethod def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]: """Create a new Signature class with the updated field information. @@ -148,23 +177,15 @@ def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]: fields_copy[name].annotation = type_ return Signature(fields_copy, cls.instructions) - @property - def input_fields(cls) -> dict[str, FieldInfo]: - return cls._get_fields_with_type("input") - - @property - def output_fields(cls) -> dict[str, FieldInfo]: - return cls._get_fields_with_type("output") - - def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]: - return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type} - + @classmethod def prepend(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(0, name, field, type_) + @classmethod def append(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(-1, name, field, type_) + @classmethod def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signature"]: # It's possible to set the type as annotation=type in pydantic.Field(...) # But this may be annoying for users, so we allow them to pass the type @@ -191,6 +212,21 @@ def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signa new_fields = dict(input_fields + output_fields) return Signature(new_fields, cls.instructions) + @classmethod + def equals(cls, other) -> bool: + """Compare the JSON schema of two Signature classes.""" + if not isinstance(other, type) or not issubclass(other, BaseModel): + return False + if cls.instructions != other.instructions: + return False + for name in cls.fields.keys() | other.fields.keys(): + if name not in other.fields or name not in cls.fields: + return False + if cls.fields[name].json_schema_extra != other.fields[name].json_schema_extra: + return False + return True + + @classmethod def dump_state(cls): state = {"instructions": cls.instructions, "fields": []} for field in cls.fields: @@ -203,6 +239,7 @@ def dump_state(cls): return state + @classmethod def load_state(cls, state): signature_copy = Signature(deepcopy(cls.fields), cls.instructions) @@ -213,35 +250,6 @@ def load_state(cls, state): return signature_copy - def equals(cls, other) -> bool: - """Compare the JSON schema of two Signature classes.""" - if not isinstance(other, type) or not issubclass(other, BaseModel): - return False - if cls.instructions != other.instructions: - return False - for name in cls.fields.keys() | other.fields.keys(): - if name not in other.fields or name not in cls.fields: - return False - if cls.fields[name].json_schema_extra != other.fields[name].json_schema_extra: - return False - return True - - def __repr__(cls): - """Output a representation of the signature. - - Uses the form: - Signature(question, context -> answer - question: str = InputField(desc="..."), - context: List[str] = InputField(desc="..."), - answer: int = OutputField(desc="..."), - ). - """ - field_reprs = [] - for name, field in cls.fields.items(): - field_reprs.append(f"{name} = Field({field})") - field_repr = "\n ".join(field_reprs) - return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)" - def ensure_signature(signature: Union[str, Type[Signature]], instructions=None) -> Signature: if signature is None: @@ -329,65 +337,83 @@ def _parse_signature(signature: str) -> Dict[str, Tuple[Type, Field]]: inputs_str, outputs_str = signature.split("->") fields = {} - for name, type_ in _parse_arg_string(inputs_str): - fields[name] = (type_, InputField()) - for name, type_ in _parse_arg_string(outputs_str): - fields[name] = (type_, OutputField()) + for field_name, field_type in _parse_field_string(inputs_str): + fields[field_name] = (field_type, InputField()) + for field_name, field_type in _parse_field_string(outputs_str): + fields[field_name] = (field_type, OutputField()) return fields -def _parse_arg_string(string: str, names=None) -> Dict[str, str]: - args = ast.parse("def f(" + string + "): pass").body[0].args.args +def _parse_field_string(field_string: str) -> Dict[str, str]: + """Extract the field name and type from field string in the string-based Signature. + + It takes a string like "x: int, y: str" and returns a dictionary mapping field names to their types. + For example, "x: int, y: str" -> [("x", int), ("y", str)]. This function utitlizes the Python AST to parse the + fields and types. + """ + + args = ast.parse(f"def f({field_string}): pass").body[0].args.args names = [arg.arg for arg in args] types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args] return zip(names, types) def _parse_type_node(node, names=None) -> Any: - """Recursively parse an AST node representing a type annotation.""" + """Recursively parse an AST node representing a type annotation. + + This function converts Python's Abstract Syntax Tree (AST) nodes into actual Python types. + It's used to parse type annotations in signature strings like "x: List[int] -> y: str". + + Examples: + - For "x: int", the AST node represents 'int' and returns the int type + - For "x: List[str]", it processes a subscript node to return typing.List[str] + - For "x: Optional[int]", it handles the Union type to return Optional[int] + - For "x: MyModule.CustomType", it processes attribute access to return the actual type + + Args: + node: An AST node from Python's ast module, representing a type annotation. + Common node types include: + - ast.Name: Simple types like 'int', 'str' + - ast.Attribute: Nested types like 'typing.List' + - ast.Subscript: Generic types like 'List[int]' + names: Optional dictionary mapping type names to their actual type objects. + Defaults to Python's typing module contents plus NoneType. + + Returns: + The actual Python type represented by the AST node. + + Raises: + ValueError: If the AST node represents an unknown or invalid type annotation. + """ if names is None: names = dict(typing.__dict__) names["NoneType"] = type(None) - def resolve_name(id_: str): + def resolve_name(type_name: str): # Check if it's a built-in known type or in the provided names - if id_ in names: - return names[id_] - + if type_name in names: + return names[type_name] # Common built-in types builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray] - # Try PIL Image if 'Image' encountered - if "Image" not in names: - try: - from PIL import Image - - names["Image"] = Image - except ImportError: - pass - - # If we have PIL Image and id_ is 'Image', return it - if "Image" in names and id_ == "Image": - return names["Image"] - # Check if it matches any known built-in type by name for t in builtin_types: - if t.__name__ == id_: + if t.__name__ == type_name: return t # Attempt to import a module with this name dynamically # This allows handling of module-based annotations like `dspy.Image`. try: - mod = importlib.import_module(id_) - names[id_] = mod + mod = importlib.import_module(type_name) + names[type_name] = mod return mod except ImportError: pass # If we don't know the type or module, raise an error - raise ValueError(f"Unknown name: {id_}") + raise ValueError(f"Unknown name: {type_name}") if isinstance(node, ast.Module): if len(node.body) != 1: @@ -445,34 +471,54 @@ def resolve_name(id_: str): values.append(_parse_type_node(kw.value, names)) return Field(**dict(zip(keys, values))) - raise ValueError(f"Unhandled AST node type in annotation: {ast.dump(node)}") + raise ValueError( + f"Failed to parse string-base Signature due to unhandled AST node type in annotation: {ast.dump(node)}. " + "Please consider using class-based DSPy Signatures instead." + ) def infer_prefix(attribute_name: str) -> str: - """Infer a prefix from an attribute name.""" - # Convert camelCase to snake_case, but handle sequences of capital letters properly + """Infer a prefix from an attribute name by converting it to a human-readable format. + + Examples: + "camelCaseText" -> "Camel Case Text" + "snake_case_text" -> "Snake Case Text" + "text2number" -> "Text 2 Number" + "HTMLParser" -> "HTML Parser" + """ + # Step 1: Convert camelCase to snake_case + # Example: "camelCase" -> "camel_Case" s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name) + + # Handle consecutive capitals + # Example: "camel_Case" -> "camel_case" intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) - # Insert underscores around numbers to ensure spaces in the final output + # Step 2: Handle numbers by adding underscores around them + # Example: "text2number" -> "text_2_number" with_underscores_around_numbers = re.sub( - r"([a-zA-Z])(\d)", - r"\1_\2", + r"([a-zA-Z])(\d)", # Match letter followed by number + r"\1_\2", # Add underscore between them intermediate_name, ) + # Example: "2text" -> "2_text" with_underscores_around_numbers = re.sub( - r"(\d)([a-zA-Z])", - r"\1_\2", + r"(\d)([a-zA-Z])", # Match number followed by letter + r"\1_\2", # Add underscore between them with_underscores_around_numbers, ) - # Convert snake_case to 'Proper Title Case', but ensure acronyms are uppercased + # Step 3: Convert to Title Case while preserving acronyms words = with_underscores_around_numbers.split("_") title_cased_words = [] for word in words: if word.isupper(): + # Preserve acronyms like 'HTML', 'API' as-is title_cased_words.append(word) else: + # Capitalize first letter: 'text' -> 'Text' title_cased_words.append(word.capitalize()) + # Join words with spaces + # Example: ["Text", "2", "Number"] -> "Text 2 Number" return " ".join(title_cased_words) diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index de53104ad..864a574a7 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -196,50 +196,6 @@ class MySignature(Signature): assert predictor().output == "short answer" -def test_replaced_by_replace_context_manager(): - class SignatureOne(Signature): - input1 = InputField() - output = OutputField() - - class SignatureTwo(Signature): - input2 = InputField() - output = OutputField() - - with SignatureOne.replace(SignatureTwo, validate_new_signature=False): - # assert SignatureOne.input_fields has been replaced with SignatureTwo.input_fields - assert "input2" in SignatureOne.input_fields - # after the context manager, the original input_fields should be restored - assert SignatureOne.input_fields["input1"].json_schema_extra["prefix"] == "Input 1:" - - -def test_multiple_replaced_by_update_signatures(): - class SignatureOne(Signature): - input1 = InputField() - output = OutputField() - - class SignatureTwo(Signature): - input2 = InputField() - output = OutputField() - - class SignatureThree(Signature): - input3 = InputField() - output = OutputField() - - class SignatureFour(Signature): - input4 = InputField() - output = OutputField() - - signature_map = { - SignatureOne: SignatureThree, - SignatureTwo: SignatureFour, - } - with dspy.update_signatures(signature_map, validate_new_signature=False): - assert "input3" in SignatureOne.input_fields - assert "input4" in SignatureTwo.input_fields - assert "input1" in SignatureOne.input_fields - assert "input2" in SignatureTwo.input_fields - - def test_dump_and_load_state(): class CustomSignature(dspy.Signature): """I am just an instruction."""