diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index 05456ceb25..211faa50c6 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -898,7 +898,8 @@ export function getCodeFlowEngine( } } - const effectiveType = typesToCombine.length > 0 ? combineTypes(typesToCombine) : undefined; + const effectiveType = + typesToCombine.length > 0 ? combineTypes(typesToCombine, undefined, evaluator) : undefined; return setCacheEntry(branchNode, effectiveType, sawIncomplete); } diff --git a/packages/pyright-internal/src/analyzer/operations.ts b/packages/pyright-internal/src/analyzer/operations.ts index afc08dfb84..7330ced158 100644 --- a/packages/pyright-internal/src/analyzer/operations.ts +++ b/packages/pyright-internal/src/analyzer/operations.ts @@ -656,7 +656,7 @@ export function getTypeOfBinaryOperation( flags | EvaluatorFlags.ExpectingInstantiableType ); - let newUnion = combineTypes([adjustedLeftType, adjustedRightType]); + let newUnion = combineTypes([adjustedLeftType, adjustedRightType], undefined, evaluator); const unionClass = evaluator.getUnionClassType(); if (unionClass && isInstantiableClass(unionClass)) { diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 924b2302b9..b68092ab1f 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -12,6 +12,8 @@ import { Uri } from '../common/uri/uri'; import { ArgumentNode, ExpressionNode, NameNode, ParameterCategory } from '../parser/parseNodes'; import { ClassDeclaration, FunctionDeclaration, SpecialBuiltInClassDeclaration } from './declaration'; import { Symbol, SymbolTable } from './symbol'; +import { TypeEvaluator } from './typeEvaluatorTypes'; +import { AssignTypeFlags } from './typeUtils'; export const enum TypeCategory { // Name is not bound to a value of any type. @@ -3237,7 +3239,6 @@ export function removeUnbound(type: Type): Type { return type; } - export function removeFromUnion(type: Type, removeFilter: (type: Type) => boolean) { if (isUnion(type)) { const remainingTypes = type.subtypes.filter((t) => !removeFilter(t)); @@ -3265,11 +3266,28 @@ export function findSubtype(type: Type, filter: (type: UnionableType | NeverType return filter(type) ? type : undefined; } -// Combines multiple types into a single type. If the types are -// the same, only one is returned. If they differ, they -// are combined into a UnionType. NeverTypes are filtered out. -// If no types remain in the end, a NeverType is returned. -export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type { +/** + * Combines multiple types into a single type. If the types are + * the same, only one is returned. If they differ, they + * are combined into a UnionType. NeverTypes are filtered out. + * If no types remain in the end, a NeverType is returned. + * + * if a {@link TypeEvaluator} is provided, it not only checks that + * the types aren't the same, but also prevents redundant subtypes from + * being added to the union. eg. adding `Literal[1]` to a union of `int | str` + * is useless, so the union is left as-is. when adding a supertype to a union + * that contains a subtype of it, that subtype becomes redundant and therefore + * gets removed (eg. adding `int` to `Literal[1] | str` will result in + * `int | str`). this is useful to prevent cases where a narrowed type would be + * treated as partially unknown unnecessarily (eg. `object | list[Any]`). + * + * a {@link TypeEvaluator} should not be provided in cases where the union + * intentionally contains redundant information for the purpose of autocomplete. + * i don't think there are any situations where this is supported currently, but + * it's something to keep in mind if we end up implementing + * https://github.com/DetachHead/basedpyright/issues/320 + */ +export function combineTypes(subtypes: Type[], maxSubtypeCount?: number, evaluator?: TypeEvaluator): Type { // Filter out any "Never" and "NoReturn" types. let sawNoReturn = false; @@ -3352,21 +3370,87 @@ export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type { return UnknownType.create(); } - const newUnionType = UnionType.create(); + let newUnionType = UnionType.create(); if (typeAliasSources.size > 0) { newUnionType.typeAliasSources = typeAliasSources; } let hitMaxSubtypeCount = false; - expandedTypes.forEach((subtype, index) => { - if (index === 0) { - UnionType.addType(newUnionType, subtype as UnionableType); - } else { - if (maxSubtypeCount === undefined || newUnionType.subtypes.length < maxSubtypeCount) { - _addTypeIfUnique(newUnionType, subtype as UnionableType); + expandedTypes.forEach((subtype) => { + let shouldAddType = false; + if ( + // if an evaluator isn't specified, don't do the redundant type check + !evaluator || + // if it's a type var (including recursive type aliases which get synthesized into typevars), + // we don't know the bound type at this point so it's not safe to do the redundant type check + isTypeVar(subtype) || + // no types have been added to the union yet which causes its type to be Never, which would break + // the redundant type check + !newUnionType.subtypes.length + ) { + shouldAddType = true; + } else if ( + // i cant figure out how to check whether a special form is assignable, for now we just skip the + // redundant check on special forms + subtype.specialForm || + newUnionType.subtypes.find((subtype) => subtype.specialForm) + ) { + shouldAddType = true; + } else if ( + // if the new type is a subtype of a type that's already in the union, it's redundant and therefore + // does not need to be added to the union + !evaluator.assignType( + newUnionType, + subtype, + undefined, + undefined, + undefined, + AssignTypeFlags.OverloadOverlapCheck + ) + ) { + shouldAddType = true; + if ( + // if the new type is a supertype of a type that's already in the union, we need to get rid of that + // type then replace it with the new wider one + evaluator.assignType( + subtype, + newUnionType, + undefined, + undefined, + undefined, + AssignTypeFlags.OverloadOverlapCheck + ) + ) { + const filteredType = removeFromUnion(newUnionType, (type) => + evaluator.assignType( + subtype, + type, + undefined, + undefined, + undefined, + AssignTypeFlags.OverloadOverlapCheck + ) + ); + if (isUnion(filteredType)) { + newUnionType = filteredType; + } else { + newUnionType = UnionType.create(); + if (filteredType.category !== TypeCategory.Never) { + UnionType.addType(newUnionType, filteredType as UnionableType); + } + } + } + } + if (shouldAddType) { + if (!newUnionType.subtypes.length) { + UnionType.addType(newUnionType, subtype as UnionableType); } else { - hitMaxSubtypeCount = true; + if (maxSubtypeCount === undefined || newUnionType.subtypes.length < maxSubtypeCount) { + _addTypeIfUnique(newUnionType, subtype as UnionableType); + } else { + hitMaxSubtypeCount = true; + } } } }); diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingBased.py b/packages/pyright-internal/src/tests/samples/typeNarrowingBased.py new file mode 100644 index 0000000000..b34d65f44e --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingBased.py @@ -0,0 +1,16 @@ +from typing import Any, assert_type + + +def foo(value: object): + print(value) + if isinstance(value, list): + _ = assert_type(value, list[Any]) + _ = assert_type(value, object) + +def bar(value: object): + print(value) + if isinstance(value, list): + _ = assert_type(value, list[Any]) + else: + _ = assert_type(value, object) + _ = assert_type(value, object) \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIsNone1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIsNone1.py index 87a95cc11c..3ee855ec64 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIsNone1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIsNone1.py @@ -77,7 +77,7 @@ def __bool__(self) -> Literal[False]: def func7(x: NoneProto | None): if x is None: - reveal_type(x, expected_text="None") + reveal_type(x, expected_text="Never") # should be None. see https://github.com/DetachHead/basedpyright/issues/459 else: reveal_type(x, expected_text="NoneProto") diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py index 4427961938..4a23a069a2 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py @@ -38,4 +38,4 @@ def func1( if isinstance(obj, Callable): reveal_type(obj, expected_text="((int, str) -> int) | B | TCall1@func1") else: - reveal_type(obj, expected_text="list[int] | C | D | A") + reveal_type(obj, expected_text="list[int] | C | A") diff --git a/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts b/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts index a13d1f91b0..ad368c6ee6 100644 --- a/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts @@ -117,3 +117,12 @@ test('subscript context manager types on 3.8', () => { ], }); }); + +test("useless type isn't added to union after if statement", () => { + const configOptions = new ConfigOptions(Uri.empty()); + configOptions.diagnosticRuleSet.reportAssertTypeFailure = 'error'; + const analysisResults = typeAnalyzeSampleFiles(['typeNarrowingBased.py'], configOptions); + validateResultsButBased(analysisResults, { + errors: [], + }); +});