diff --git a/arcparse/parser.py b/arcparse/parser.py index 59ab7de..294f21a 100644 --- a/arcparse/parser.py +++ b/arcparse/parser.py @@ -16,8 +16,11 @@ def _extract_optional_type(typehint: type) -> type | None: return get_args(typehint)[0] elif origin in {Union, UnionType}: args = get_args(typehint) - if len(args) == 2 and args[1] == NoneType: - return args[0] + if len(args) == 2: + if args[0] == NoneType: + return args[1] + elif args[1] == NoneType: + return args[0] return None @@ -205,12 +208,21 @@ def __collect_arguments(cls) -> tuple[dict[str, type], dict[str, _BaseArgument]] if isinstance(default, _Subparsers): continue - if isinstance(default, _BaseArgument): + if get_origin(typehint) in {Union, UnionType}: + union_args = get_args(typehint) + if len(union_args) > 2 or NoneType not in union_args: + raise Exception("Union can be used only for optional arguments (length of 2, 1 of them being None)") + + if isinstance(default, _BaseValueArgument) and _extract_type_from_typehint(typehint) == bool: + raise Exception("Unable to make type=bool, everything would be True") + elif isinstance(default, _BaseArgument): argument = default else: typ = _extract_type_from_typehint(typehint) if typ is bool: + if _extract_optional_type(typehint): + raise Exception("Unable to make type=bool, everything would be True") argument = _Flag(default=default) elif isinstance(typ, StrEnum): argument = _Option(default=default, choices=list(typ), converter=typ)