Skip to content
This repository was archived by the owner on Sep 27, 2024. It is now read-only.

Commit dffe074

Browse files
authored
Add .from_proto() and .from_json() class methods (#278)
1 parent c520a26 commit dffe074

File tree

4 files changed

+74
-70
lines changed

4 files changed

+74
-70
lines changed

model_card_toolkit/base_model_card_field.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,23 @@
2020
import abc
2121
import dataclasses
2222
import json as json_lib
23-
from typing import Any, Dict
23+
from textwrap import dedent
24+
from typing import Any, Dict, Type, TypeVar
25+
from warnings import warn
2426

2527
from google.protobuf import descriptor, message
2628

2729
from model_card_toolkit.utils import json_utils
2830

31+
T = TypeVar('T', bound='BaseModelCardField')
32+
2933

3034
class BaseModelCardField(abc.ABC):
3135
"""Model card field base class.
3236
3337
This is an abstract class. All the model card fields should inherit this class
3438
and override the _proto_type property to the corresponding proto type. This
35-
abstract class provides methods `copy_from_proto`, `merge_from_proto` and
39+
abstract class provides methods `from_proto`, `merge_from_proto` and
3640
`to_proto` to convert the class from and to proto. The child class does not
3741
need to override this unless it needs some special process.
3842
"""
@@ -55,7 +59,7 @@ def to_proto(self) -> message.Message:
5559
for field_name, field_value in self.__dict__.items():
5660
if not hasattr(proto, field_name):
5761
raise ValueError(
58-
"%s has no such field named '%s'." % (type(proto), field_name)
62+
'%s has no such field named "%s".' % (type(proto), field_name)
5963
)
6064
if not field_value:
6165
continue
@@ -80,19 +84,19 @@ def to_proto(self) -> message.Message:
8084

8185
return proto
8286

83-
def _from_proto(self, proto: message.Message) -> "BaseModelCardField":
87+
def _from_proto(self: T, proto: message.Message) -> T:
8488
"""Convert proto to this class object."""
8589
if not isinstance(proto, self._proto_type):
8690
raise TypeError(
87-
"%s is expected. However %s is provided." %
91+
'%s is expected. However %s is provided.' %
8892
(self._proto_type, type(proto))
8993
)
9094

9195
for field_descriptor in proto.DESCRIPTOR.fields:
9296
field_name = field_descriptor.name
9397
if not hasattr(self, field_name):
9498
raise ValueError(
95-
"%s has no such field named '%s.'" % (self, field_name)
99+
'%s has no such field named "%s".' % (self, field_name)
96100
)
97101

98102
# Process Message type.
@@ -120,28 +124,44 @@ def _from_proto(self, proto: message.Message) -> "BaseModelCardField":
120124

121125
return self
122126

123-
def merge_from_proto(self, proto: message.Message) -> "BaseModelCardField":
127+
def merge_from_proto(self: T, proto: message.Message) -> T:
124128
"""Merges the contents of the model card proto into current object."""
125129
current = self.to_proto()
126130
current.MergeFrom(proto)
127131
self.clear()
128132
return self._from_proto(current)
129133

130-
def copy_from_proto(self, proto: message.Message) -> "BaseModelCardField":
134+
def copy_from_proto(self: T, proto: message.Message) -> T:
131135
"""Copies the contents of the model card proto into current object."""
136+
notice = dedent(
137+
'''
138+
This function is deprecated and will be removed in a future version.
139+
140+
If you would like to create a new model card from a proto, please use
141+
`ModelCard.from_proto(proto)` instead.
142+
143+
If you would like to copy the contents of a proto into an existing model
144+
card, please use `model_card.clear()` and `model_card.merge_from_proto(proto)`
145+
instead.
146+
'''
147+
)
148+
warn(notice, DeprecationWarning, stacklevel=2)
132149
self.clear()
133150
return self._from_proto(proto)
134151

135-
def _from_json(
136-
self, json_dict: Dict[str, Any], field: "BaseModelCardField"
137-
) -> "BaseModelCardField":
152+
@classmethod
153+
def from_proto(cls: Type[T], proto: message.Message) -> T:
154+
"""Constructs an object of this class from a model card proto."""
155+
return cls()._from_proto(proto)
156+
157+
def _from_json(self: T, json_dict: Dict[str, Any], field: T) -> T:
138158
"""Parses a JSON dictionary into the current object."""
139159
for subfield_key, subfield_json_value in json_dict.items():
140160
if subfield_key.startswith(json_utils.SCHEMA_VERSION_STRING):
141161
continue
142162
elif not hasattr(field, subfield_key):
143163
raise ValueError(
144-
"BaseModelCardField %s has no such field named '%s.'" %
164+
'BaseModelCardField %s has no such field named "%s".' %
145165
(field, subfield_key)
146166
)
147167
elif isinstance(subfield_json_value, dict):

model_card_toolkit/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _read_proto_file(self, path: str) -> Optional[ModelCard]:
193193
model_card_proto = model_card_pb2.ModelCard()
194194
with open(path, 'rb') as f:
195195
model_card_proto.ParseFromString(f.read())
196-
return ModelCard().copy_from_proto(model_card_proto)
196+
return ModelCard.from_proto(model_card_proto)
197197

198198
def _annotate_eval_results(self, model_card: ModelCard) -> ModelCard:
199199
"""Annotates a model card with info from TFMA evaluation results.

model_card_toolkit/model_card.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,14 @@ def to_json(self) -> str:
489489
] = json_utils.get_latest_schema_version()
490490
return json_lib.dumps(model_card_dict, indent=2)
491491

492-
def from_json(self, json_dict: Dict[str, Any]) -> None:
492+
def merge_from_json(self, json: Union[Dict[str, Any], str]) -> 'ModelCard':
493493
"""Reads ModelCard from JSON.
494494
495-
This function will overwrite all existing ModelCard fields.
495+
This function will only overwrite ModelCard fields specified in the JSON.
496496
497497
Args:
498-
json_dict: A JSON dict from which to populate fields in the model card
499-
schema.
498+
json: A JSON object from which to populate fields in the model card. This
499+
can be provided as either a dictionary or a string.
500500
501501
Raises:
502502
JSONDecodeError: If `json_dict` is not a valid JSON string.
@@ -505,19 +505,19 @@ def from_json(self, json_dict: Dict[str, Any]) -> None:
505505
ValueError: If `json_dict` contains a value not in the class or schema
506506
definition.
507507
"""
508+
if isinstance(json, str):
509+
json = json_lib.loads(json)
510+
json_utils.validate_json_schema(json)
511+
self._from_json(json, self)
512+
return self
508513

509-
json_utils.validate_json_schema(json_dict)
510-
self.clear()
511-
self._from_json(json_dict, self)
512-
513-
def merge_from_json(self, json: Union[Dict[str, Any], str]) -> None:
514-
"""Reads ModelCard from JSON.
515-
516-
This function will only overwrite ModelCard fields specified in the JSON.
514+
@classmethod
515+
def from_json(cls, json_dict: Dict[str, Any]) -> 'ModelCard':
516+
"""Constructs a ModelCard from JSON.
517517
518518
Args:
519-
json: A JSON object from whichto populate fields in the model card. This
520-
can be provided as either a dictionary or a string.
519+
json_dict: A JSON dict from which to populate fields in the model card
520+
schema.
521521
522522
Raises:
523523
JSONDecodeError: If `json_dict` is not a valid JSON string.
@@ -526,7 +526,8 @@ def merge_from_json(self, json: Union[Dict[str, Any], str]) -> None:
526526
ValueError: If `json_dict` contains a value not in the class or schema
527527
definition.
528528
"""
529-
if isinstance(json, str):
530-
json = json_lib.loads(json)
531-
json_utils.validate_json_schema(json)
532-
self._from_json(json, self)
529+
530+
json_utils.validate_json_schema(json_dict)
531+
model_card = cls()
532+
model_card._from_json(json_dict, model_card)
533+
return model_card

model_card_toolkit/model_card_test.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@
3737

3838

3939
class ModelCardTest(absltest.TestCase):
40-
def test_copy_from_proto_and_to_proto_with_all_fields(self):
40+
def test_from_proto_and_to_proto_with_all_fields(self):
4141
want_proto = text_format.Parse(_FULL_PROTO, model_card_pb2.ModelCard())
42-
model_card_py = model_card.ModelCard()
43-
model_card_py.copy_from_proto(want_proto)
42+
model_card_py = model_card.ModelCard.from_proto(want_proto)
4443
got_proto = model_card_py.to_proto()
4544

4645
self.assertEqual(want_proto, got_proto)
@@ -53,23 +52,27 @@ def test_merge_from_proto_and_to_proto_with_all_fields(self):
5352

5453
self.assertEqual(want_proto, got_proto)
5554

56-
def test_copy_from_proto_success(self):
55+
def test_copy_from_proto_shows_deprecation_warning(self):
56+
with self.assertWarns(DeprecationWarning):
57+
owner = model_card.Owner(name="my_name1")
58+
owner_proto = model_card_pb2.Owner(
59+
name="my_name2", contact="my_contact2"
60+
)
61+
owner.copy_from_proto(owner_proto)
62+
63+
def test_from_proto_success(self):
5764
# Test fields convert.
58-
owner = model_card.Owner(name="my_name1")
5965
owner_proto = model_card_pb2.Owner(name="my_name2", contact="my_contact2")
60-
owner.copy_from_proto(owner_proto)
66+
owner = model_card.Owner.from_proto(owner_proto)
6167
self.assertEqual(
6268
owner, model_card.Owner(name="my_name2", contact="my_contact2")
6369
)
6470

6571
# Test message convert.
66-
model_details = model_card.ModelDetails(
67-
owners=[model_card.Owner(name="my_name1")]
68-
)
6972
model_details_proto = model_card_pb2.ModelDetails(
7073
owners=[model_card_pb2.Owner(name="my_name2", contact="my_contact2")]
7174
)
72-
model_details.copy_from_proto(model_details_proto)
75+
model_details = model_card.ModelDetails.from_proto(model_details_proto)
7376
self.assertEqual(
7477
model_details,
7578
model_card.ModelDetails(
@@ -104,16 +107,15 @@ def test_merge_from_proto_success(self):
104107
)
105108
)
106109

107-
def test_copy_from_proto_with_invalid_proto(self):
108-
owner = model_card.Owner()
110+
def test_from_proto_with_invalid_proto(self):
109111
wrong_proto = model_card_pb2.Version()
110112
with self.assertRaisesRegex(
111113
TypeError,
112114
"<class 'model_card_toolkit.proto.model_card_pb2.Owner'> is expected. "
113115
"However <class 'model_card_toolkit.proto.model_card_pb2.Version'> is "
114116
"provided."
115117
):
116-
owner.copy_from_proto(wrong_proto)
118+
model_card.Owner.from_proto(wrong_proto)
117119

118120
def test_merge_from_proto_with_invalid_proto(self):
119121
owner = model_card.Owner()
@@ -152,34 +154,16 @@ def test_to_proto_with_invalid_field(self):
152154
owner = model_card.Owner()
153155
owner.wrong_field = "wrong"
154156
with self.assertRaisesRegex(
155-
ValueError, "has no such field named 'wrong_field'."
157+
ValueError, "has no such field named \"wrong_field\"."
156158
):
157159
owner.to_proto()
158160

159161
def test_from_json_and_to_json_with_all_fields(self):
160162
want_json = json.loads(_FULL_JSON)
161-
model_card_py = model_card.ModelCard()
162-
model_card_py.from_json(want_json)
163+
model_card_py = model_card.ModelCard.from_json(want_json)
163164
got_json = json.loads(model_card_py.to_json())
164165
self.assertEqual(want_json, got_json)
165166

166-
def test_from_json_overwrites_previous_fields(self):
167-
overwritten_limitation = model_card.Limitation(
168-
description="This model can only be used on text up to 140 characters."
169-
)
170-
overwritten_user = model_card.User(description="language researchers")
171-
model_card_py = model_card.ModelCard(
172-
considerations=model_card.Considerations(
173-
limitations=[overwritten_limitation], users=[overwritten_user]
174-
)
175-
)
176-
model_card_json = json.loads(_FULL_JSON)
177-
model_card_py.from_json(model_card_json)
178-
self.assertNotIn(
179-
overwritten_limitation, model_card_py.considerations.limitations
180-
)
181-
self.assertNotIn(overwritten_user, model_card_py.considerations.users)
182-
183167
def test_merge_from_json_does_not_overwrite_all_fields(self):
184168
# We want the "Limitations" field to be overwritten, but not "Users".
185169

@@ -222,7 +206,7 @@ def test_merge_from_json_dict_and_str(self):
222206
def test_from_invalid_json(self):
223207
invalid_json_dict = {"model_name": "the_greatest_model"}
224208
with self.assertRaises(jsonschema.ValidationError):
225-
model_card.ModelCard().from_json(invalid_json_dict)
209+
model_card.ModelCard.from_json(invalid_json_dict)
226210

227211
def test_from_invalid_json_vesion(self):
228212
model_card_dict = {
@@ -238,7 +222,7 @@ def test_from_invalid_json_vesion(self):
238222
"model card."
239223
)
240224
):
241-
model_card.ModelCard().from_json(model_card_dict)
225+
model_card.ModelCard.from_json(model_card_dict)
242226

243227
def test_from_proto_to_json(self):
244228
model_card_proto = text_format.Parse(
@@ -251,10 +235,10 @@ def test_from_proto_to_json(self):
251235
_FULL_JSON,
252236
model_card_py.merge_from_proto(model_card_proto).to_json()
253237
)
254-
# Use copy_from_proto
238+
# Use from_proto
255239
self.assertJsonEqual(
256240
_FULL_JSON,
257-
model_card_py.copy_from_proto(model_card_proto).to_json()
241+
model_card.ModelCard.from_proto(model_card_proto).to_json()
258242
)
259243

260244
def test_from_json_to_proto(self):
@@ -263,8 +247,7 @@ def test_from_json_to_proto(self):
263247
)
264248

265249
model_card_json = json.loads(_FULL_JSON)
266-
model_card_py = model_card.ModelCard()
267-
model_card_py.from_json(model_card_json)
250+
model_card_py = model_card.ModelCard.from_json(model_card_json)
268251
model_card_json2proto = model_card_py.to_proto()
269252

270253
self.assertEqual(model_card_proto, model_card_json2proto)

0 commit comments

Comments
 (0)