Skip to content

Commit 16bb9f2

Browse files
authored
Merge branch 'keras-team:master' into esm
2 parents 72e9829 + c00db4e commit 16bb9f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+5434
-145
lines changed

keras_hub/api/models/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,18 @@
348348
from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
349349
MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor,
350350
)
351+
from keras_hub.src.models.mixtral.mixtral_backbone import (
352+
MixtralBackbone as MixtralBackbone,
353+
)
354+
from keras_hub.src.models.mixtral.mixtral_causal_lm import (
355+
MixtralCausalLM as MixtralCausalLM,
356+
)
357+
from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import (
358+
MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor,
359+
)
360+
from keras_hub.src.models.mixtral.mixtral_tokenizer import (
361+
MixtralTokenizer as MixtralTokenizer,
362+
)
351363
from keras_hub.src.models.mobilenet.mobilenet_backbone import (
352364
MobileNetBackbone as MobileNetBackbone,
353365
)
@@ -420,6 +432,15 @@
420432
from keras_hub.src.models.qwen.qwen_tokenizer import (
421433
QwenTokenizer as QwenTokenizer,
422434
)
435+
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import (
436+
QwenMoeBackbone as QwenMoeBackbone,
437+
)
438+
from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import (
439+
QwenMoeCausalLM as QwenMoeCausalLM,
440+
)
441+
from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import (
442+
QwenMoeCausalLMPreprocessor as QwenMoeCausalLMPreprocessor,
443+
)
423444
from keras_hub.src.models.resnet.resnet_backbone import (
424445
ResNetBackbone as ResNetBackbone,
425446
)

keras_hub/api/tokenizers/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from keras_hub.src.models.mistral.mistral_tokenizer import (
5656
MistralTokenizer as MistralTokenizer,
5757
)
58+
from keras_hub.src.models.mixtral.mixtral_tokenizer import (
59+
MixtralTokenizer as MixtralTokenizer,
60+
)
5861
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer
5962
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
6063
PaliGemmaTokenizer as PaliGemmaTokenizer,
@@ -68,6 +71,9 @@
6871
from keras_hub.src.models.qwen.qwen_tokenizer import (
6972
QwenTokenizer as QwenTokenizer,
7073
)
74+
from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import (
75+
QwenMoeTokenizer as QwenMoeTokenizer,
76+
)
7177
from keras_hub.src.models.roberta.roberta_tokenizer import (
7278
RobertaTokenizer as RobertaTokenizer,
7379
)

keras_hub/src/models/backbone.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,17 @@ class like `keras_hub.models.Backbone.from_preset()`, or from
177177
)
178178
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
179179

180-
def save_to_preset(self, preset_dir):
180+
def save_to_preset(self, preset_dir, max_shard_size=10):
181181
"""Save backbone to a preset directory.
182182
183183
Args:
184184
preset_dir: The path to the local model preset directory.
185+
max_shard_size: `int` or `float`. Maximum size in GB for each
186+
sharded file. If `None`, no sharding will be done. Defaults to
187+
`10`.
185188
"""
186189
saver = get_preset_saver(preset_dir)
187-
saver.save_backbone(self)
190+
saver.save_backbone(self, max_shard_size=max_shard_size)
188191

