10
10
11
11
from keras_hub .src .api_export import keras_hub_export
12
12
from keras_hub .src .utils .keras_utils import print_msg
13
+ from keras_hub .src .utils .keras_utils import sharded_weights_available
14
+ from keras_hub .src .utils .tensor_utils import get_tensor_size_in_bits
13
15
14
16
try :
15
17
import kagglehub
48
50
# Weight file names.
49
51
MODEL_WEIGHTS_FILE = "model.weights.h5"
50
52
TASK_WEIGHTS_FILE = "task.weights.h5"
53
+ SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
51
54
52
55
# HuggingFace filenames.
53
56
README_FILE = "README.md"
@@ -647,7 +650,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
647
650
backbone = self ._load_serialized_object (self .config , ** kwargs )
648
651
if load_weights :
649
652
jax_memory_cleanup (backbone )
650
- backbone . load_weights ( get_file ( self .preset , MODEL_WEIGHTS_FILE ) )
653
+ self ._load_backbone_weights ( backbone )
651
654
return backbone
652
655
653
656
def load_tokenizer (self , cls , config_file = TOKENIZER_CONFIG_FILE , ** kwargs ):
@@ -697,8 +700,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
697
700
task .load_task_weights (task_weights )
698
701
else :
699
702
jax_memory_cleanup (task .backbone )
700
- backbone_weights = get_file (self .preset , MODEL_WEIGHTS_FILE )
701
- task .backbone .load_weights (backbone_weights )
703
+ self ._load_backbone_weights (task .backbone )
702
704
return task
703
705
704
706
def load_preprocessor (
@@ -726,18 +728,64 @@ def _load_serialized_object(self, config, **kwargs):
726
728
config ["config" ] = {** config ["config" ], ** kwargs }
727
729
return keras .saving .deserialize_keras_object (config )
728
730
731
+ def _get_sharded_filenames (self , config_path ):
732
+ with open (config_path , encoding = "utf-8" ) as config_file :
733
+ config = json .load (config_file )
734
+ weight_map = config ["weight_map" ]
735
+ return sorted (set (weight_map .values ()))
736
+
737
+ def _load_backbone_weights (self , backbone ):
738
+ # Detect if the backbone is sharded or not.
739
+ has_single_file_weights = check_file_exists (
740
+ self .preset , MODEL_WEIGHTS_FILE
741
+ )
742
+ if has_single_file_weights :
743
+ filepath = get_file (self .preset , MODEL_WEIGHTS_FILE )
744
+ else :
745
+ if not sharded_weights_available ():
746
+ raise RuntimeError (
747
+ "Sharded weights loading is not supported in the current "
748
+ f"Keras version { keras .__version__ } . "
749
+ "Please update to a newer version."
750
+ )
751
+ filepath = get_file (self .preset , SHARDED_MODEL_WEIGHTS_CONFIG_FILE )
752
+ sharded_filenames = self ._get_sharded_filenames (filepath )
753
+ for sharded_filename in sharded_filenames :
754
+ # Download the sharded weights.
755
+ _ = get_file (self .preset , sharded_filename )
756
+ backbone .load_weights (filepath )
757
+
729
758
730
759
class KerasPresetSaver :
731
760
def __init__ (self , preset_dir ):
732
761
os .makedirs (preset_dir , exist_ok = True )
733
762
self .preset_dir = preset_dir
734
763
735
- def save_backbone (self , backbone ):
764
+ def save_backbone (self , backbone , max_shard_size = 10 ):
736
765
self ._save_serialized_object (backbone , config_file = CONFIG_FILE )
737
- backbone_weight_path = os .path .join (self .preset_dir , MODEL_WEIGHTS_FILE )
738
- backbone .save_weights (backbone_weight_path )
739
766
self ._save_metadata (backbone )
740
767
768
+ # Save the weights.
769
+ backbone_size_in_bytes = self ._get_variables_size_in_bytes (
770
+ backbone .variables
771
+ )
772
+ backbone_size_in_gb = backbone_size_in_bytes / (1024 ** 3 )
773
+ # If the size of the backbone is larger than `max_shard_size`, save
774
+ # sharded weights.
775
+ if sharded_weights_available () and backbone_size_in_gb > max_shard_size :
776
+ backbone_sharded_weights_config_path = os .path .join (
777
+ self .preset_dir , SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778
+ )
779
+ backbone .save_weights (
780
+ backbone_sharded_weights_config_path ,
781
+ max_shard_size = max_shard_size ,
782
+ )
783
+ else :
784
+ backbone_weight_path = os .path .join (
785
+ self .preset_dir , MODEL_WEIGHTS_FILE
786
+ )
787
+ backbone .save_weights (backbone_weight_path )
788
+
741
789
def save_tokenizer (self , tokenizer ):
742
790
config_file = TOKENIZER_CONFIG_FILE
743
791
if hasattr (tokenizer , "config_file" ):
@@ -755,18 +803,20 @@ def save_audio_converter(self, converter):
755
803
def save_image_converter (self , converter ):
756
804
self ._save_serialized_object (converter , IMAGE_CONVERTER_CONFIG_FILE )
757
805
758
- def save_task (self , task ):
806
+ def save_task (self , task , max_shard_size = 10 ):
759
807
# Save task specific config and weights.
760
808
self ._save_serialized_object (task , TASK_CONFIG_FILE )
761
809
if task .has_task_weights ():
762
810
task_weight_path = os .path .join (self .preset_dir , TASK_WEIGHTS_FILE )
763
811
task .save_task_weights (task_weight_path )
764
812
# Save backbone.
765
813
if hasattr (task .backbone , "save_to_preset" ):
766
- task .backbone .save_to_preset (self .preset_dir )
814
+ task .backbone .save_to_preset (
815
+ self .preset_dir , max_shard_size = max_shard_size
816
+ )
767
817
else :
768
818
# Allow saving a `keras.Model` that is not a backbone subclass.
769
- self .save_backbone (task .backbone )
819
+ self .save_backbone (task .backbone , max_shard_size = max_shard_size )
770
820
# Save preprocessor.
771
821
if task .preprocessor and hasattr (task .preprocessor , "save_to_preset" ):
772
822
task .preprocessor .save_to_preset (self .preset_dir )
@@ -823,3 +873,13 @@ def _save_metadata(self, layer):
823
873
metadata_path = os .path .join (self .preset_dir , METADATA_FILE )
824
874
with open (metadata_path , "w" ) as metadata_file :
825
875
metadata_file .write (json .dumps (metadata , indent = 4 ))
876
+
877
+ def _get_variables_size_in_bytes (self , variables ):
878
+ unique_variables = {}
879
+ for v in variables :
880
+ if id (v ) not in unique_variables :
881
+ unique_variables [id (v )] = (v .shape , v .dtype )
882
+ total_memory_size = 0
883
+ for shape , dtype in unique_variables .values ():
884
+ total_memory_size += get_tensor_size_in_bits (shape , dtype )
885
+ return total_memory_size / 8
0 commit comments