Skip to content

Commit b3709fa

Browse files
author
SunDoge
committed
update saving strategy
1 parent 3d0346f commit b3709fa

File tree

5 files changed

+64
-33
lines changed

5 files changed

+64
-33
lines changed

_test_args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,26 @@
33
from dataclasses import dataclass
44

55
# from typed_args import TypedArgs, add_argument
6-
import typed_args as tp
6+
import typed_args as ta
77
import argparse
88

99
logging.basicConfig(level=logging.DEBUG)
1010

1111

1212
@dataclass()
13-
class Args(tp.TypedArgs):
13+
class Args(ta.TypedArgs):
1414
foo: str = 'bar'
15-
data: str = tp.add_argument(
15+
data: str = ta.add_argument(
1616
metavar='DIR',
1717
help='path to dataset'
1818
)
19-
arch: str = tp.add_argument(
19+
arch: str = ta.add_argument(
2020
'-a', '--arch',
2121
metavar='ARCH',
2222
default='resnet18',
2323
help='model architecture (default: resnet18)'
2424
)
25-
num_workers: int = tp.add_argument(
25+
num_workers: int = ta.add_argument(
2626
'-j', '--workers',
2727
default=4,
2828
metavar='N',
@@ -31,7 +31,7 @@ class Args(tp.TypedArgs):
3131

3232
@dataclass
3333
class Args1(Args):
34-
foo: str = tp.add_argument('--foo')
34+
foo: str = ta.add_argument('--foo')
3535

3636

3737
def test_args():

tests/test_add_argument.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import typed_args as tp
1+
import typed_args as ta
22
import pytest
33
from dataclasses import dataclass
44

55

66
def test_name_or_flags():
77
@dataclass
8-
class Args(tp.TypedArgs):
9-
foo: str = tp.add_argument('-f', '--foo')
10-
bar: str = tp.add_argument()
8+
class Args(ta.TypedArgs):
9+
foo: str = ta.add_argument('-f', '--foo')
10+
bar: str = ta.add_argument()
1111

1212
args1 = Args.from_args(['BAR'])
1313
assert args1 == Args(bar='BAR')
@@ -23,17 +23,20 @@ class Args(tp.TypedArgs):
2323

2424
def test_store_action():
2525
@dataclass
26-
class Args(tp.TypedArgs):
27-
foo: str = tp.add_argument('--foo')
26+
class Args(ta.TypedArgs):
27+
foo: str = ta.add_argument('--foo')
2828

2929
args = Args.from_args('--foo 1'.split())
3030
assert args == Args(foo='1')
3131

3232

3333
def test_store_const_action():
3434
@dataclass
35-
class Args(tp.TypedArgs):
36-
foo: int = tp.add_argument('--foo', action='store_const', const=42)
35+
class Args(ta.TypedArgs):
36+
foo: int = ta.add_argument('--foo', action='store_const', const=42)
3737

3838
args = Args.from_args(['--foo'])
3939
assert args == Args(foo=42)
40+
41+
42+

tests/test_list.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
1-
# from typed_args import TypedArgs, add_argument, dataclass
2-
# from typing import List
1+
from typed_args import TypedArgs, add_argument
2+
from dataclasses import dataclass
3+
from typing import List
4+
import typed_args as ta
5+
import pickle
36

47

5-
# def test_list():
6-
# """
7-
# https://docs.python.org/3/library/argparse.html#nargs
8-
# :return:
9-
# """
8+
def test_list():
9+
"""
10+
https://docs.python.org/3/library/argparse.html#nargs
11+
:return:
12+
"""
1013

11-
# @dataclass()
12-
# class Args(TypedArgs):
13-
# foo: List[str] = add_argument('--foo', nargs=2)
14-
# bar: List[str] = add_argument(nargs=1)
14+
@dataclass()
15+
class Args(TypedArgs):
16+
foo: List[str] = add_argument('--foo', nargs=2, type=str)
17+
bar: List[str] = add_argument(nargs=1, type=str)
1518

16-
# args = Args.from_args('c --foo a b'.split())
19+
args = Args.from_args('c --foo a b'.split())
1720

18-
# assert args.foo == ['a', 'b']
19-
# assert args.bar == ['c']
21+
assert args.foo == ['a', 'b']
22+
assert args.bar == ['c']
23+
24+
25+
def test_default_list():
26+
@dataclass
27+
class Args(ta.TypedArgs):
28+
foo: int = ta.add_argument('--foo', type=int, default=42)
29+
bar: List[int] = ta.add_argument(nargs='*', default=[1, 2, 3])
30+
31+
args = Args.from_args([])
32+
33+
assert args.bar == [1, 2, 3]

tests/test_pickle.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# from typed_args import TypedArgs, add_argument, dataclass
2-
import typed_args as tp
2+
import typed_args as ta
33
import pickle
4-
from typing import Optional
4+
from typing import List, Optional
55
from dataclasses import dataclass
66

77

88
@dataclass()
9-
class Args(tp.TypedArgs):
10-
foo: Optional[str] = tp.add_argument('--foo')
9+
class Args(ta.TypedArgs):
10+
foo: Optional[str] = ta.add_argument('--foo')
11+
bar: List[int] = ta.add_argument(nargs='*', default=[1, 2, 3])
1112

1213

1314
def test_pickle():

typed_args/_typed_args.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
}
77
"""
88

9+
import dataclasses
910
import inspect
1011
import logging
1112
from argparse import ArgumentParser, Namespace
@@ -151,8 +152,20 @@ def _add_argument(*args, **kwargs) -> Field:
151152
"""
152153
metadata = {'type': 'add_argument', 'args': args, 'kwargs': kwargs}
153154
default = kwargs.get('default', None)
155+
156+
if isinstance(default, (list, dict, set)):
157+
_logger.debug(
158+
'mutable object cannot be dataclass default attribute, make default_factory'
159+
)
160+
161+
def default_factory(): return default
162+
163+
default = dataclasses.MISSING
164+
else:
165+
default_factory = dataclasses.MISSING
166+
154167
_logger.debug('metadata: %s', metadata)
155-
return field(default=default, metadata=metadata)
168+
return field(default=default, default_factory=default_factory, metadata=metadata)
156169

157170

158171
def add_parser():

0 commit comments

Comments
 (0)