diff --git a/packages/pyright-internal/src/analyzer/dataClasses.ts b/packages/pyright-internal/src/analyzer/dataClasses.ts index 43e772ed273b..764976b65aa8 100644 --- a/packages/pyright-internal/src/analyzer/dataClasses.ts +++ b/packages/pyright-internal/src/analyzer/dataClasses.ts @@ -613,6 +613,11 @@ export function synthesizeDataClassMethods( }); defaultType = makeTypeVarsFree(defaultType, liveTypeVars); + + if (entry.mroClass && requiresSpecialization(defaultType)) { + const solution = buildSolutionFromSpecializedClass(entry.mroClass); + defaultType = applySolvedTypeVars(defaultType, solution); + } } } } @@ -1135,7 +1140,7 @@ export function addInheritedDataClassEntries(classType: ClassType, entries: Data // If the type from the parent class is generic, we need to convert // to the type parameter namespace of child class. - const updatedEntry = { ...entry }; + const updatedEntry = { ...entry, mroClass }; updatedEntry.type = applySolvedTypeVars(updatedEntry.type, solution); if (entry.isClassVar) { diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index c0c7bfbb8fd9..e14e736af735 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -538,6 +538,7 @@ export namespace ModuleType { export interface DataClassEntry { name: string; classType: ClassType; + mroClass?: ClassType; isClassVar: boolean; isKeywordOnly: boolean; alias?: string | undefined; diff --git a/packages/pyright-internal/src/tests/samples/dataclass10.py b/packages/pyright-internal/src/tests/samples/dataclass10.py index 5bce5a1efbe1..d6633be4a43d 100644 --- a/packages/pyright-internal/src/tests/samples/dataclass10.py +++ b/packages/pyright-internal/src/tests/samples/dataclass10.py @@ -7,15 +7,38 @@ @dataclass -class Foo(Generic[T]): +class ABase(Generic[T]): value: Union[str, T] -reveal_type(Foo(""), expected_text="Foo[Unknown]") +reveal_type(ABase(""), expected_text="ABase[Unknown]") -class Bar(Foo[int]): +class AChild(ABase[int]): pass -reveal_type(Bar(123), expected_text="Bar") +reveal_type(AChild(123), expected_text="AChild") + + +class B(Generic[T]): + pass + + +@dataclass +class CBase(Generic[T]): + x: B[T] = B[T]() + + +@dataclass +class CChild(CBase[T]): + pass + + +c1 = CBase[int]() +reveal_type(c1, expected_text="CBase[int]") +reveal_type(c1.x, expected_text="B[int]") + +c2 = CChild[int]() +reveal_type(c2, expected_text="CChild[int]") +reveal_type(c2.x, expected_text="B[int]")