Skip to content

Commit ec0a914

Browse files
Sharded weights support (#2218)
* Add support for sharded weights. * Add `max_shard_size` to Backbone and Task. Simplify the test. * Split the functions from `KerasPresetSaver`. Clean up the comments.
1 parent 7480e31 commit ec0a914

File tree

6 files changed

+160
-14
lines changed

6 files changed

+160
-14
lines changed

keras_hub/src/models/backbone.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,17 @@ class like `keras_hub.models.Backbone.from_preset()`, or from
177177
)
178178
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
179179

180-
def save_to_preset(self, preset_dir):
180+
def save_to_preset(self, preset_dir, max_shard_size=10):
181181
"""Save backbone to a preset directory.
182182
183183
Args:
184184
preset_dir: The path to the local model preset directory.
185+
max_shard_size: `int` or `float`. Maximum size in GB for each
186+
sharded file. If `None`, no sharding will be done. Defaults to
187+
`10`.
185188
"""
186189
saver = get_preset_saver(preset_dir)
187-
saver.save_backbone(self)
190+
saver.save_backbone(self, max_shard_size=max_shard_size)
188191

189192
def get_lora_target_names(self):
190193
"""Returns list of layer names which are to be LoRA-fied.

keras_hub/src/models/task.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,17 @@ def save_task_weights(self, filepath):
236236
objects_to_skip=backbone_layer_ids,
237237
)
238238

239-
def save_to_preset(self, preset_dir):
239+
def save_to_preset(self, preset_dir, max_shard_size=10):
240240
"""Save task to a preset directory.
241241
242242
Args:
243243
preset_dir: The path to the local model preset directory.
244+
max_shard_size: `int` or `float`. Maximum size in GB for each
245+
sharded file. If `None`, no sharding will be done. Defaults to
246+
`10`.
244247
"""
245248
saver = get_preset_saver(preset_dir)
246-
saver.save_task(self)
249+
saver.save_task(self, max_shard_size=max_shard_size)
247250

248251
@property
249252
def layers(self):

keras_hub/src/utils/keras_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import sys
23

34
import keras
@@ -147,3 +148,13 @@ def get_gpu_names():
147148
]
148149
else:
149150
return [""]
151+
152+
153+
def sharded_weights_available():
154+
"""Whether sharded weights serialization is available.
155+
156+
Returns:
157+
`True` if sharded weights are available, `False` otherwise.
158+
"""
159+
save_weights_signature = inspect.signature(keras.saving.save_weights)
160+
return "max_shard_size" in save_weights_signature.parameters

keras_hub/src/utils/preset_utils.py

+69-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from keras_hub.src.api_export import keras_hub_export
1212
from keras_hub.src.utils.keras_utils import print_msg
13+
from keras_hub.src.utils.keras_utils import sharded_weights_available
14+
from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
1315

1416
try:
1517
import kagglehub
@@ -48,6 +50,7 @@
4850
# Weight file names.
4951
MODEL_WEIGHTS_FILE = "model.weights.h5"
5052
TASK_WEIGHTS_FILE = "task.weights.h5"
53+
SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
5154

5255
# HuggingFace filenames.
5356
README_FILE = "README.md"
@@ -647,7 +650,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
647650
backbone = self._load_serialized_object(self.config, **kwargs)
648651
if load_weights:
649652
jax_memory_cleanup(backbone)
650-
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
653+
self._load_backbone_weights(backbone)
651654
return backbone
652655

653656
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
@@ -697,8 +700,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
697700
task.load_task_weights(task_weights)
698701
else:
699702
jax_memory_cleanup(task.backbone)
700-
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
701-
task.backbone.load_weights(backbone_weights)
703+
self._load_backbone_weights(task.backbone)
702704
return task
703705

704706
def load_preprocessor(
@@ -726,18 +728,64 @@ def _load_serialized_object(self, config, **kwargs):
726728
config["config"] = {**config["config"], **kwargs}
727729
return keras.saving.deserialize_keras_object(config)
728730

731+
def _get_sharded_filenames(self, config_path):
732+
with open(config_path, encoding="utf-8") as config_file:
733+
config = json.load(config_file)
734+
weight_map = config["weight_map"]
735+
return sorted(set(weight_map.values()))
736+
737+
def _load_backbone_weights(self, backbone):
738+
# Detect if the backbone is sharded or not.
739+
has_single_file_weights = check_file_exists(
740+
self.preset, MODEL_WEIGHTS_FILE
741+
)
742+
if has_single_file_weights:
743+
filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
744+
else:
745+
if not sharded_weights_available():
746+
raise RuntimeError(
747+
"Sharded weights loading is not supported in the current "
748+
f"Keras version {keras.__version__}. "
749+
"Please update to a newer version."
750+
)
751+
filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
752+
sharded_filenames = self._get_sharded_filenames(filepath)
753+
for sharded_filename in sharded_filenames:
754+
# Download the sharded weights.
755+
_ = get_file(self.preset, sharded_filename)
756+
backbone.load_weights(filepath)
757+
729758

730759
class KerasPresetSaver:
731760
def __init__(self, preset_dir):
732761
os.makedirs(preset_dir, exist_ok=True)
733762
self.preset_dir = preset_dir
734763

735-
def save_backbone(self, backbone):
764+
def save_backbone(self, backbone, max_shard_size=10):
736765
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
737-
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
738-
backbone.save_weights(backbone_weight_path)
739766
self._save_metadata(backbone)
740767

768+
# Save the weights.
769+
backbone_size_in_bytes = self._get_variables_size_in_bytes(
770+
backbone.variables
771+
)
772+
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
773+
# If the size of the backbone is larger than `max_shard_size`, save
774+
# sharded weights.
775+
if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
776+
backbone_sharded_weights_config_path = os.path.join(
777+
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778+
)
779+
backbone.save_weights(
780+
backbone_sharded_weights_config_path,
781+
max_shard_size=max_shard_size,
782+
)
783+
else:
784+
backbone_weight_path = os.path.join(
785+
self.preset_dir, MODEL_WEIGHTS_FILE
786+
)
787+
backbone.save_weights(backbone_weight_path)
788+
741789
def save_tokenizer(self, tokenizer):
742790
config_file = TOKENIZER_CONFIG_FILE
743791
if hasattr(tokenizer, "config_file"):
@@ -755,18 +803,20 @@ def save_audio_converter(self, converter):
755803
def save_image_converter(self, converter):
756804
self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
757805

758-
def save_task(self, task):
806+
def save_task(self, task, max_shard_size=10):
759807
# Save task specific config and weights.
760808
self._save_serialized_object(task, TASK_CONFIG_FILE)
761809
if task.has_task_weights():
762810
task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE)
763811
task.save_task_weights(task_weight_path)
764812
# Save backbone.
765813
if hasattr(task.backbone, "save_to_preset"):
766-
task.backbone.save_to_preset(self.preset_dir)
814+
task.backbone.save_to_preset(
815+
self.preset_dir, max_shard_size=max_shard_size
816+
)
767817
else:
768818
# Allow saving a `keras.Model` that is not a backbone subclass.
769-
self.save_backbone(task.backbone)
819+
self.save_backbone(task.backbone, max_shard_size=max_shard_size)
770820
# Save preprocessor.
771821
if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
772822
task.preprocessor.save_to_preset(self.preset_dir)
@@ -823,3 +873,13 @@ def _save_metadata(self, layer):
823873
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
824874
with open(metadata_path, "w") as metadata_file:
825875
metadata_file.write(json.dumps(metadata, indent=4))
876+
877+
def _get_variables_size_in_bytes(self, variables):
878+
unique_variables = {}
879+
for v in variables:
880+
if id(v) not in unique_variables:
881+
unique_variables[id(v)] = (v.shape, v.dtype)
882+
total_memory_size = 0
883+
for shape, dtype in unique_variables.values():
884+
total_memory_size += get_tensor_size_in_bits(shape, dtype)
885+
return total_memory_size / 8

keras_hub/src/utils/preset_utils_test.py

+43
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,55 @@
1010
)
1111
from keras_hub.src.models.bert.bert_backbone import BertBackbone
1212
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
13+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
1314
from keras_hub.src.tests.test_case import TestCase
15+
from keras_hub.src.utils.keras_utils import sharded_weights_available
1416
from keras_hub.src.utils.preset_utils import CONFIG_FILE
1517
from keras_hub.src.utils.preset_utils import upload_preset
1618

1719

1820
class PresetUtilsTest(TestCase):
21+
@pytest.mark.large
22+
def test_sharded_weights(self):
23+
if not sharded_weights_available():
24+
self.skipTest("Sharded weights are not available.")
25+
26+
init_kwargs = {
27+
"vocabulary_size": 1024,
28+
"num_layers": 12,
29+
"num_query_heads": 8,
30+
"num_key_value_heads": 4,
31+
"hidden_dim": 32,
32+
"intermediate_dim": 64,
33+
"head_dim": 4,
34+
"sliding_window_size": 5,
35+
"attention_logit_soft_cap": 50,
36+
"final_logit_soft_cap": 30,
37+
"layer_norm_epsilon": 1e-6,
38+
"query_head_dim_normalize": False,
39+
"use_post_ffw_norm": True,
40+
"use_post_attention_norm": True,
41+
"use_sliding_window_attention": True,
42+
}
43+
backbone = GemmaBackbone(**init_kwargs) # ~422KB
44+
45+
# Save the sharded weights.
46+
preset_dir = self.get_temp_dir()
47+
backbone.save_to_preset(preset_dir, max_shard_size=0.0002)
48+
self.assertTrue(
49+
os.path.exists(os.path.join(preset_dir, "model.weights.json"))
50+
)
51+
self.assertTrue(
52+
os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5"))
53+
)
54+
55+
# Load the sharded weights.
56+
revived_backbone = GemmaBackbone.from_preset(preset_dir)
57+
for v1, v2 in zip(
58+
backbone.trainable_variables, revived_backbone.trainable_variables
59+
):
60+
self.assertAllClose(v1, v2)
61+
1962
@pytest.mark.large
2063
def test_preset_errors(self):
2164
with self.assertRaisesRegex(ValueError, "must be a string"):

keras_hub/src/utils/tensor_utils.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import contextlib
22
import functools
33
import inspect
4+
import math
5+
import re
46
import threading
57

68
import keras
@@ -305,6 +307,29 @@ def is_string_dtype(dtype):
305307
return "string" in keras.backend.standardize_dtype(dtype)
306308

307309

310+
def get_dtype_size_in_bits(dtype):
311+
"""Get the size of a given dtype in bits."""
312+
dtype = keras.backend.standardize_dtype(dtype)
313+
# If dtype is bool, return 1 immediately.
314+
if dtype == "bool":
315+
return 1
316+
# Else, we extract the bit size from the string.
317+
return int(re.sub(r"bfloat|float|uint|int", "", dtype))
318+
319+
320+
def get_tensor_size_in_bits(shape, dtype):
321+
"""Calculate the size given dtype and shape in bits.
322+
323+
Args:
324+
dtype: The dtype of the tensor.
325+
shape: List of iterables representing the shape of the tensor.
326+
327+
Returns:
328+
The size of the tensor in bytes.
329+
"""
330+
return math.prod(shape) * get_dtype_size_in_bits(dtype)
331+
332+
308333
def any_equal(inputs, values, padding_mask):
309334
"""Return a mask that is True anywhere `inputs` has a value in `values`.
310335
@@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask):
320345
Returns:
321346
A tensor with `inputs` shape where each position is True if it contains
322347
a value from any `values`. Padding mask will be applied before
323-
returning."""
348+
returning.
349+
"""
324350
output = ops.equal(inputs, values[0])
325351
for value in values[1:]:
326352
value_equality = ops.equal(inputs, value)

0 commit comments

Comments
 (0)