37
37
38
38
39
39
class ModelCardTest (absltest .TestCase ):
40
- def test_copy_from_proto_and_to_proto_with_all_fields (self ):
40
+ def test_from_proto_and_to_proto_with_all_fields (self ):
41
41
want_proto = text_format .Parse (_FULL_PROTO , model_card_pb2 .ModelCard ())
42
- model_card_py = model_card .ModelCard ()
43
- model_card_py .copy_from_proto (want_proto )
42
+ model_card_py = model_card .ModelCard .from_proto (want_proto )
44
43
got_proto = model_card_py .to_proto ()
45
44
46
45
self .assertEqual (want_proto , got_proto )
@@ -53,23 +52,27 @@ def test_merge_from_proto_and_to_proto_with_all_fields(self):
53
52
54
53
self .assertEqual (want_proto , got_proto )
55
54
56
- def test_copy_from_proto_success (self ):
55
+ def test_copy_from_proto_shows_deprecation_warning (self ):
56
+ with self .assertWarns (DeprecationWarning ):
57
+ owner = model_card .Owner (name = "my_name1" )
58
+ owner_proto = model_card_pb2 .Owner (
59
+ name = "my_name2" , contact = "my_contact2"
60
+ )
61
+ owner .copy_from_proto (owner_proto )
62
+
63
+ def test_from_proto_success (self ):
57
64
# Test fields convert.
58
- owner = model_card .Owner (name = "my_name1" )
59
65
owner_proto = model_card_pb2 .Owner (name = "my_name2" , contact = "my_contact2" )
60
- owner . copy_from_proto (owner_proto )
66
+ owner = model_card . Owner . from_proto (owner_proto )
61
67
self .assertEqual (
62
68
owner , model_card .Owner (name = "my_name2" , contact = "my_contact2" )
63
69
)
64
70
65
71
# Test message convert.
66
- model_details = model_card .ModelDetails (
67
- owners = [model_card .Owner (name = "my_name1" )]
68
- )
69
72
model_details_proto = model_card_pb2 .ModelDetails (
70
73
owners = [model_card_pb2 .Owner (name = "my_name2" , contact = "my_contact2" )]
71
74
)
72
- model_details . copy_from_proto (model_details_proto )
75
+ model_details = model_card . ModelDetails . from_proto (model_details_proto )
73
76
self .assertEqual (
74
77
model_details ,
75
78
model_card .ModelDetails (
@@ -104,16 +107,15 @@ def test_merge_from_proto_success(self):
104
107
)
105
108
)
106
109
107
- def test_copy_from_proto_with_invalid_proto (self ):
108
- owner = model_card .Owner ()
110
+ def test_from_proto_with_invalid_proto (self ):
109
111
wrong_proto = model_card_pb2 .Version ()
110
112
with self .assertRaisesRegex (
111
113
TypeError ,
112
114
"<class 'model_card_toolkit.proto.model_card_pb2.Owner'> is expected. "
113
115
"However <class 'model_card_toolkit.proto.model_card_pb2.Version'> is "
114
116
"provided."
115
117
):
116
- owner . copy_from_proto (wrong_proto )
118
+ model_card . Owner . from_proto (wrong_proto )
117
119
118
120
def test_merge_from_proto_with_invalid_proto (self ):
119
121
owner = model_card .Owner ()
@@ -152,34 +154,16 @@ def test_to_proto_with_invalid_field(self):
152
154
owner = model_card .Owner ()
153
155
owner .wrong_field = "wrong"
154
156
with self .assertRaisesRegex (
155
- ValueError , "has no such field named ' wrong_field' ."
157
+ ValueError , "has no such field named \" wrong_field\" ."
156
158
):
157
159
owner .to_proto ()
158
160
159
161
def test_from_json_and_to_json_with_all_fields (self ):
160
162
want_json = json .loads (_FULL_JSON )
161
- model_card_py = model_card .ModelCard ()
162
- model_card_py .from_json (want_json )
163
+ model_card_py = model_card .ModelCard .from_json (want_json )
163
164
got_json = json .loads (model_card_py .to_json ())
164
165
self .assertEqual (want_json , got_json )
165
166
166
- def test_from_json_overwrites_previous_fields (self ):
167
- overwritten_limitation = model_card .Limitation (
168
- description = "This model can only be used on text up to 140 characters."
169
- )
170
- overwritten_user = model_card .User (description = "language researchers" )
171
- model_card_py = model_card .ModelCard (
172
- considerations = model_card .Considerations (
173
- limitations = [overwritten_limitation ], users = [overwritten_user ]
174
- )
175
- )
176
- model_card_json = json .loads (_FULL_JSON )
177
- model_card_py .from_json (model_card_json )
178
- self .assertNotIn (
179
- overwritten_limitation , model_card_py .considerations .limitations
180
- )
181
- self .assertNotIn (overwritten_user , model_card_py .considerations .users )
182
-
183
167
def test_merge_from_json_does_not_overwrite_all_fields (self ):
184
168
# We want the "Limitations" field to be overwritten, but not "Users".
185
169
@@ -222,7 +206,7 @@ def test_merge_from_json_dict_and_str(self):
222
206
def test_from_invalid_json (self ):
223
207
invalid_json_dict = {"model_name" : "the_greatest_model" }
224
208
with self .assertRaises (jsonschema .ValidationError ):
225
- model_card .ModelCard () .from_json (invalid_json_dict )
209
+ model_card .ModelCard .from_json (invalid_json_dict )
226
210
227
211
def test_from_invalid_json_vesion (self ):
228
212
model_card_dict = {
@@ -238,7 +222,7 @@ def test_from_invalid_json_vesion(self):
238
222
"model card."
239
223
)
240
224
):
241
- model_card .ModelCard () .from_json (model_card_dict )
225
+ model_card .ModelCard .from_json (model_card_dict )
242
226
243
227
def test_from_proto_to_json (self ):
244
228
model_card_proto = text_format .Parse (
@@ -251,10 +235,10 @@ def test_from_proto_to_json(self):
251
235
_FULL_JSON ,
252
236
model_card_py .merge_from_proto (model_card_proto ).to_json ()
253
237
)
254
- # Use copy_from_proto
238
+ # Use from_proto
255
239
self .assertJsonEqual (
256
240
_FULL_JSON ,
257
- model_card_py . copy_from_proto (model_card_proto ).to_json ()
241
+ model_card . ModelCard . from_proto (model_card_proto ).to_json ()
258
242
)
259
243
260
244
def test_from_json_to_proto (self ):
@@ -263,8 +247,7 @@ def test_from_json_to_proto(self):
263
247
)
264
248
265
249
model_card_json = json .loads (_FULL_JSON )
266
- model_card_py = model_card .ModelCard ()
267
- model_card_py .from_json (model_card_json )
250
+ model_card_py = model_card .ModelCard .from_json (model_card_json )
268
251
model_card_json2proto = model_card_py .to_proto ()
269
252
270
253
self .assertEqual (model_card_proto , model_card_json2proto )
0 commit comments