11
11
from argparse import ArgumentParser , Namespace
12
12
from dataclasses import Field , dataclass , field
13
13
from enum import Enum
14
- from typing import Dict , List , Optional , Sequence , Tuple , TypeVar
14
+ from typing import Dict , List , Optional , Sequence , Tuple , Type , TypeVar
15
15
16
16
_logger = logging .getLogger (__name__ )
17
17
18
- T = TypeVar ('T' )
18
+ T = TypeVar ('T' , bound = 'TypedArgs' )
19
19
20
20
21
21
def _get_dataclass_fields (cls ) -> Dict [str , Field ]:
@@ -27,11 +27,11 @@ class TypedArgs:
27
27
28
28
@classmethod
29
29
def from_args (
30
- cls ,
30
+ cls : Type [ T ] ,
31
31
args : Optional [List [str ]] = None ,
32
32
namespace : Optional [Namespace ] = None ,
33
33
parser : Optional [ArgumentParser ] = None ,
34
- ):
34
+ ) -> T :
35
35
36
36
if parser is None :
37
37
parser = ArgumentParser ()
@@ -47,7 +47,7 @@ def from_args(
47
47
48
48
@classmethod
49
49
def from_known_args (
50
- cls ,
50
+ cls : Type [ T ] ,
51
51
args : Optional [List [str ]] = None ,
52
52
namespace : Optional [Namespace ] = None ,
53
53
parser : Optional [ArgumentParser ] = None ,
@@ -66,10 +66,9 @@ def from_known_args(
66
66
67
67
@classmethod
68
68
def _add_arguments (
69
- cls : T ,
69
+ cls ,
70
70
parser : ArgumentParser ,
71
71
prefix : str = '' ,
72
-
73
72
):
74
73
fields = _get_dataclass_fields (cls )
75
74
for name , field in fields .items ():
0 commit comments