diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 41fcd8f5cb3c..be7d84686b5c 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -4532,39 +4532,30 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } // If the type is a TypeVar and we're not expecting a type, convert - // a TypeVar or TypeVarTuple into a runtime type. We don't currently - // do this for ParamSpec (although we arguably should) because it's - // problematic for handling P.args and P.kwargs. + // a TypeVar, TypeVarTuple or ParamSpec into a runtime type. function convertTypeVarToRuntimeInstance(node: ExpressionNode, type: Type, flags: EvaluatorFlags) { - if ( - node.nodeType === ParseNodeType.Name && - isTypeVar(type) && - node.value === type.details.name && - !type.isVariadicInUnion && - (flags & EvaluatorFlags.ExpectingInstantiableType) === 0 - ) { - if ((flags & EvaluatorFlags.SkipConvertParamSpecToRuntimeObject) !== 0 && type.details.isParamSpec) { - return type; - } + if (!type.specialForm || type.typeAliasInfo) { + return type; + } - // Handle the special case of a PEP 604 union. These can appear within - // an implied type alias where we are not expecting a type. - const isPep604Union = - node.parent?.nodeType === ParseNodeType.BinaryOperation && - node.parent.operator === OperatorType.BitwiseOr; + if (!isTypeVar(type) || type.isVariadicInUnion || (flags & EvaluatorFlags.ExpectingInstantiableType) !== 0) { + return type; + } - if (!isPep604Union) { - // A TypeVar in contexts where we're not expecting a type is - // simply a runtime object. - if (type.details.runtimeClass) { - type = ClassType.cloneAsInstance(type.details.runtimeClass); - } else { - type = UnknownType.create(); - } - } + if ((flags & EvaluatorFlags.SkipConvertParamSpecToRuntimeObject) !== 0 && type.details.isParamSpec) { + return TypeBase.cloneAsSpecialForm(type, undefined); } - return type; + // Handle the special case of a PEP 604 union. These can appear within + // an implied type alias where we are not expecting a type. + const isPep604Union = + node.parent?.nodeType === ParseNodeType.BinaryOperation && node.parent.operator === OperatorType.BitwiseOr; + + if (isPep604Union) { + return TypeBase.cloneAsSpecialForm(type, undefined); + } + + return ClassType.cloneAsInstance(type.specialForm); } // Handles the case where a variable or parameter is defined in an outer @@ -12145,7 +12136,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions addError(LocMessage.typeVarFirstArg(), firstArg.valueExpression || errorNode); } - const typeVar = TypeVarType.createInstantiable(typeVarName, /* isParamSpec */ false, classType); + const typeVar = TypeBase.cloneAsSpecialForm( + TypeVarType.createInstantiable(typeVarName, /* isParamSpec */ false), + ClassType.cloneAsInstance(classType) + ); // Parse the remaining parameters. const paramNameMap = new Map(); @@ -12313,7 +12307,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions addError(LocMessage.typeVarFirstArg(), firstArg.valueExpression || errorNode); } - const typeVar = TypeVarType.createInstantiable(typeVarName, /* isParamSpec */ false, classType); + const typeVar = TypeBase.cloneAsSpecialForm( + TypeVarType.createInstantiable(typeVarName, /* isParamSpec */ false), + ClassType.cloneAsInstance(classType) + ); typeVar.details.isVariadic = true; // Parse the remaining parameters. @@ -12384,7 +12381,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions addError(LocMessage.paramSpecFirstArg(), firstArg.valueExpression || errorNode); } - const paramSpec = TypeVarType.createInstantiable(paramSpecName, /* isParamSpec */ true, classType); + const paramSpec = TypeBase.cloneAsSpecialForm( + TypeVarType.createInstantiable(paramSpecName, /* isParamSpec */ true), + ClassType.cloneAsInstance(classType) + ); // Parse the remaining parameters. for (let i = 1; i < argList.length; i++) { @@ -16002,9 +16002,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions getTypeOfExpression(arg.valueExpression); argType = UnknownType.create(); } else { - argType = makeTopLevelTypeVarsConcrete( - getTypeOfExpression(arg.valueExpression, exprFlags).type - ); + argType = getTypeOfExpression(arg.valueExpression, exprFlags).type; + + if (isTypeVar(argType) && argType.specialForm && TypeBase.isInstance(argType.specialForm)) { + addDiagnostic(DiagnosticRule.reportGeneralTypeIssues, LocMessage.baseClassInvalid(), arg); + argType = UnknownType.create(); + } + + argType = makeTopLevelTypeVarsConcrete(argType); } // In some stub files, classes are conditionally defined (e.g. based @@ -20566,9 +20571,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions let typeVar = TypeVarType.createInstantiable( node.name.value, - node.typeParamCategory === TypeParameterCategory.ParamSpec, - runtimeClass + node.typeParamCategory === TypeParameterCategory.ParamSpec ); + if (runtimeClass) { + typeVar = TypeBase.cloneAsSpecialForm(typeVar, ClassType.cloneAsInstance(runtimeClass)); + } typeVar.details.isTypeParamSyntax = true; if (node.typeParamCategory === TypeParameterCategory.TypeVarTuple) { diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 7465714e1dca..57b6aac45fa5 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -199,10 +199,16 @@ export namespace TypeBase { return clone; } - export function cloneAsSpecialForm(type: T, specialForm: ClassType): T { + export function cloneAsSpecialForm(type: T, specialForm: ClassType | undefined): T { const clone = { ...type }; delete clone.cached; - clone.specialForm = specialForm; + + if (specialForm) { + clone.specialForm = specialForm; + } else { + delete clone.specialForm; + } + return clone; } @@ -2448,7 +2454,6 @@ export interface TypeVarDetails { constraints: Type[]; boundType?: Type | undefined; defaultType?: Type | undefined; - runtimeClass?: ClassType | undefined; isParamSpec: boolean; isVariadic: boolean; @@ -2525,8 +2530,8 @@ export namespace TypeVarType { return create(name, /* isParamSpec */ false, TypeFlags.Instance); } - export function createInstantiable(name: string, isParamSpec = false, runtimeClass?: ClassType) { - return create(name, isParamSpec, TypeFlags.Instantiable, runtimeClass); + export function createInstantiable(name: string, isParamSpec = false) { + return create(name, isParamSpec, TypeFlags.Instantiable); } export function cloneAsInstance(type: TypeVarType): TypeVarType { @@ -2657,7 +2662,7 @@ export namespace TypeVarType { return `${name}.${scopeId}`; } - function create(name: string, isParamSpec: boolean, typeFlags: TypeFlags, runtimeClass?: ClassType): TypeVarType { + function create(name: string, isParamSpec: boolean, typeFlags: TypeFlags): TypeVarType { const newTypeVarType: TypeVarType = { category: TypeCategory.TypeVar, details: { @@ -2667,7 +2672,6 @@ export namespace TypeVarType { isParamSpec, isVariadic: false, isSynthesized: false, - runtimeClass, }, flags: typeFlags, }; diff --git a/packages/pyright-internal/src/tests/samples/classes1.py b/packages/pyright-internal/src/tests/samples/classes1.py index 45ffe6e6e994..5c2f506283b7 100644 --- a/packages/pyright-internal/src/tests/samples/classes1.py +++ b/packages/pyright-internal/src/tests/samples/classes1.py @@ -2,7 +2,10 @@ # handle various class definition cases. -from typing import Any +from typing import Any, TypeVar + + +T = TypeVar("T") class A: @@ -60,3 +63,17 @@ class Y(x): pass return Y() + + +# This should generate an error because a TypeVar can't be used as a base class. +class K(T): + pass + + +class L(type[T]): + pass + + +def func2(cls: type[T]): + class M(cls): + pass diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index 567db4517fa2..990f54b64ed3 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -766,7 +766,7 @@ test('RecursiveTypeAlias14', () => { test('Classes1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['classes1.py']); - TestUtils.validateResults(analysisResults, 1); + TestUtils.validateResults(analysisResults, 2); }); test('Classes3', () => {