Skip to content

Commit

Permalink
Added check for name mismatch for enum classes defined using the func…
Browse files Browse the repository at this point in the history
…tional syntax. This addresses #7025. (#7038)
  • Loading branch information
erictraut authored Jan 18, 2024
1 parent 81e85d1 commit 922e746
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 2 deletions.
18 changes: 17 additions & 1 deletion packages/pyright-internal/src/analyzer/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
*/

import { assert } from '../common/debug';
import { DiagnosticRule } from '../common/diagnosticRules';
import { LocMessage } from '../localization/localize';
import { ArgumentCategory, ExpressionNode, NameNode, ParseNode, ParseNodeType } from '../parser/parseNodes';
import { getFileInfo } from './analyzerNodeInfo';
import { VariableDeclaration } from './declaration';
Expand Down Expand Up @@ -104,14 +106,28 @@ export function createEnumType(
}
const classInstanceType = ClassType.cloneAsInstance(classType);

// Check for name consistency if the enum class is assigned to a variable.
if (
errorNode.parent?.nodeType === ParseNodeType.Assignment &&
errorNode.parent.leftExpression.nodeType === ParseNodeType.Name &&
errorNode.parent.leftExpression.value !== className
) {
evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocMessage.enumNameMismatch(),
errorNode.parent.leftExpression
);
return undefined;
}

// The Enum functional form supports various forms of arguments:
// Enum('name', 'a b c')
// Enum('name', 'a,b,c')
// Enum('name', ['a', 'b', 'c'])
// Enum('name', ('a', 'b', 'c'))
// Enum('name', (('a', 1), ('b', 2), ('c', 3)))
// Enum('name', [('a', 1), ('b', 2), ('c', 3))]
// Enum('name', {'a': 1, 'b': 2, 'c': 3}
// Enum('name', {'a': 1, 'b': 2, 'c': 3})
if (initArg.valueExpression.nodeType === ParseNodeType.StringList) {
// Don't allow format strings in the init arg.
if (!initArg.valueExpression.strings.every((str) => str.nodeType === ParseNodeType.String)) {
Expand Down
1 change: 1 addition & 0 deletions packages/pyright-internal/src/localization/localize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ export namespace Localizer {
export const ellipsisSecondArg = () => getRawString('Diagnostic.ellipsisSecondArg');
export const enumClassOverride = () =>
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.enumClassOverride'));
export const enumNameMismatch = () => getRawString('Diagnostic.enumNameMismatch');
export const exceptionGroupIncompatible = () => getRawString('Diagnostic.exceptionGroupIncompatible');
export const exceptionTypeIncorrect = () =>
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.exceptionTypeIncorrect'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
"ellipsisContext": "\"...\" is not allowed in this context",
"ellipsisSecondArg": "\"...\" is allowed only as the second of two arguments",
"enumClassOverride": "Enum class \"{name}\" is final and cannot be subclassed",
"enumNameMismatch": "Enum class must be assigned to a variable with the same name",
"exceptionGroupIncompatible": "Exception group syntax (\"except*\") requires Python 3.11 or newer",
"exceptionTypeIncorrect": "\"{type}\" does not derive from BaseException",
"exceptionTypeNotClass": "\"{type}\" is not a valid exception class",
Expand Down
4 changes: 4 additions & 0 deletions packages/pyright-internal/src/tests/samples/enum1.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,7 @@ class TestEnum12(Enum):
reveal_type(TestEnum12.a, expected_text="Literal[TestEnum12.a]")
reveal_type(TestEnum12.b, expected_text="() -> None")
reveal_type(TestEnum12.c, expected_text="() -> None")


# This should generate an error because of the name mismatch.
BadName = Enum("GoodName", "A", "B")
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator3.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ test('MethodOverride6', () => {
test('Enum1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['enum1.py']);

TestUtils.validateResults(analysisResults, 3);
TestUtils.validateResults(analysisResults, 4);
});

test('Enum2', () => {
Expand Down

0 comments on commit 922e746

Please sign in to comment.