189192
def get_lora_target_names(self):
190193
"""Returns list of layer names which are to be LoRA-fied.

keras_hub/src/models/cspnet/cspnet_backbone.py

+51-26
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class CSPNetBackbone(FeaturePyramidBackbone):
8181
8282
# Pretrained backbone
8383
model = keras_hub.models.CSPNetBackbone.from_preset(
84-
"cspdarknet53_ra_imagenet"
84+
"csp_darknet_53_ra_imagenet"
8585
)
8686
model(input_data)
8787
@@ -357,18 +357,6 @@ def apply(x):
357357
dtype=dtype,
358358
name=f"{name}_bottleneck_block_bn_3",
359359
)(x)
360-
if activation == "leaky_relu":
361-
x = layers.LeakyReLU(
362-
negative_slope=0.01,
363-
dtype=dtype,
364-
name=f"{name}_bottleneck_block_activation_3",
365-
)(x)
366-
else:
367-
x = layers.Activation(
368-
activation,
369-
dtype=dtype,
370-
name=f"{name}_bottleneck_block_activation_3",
371-
)(x)
372360

373361
x = layers.add(
374362
[x, shortcut], dtype=dtype, name=f"{name}_bottleneck_block_add"
@@ -673,6 +661,13 @@ def apply(x):
673661
name=f"{name}_csp_activation_1",
674662
)(x)
675663
else:
664+
if strides > 1:
665+
x = layers.ZeroPadding2D(
666+
1,
667+
data_format=data_format,
668+
dtype=dtype,
669+
name=f"{name}_csp_conv_pad_1",
670+
)(x)
676671
x = layers.Conv2D(
677672
filters=down_chs,
678673
kernel_size=3,
@@ -882,6 +877,13 @@ def apply(x):
882877
name=f"{name}_cs3_activation_1",
883878
)(x)
884879
else:
880+
if strides > 1:
881+
x = layers.ZeroPadding2D(
882+
1,
883+
data_format=data_format,
884+
dtype=dtype,
885+
name=f"{name}_cs3_conv_pad_1",
886+
)(x)
885887
x = layers.Conv2D(
886888
filters=down_chs,
887889
kernel_size=3,
@@ -1062,6 +1064,13 @@ def apply(x):
10621064
name=f"{name}_dark_activation_1",
10631065
)(x)
10641066
else:
1067+
if strides > 1:
1068+
x = layers.ZeroPadding2D(
1069+
1,
1070+
data_format=data_format,
1071+
dtype=dtype,
1072+
name=f"{name}_dark_conv_pad_1",
1073+
)(x)
10651074
x = layers.Conv2D(
10661075
filters=filters,
10671076
kernel_size=3,
@@ -1091,18 +1100,18 @@ def apply(x):
10911100
dtype=dtype,
10921101
name=f"{name}_dark_activation_1",
10931102
)(x)
1094-
for i in range(depth):
1095-
x = block_fn(
1096-
filters=block_channels,
1097-
dilation=dilation,
1098-
bottle_ratio=bottle_ratio,
1099-
groups=groups,
1100-
activation=activation,
1101-
data_format=data_format,
1102-
channel_axis=channel_axis,
1103-
dtype=dtype,
1104-
name=f"{name}_block_{i}",
1105-
)(x)
1103+
for i in range(depth):
1104+
x = block_fn(
1105+
filters=block_channels,
1106+
dilation=dilation,
1107+
bottle_ratio=bottle_ratio,
1108+
groups=groups,
1109+
activation=activation,
1110+
data_format=data_format,
1111+
channel_axis=channel_axis,
1112+
dtype=dtype,
1113+
name=f"{name}_block_{i}",
1114+
)(x)
11061115
return x
11071116

11081117
return apply
@@ -1135,6 +1144,13 @@ def apply(x):
11351144
or (i == last_idx and strides > 2 and not pooling)
11361145
else 1
11371146
)
1147+
if conv_strides > 1:
1148+
x = layers.ZeroPadding2D(
1149+
(kernel_size - 1) // 2,
1150+
data_format=data_format,
1151+
dtype=dtype,
1152+
name=f"csp_stem_pad_{i}",
1153+
)(x)
11381154
x = layers.Conv2D(
11391155
filters=chs,
11401156
kernel_size=kernel_size,
@@ -1167,10 +1183,19 @@ def apply(x):
11671183

11681184
if pooling == "max":
11691185
assert strides > 2
1186+
# Use manual padding to handle edge case scenario to ignore zero's
1187+
# as max value instead consider negative values from Leaky Relu type
1188+
# of activations.
1189+
pad_width = [[1, 1], [1, 1]]
1190+
if data_format == "channels_last":
1191+
pad_width += [[0, 0]]
1192+
else:
1193+
pad_width = [[0, 0]] + pad_width
1194+
pad_width = [[0, 0]] + pad_width
1195+
x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf"))
11701196
x = layers.MaxPooling2D(
11711197
pool_size=3,
11721198
strides=2,
1173-
padding="same",
11741199
data_format=data_format,
11751200
dtype=dtype,
11761201
name="csp_stem_pool",

keras_hub/src/models/cspnet/cspnet_backbone_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def setUp(self):
2222
"expand_ratio": (2.0,) + (1.0,),
2323
"block_type": "dark_block",
2424
"stage_type": "csp",
25+
"stem_padding": "same",
2526
}
2627
self.input_size = 64
2728
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
@@ -38,9 +39,9 @@ def test_backbone_basics(self, stage_type, block_type):
3839
"stage_type": stage_type,
3940
},
4041
input_data=self.input_data,
41-
expected_output_shape=(2, 6, 6, 48),
42+
expected_output_shape=(2, 8, 8, 48),
4243
expected_pyramid_output_keys=["P2", "P3", "P4"],
43-
expected_pyramid_image_sizes=[(30, 30), (14, 14), (6, 6)],
44+
expected_pyramid_image_sizes=[(32, 32), (16, 16), (8, 8)],
4445
)
4546

4647
@pytest.mark.large

keras_hub/src/models/cspnet/cspnet_presets.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,46 @@
66
"description": (
77
"A CSP-DarkNet (Cross-Stage-Partial) image classification model"
88
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
9-
"a 224x224 resolution."
9+
"a 256x256 resolution."
1010
),
11-
"params": 26652512,
11+
"params": 27642184,
1212
"path": "cspnet",
1313
},
14-
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/1",
14+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/2",
15+
},
16+
"csp_resnext_50_ra_imagenet": {
17+
"metadata": {
18+
"description": (
19+
"A CSP-ResNeXt (Cross-Stage-Partial) image classification model"
20+
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
21+
"a 256x256 resolution."
22+
),
23+
"params": 20569896,
24+
"path": "cspnet",
25+
},
26+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnext_50_ra_imagenet/1",
27+
},
28+
"csp_resnet_50_ra_imagenet": {
29+
"metadata": {
30+
"description": (
31+
"A CSP-ResNet (Cross-Stage-Partial) image classification model"
32+
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
33+
"a 256x256 resolution."
34+
),
35+
"params": 21616168,
36+
"path": "cspnet",
37+
},
38+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnet_50_ra_imagenet/1",
39+
},
40+
"darknet_53_imagenet": {
41+
"metadata": {
42+
"description": (
43+
"A DarkNet image classification model pre-trained on the"
44+
"ImageNet 1k dataset at a 256x256 resolution."
45+
),
46+
"params": 41609928,
47+
"path": "cspnet",
48+
},
49+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/darknet_53_imagenet/1",
1550
},
1651
}

0 commit comments

Comments
 (0)