Skip to content

Commit 3f1af5f

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: relax openapi spec to gemini schema conversion to tolerate more cases
PiperOrigin-RevId: 766903673
1 parent 98a635a commit 3f1af5f

File tree

2 files changed

+205
-48
lines changed

2 files changed

+205
-48
lines changed

src/google/adk/tools/_gemini_schema_util.py

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -74,45 +74,73 @@ def _to_snake_case(text: str) -> str:
7474
return text
7575

7676

77-
def _sanitize_schema_formats_for_gemini(schema_node: Any) -> Any:
78-
"""Helper function to sanitize schema formats for Gemini compatibility"""
79-
if isinstance(schema_node, dict):
80-
new_node = {}
81-
current_type = schema_node.get("type")
82-
83-
for key, value in schema_node.items():
84-
key = _to_snake_case(key)
85-
86-
# special handle of format field
87-
if key == "format":
88-
current_format = value
89-
format_to_keep = None
90-
if current_format:
91-
if current_type == "integer" or current_type == "number":
92-
if current_format in ("int32", "int64"):
93-
format_to_keep = current_format
94-
elif current_type == "string":
95-
# only 'enum' and 'date-time' are supported for STRING type"
96-
if current_format in ("date-time", "enum"):
97-
format_to_keep = current_format
98-
# For any other type or unhandled format
99-
# the 'format' key will be effectively removed for that node.
100-
if format_to_keep:
101-
new_node[key] = format_to_keep
102-
continue
103-
# don't change property name
104-
if key == "properties":
105-
new_node[key] = {
106-
k: _sanitize_schema_formats_for_gemini(v) for k, v in value.items()
107-
}
108-
continue
109-
# Recursively sanitize other parts of the schema
110-
new_node[key] = _sanitize_schema_formats_for_gemini(value)
111-
return new_node
112-
elif isinstance(schema_node, list):
113-
return [_sanitize_schema_formats_for_gemini(item) for item in schema_node]
114-
else:
115-
return schema_node
77+
def _sanitize_schema_type(schema: dict[str, Any]) -> dict[str, Any]:
78+
if ("type" not in schema or not schema["type"]) and schema.keys().isdisjoint(
79+
schema
80+
):
81+
schema["type"] = "object"
82+
if isinstance(schema.get("type"), list):
83+
nullable = False
84+
non_null_type = None
85+
for t in schema["type"]:
86+
if t == "null":
87+
nullable = True
88+
elif not non_null_type:
89+
non_null_type = t
90+
if not non_null_type:
91+
non_null_type = "object"
92+
if nullable:
93+
schema["type"] = [non_null_type, "null"]
94+
else:
95+
schema["type"] = non_null_type
96+
elif schema.get("type") == "null":
97+
schema["type"] = ["object", "null"]
98+
99+
return schema
100+
101+
102+
def _sanitize_schema_formats_for_gemini(
103+
schema: dict[str, Any],
104+
) -> dict[str, Any]:
105+
"""Filters the schema to only include fields that are supported by JSONSchema."""
106+
supported_fields: set[str] = set(_ExtendedJSONSchema.model_fields.keys())
107+
schema_field_names: set[str] = {"items"} # 'additional_properties' to come
108+
list_schema_field_names: set[str] = {
109+
"any_of", # 'one_of', 'all_of', 'not' to come
110+
}
111+
snake_case_schema = {}
112+
dict_schema_field_names: tuple[str] = ("properties",) # 'defs' to come
113+
for field_name, field_value in schema.items():
114+
field_name = _to_snake_case(field_name)
115+
if field_name in schema_field_names:
116+
snake_case_schema[field_name] = _sanitize_schema_formats_for_gemini(
117+
field_value
118+
)
119+
elif field_name in list_schema_field_names:
120+
snake_case_schema[field_name] = [
121+
_sanitize_schema_formats_for_gemini(value) for value in field_value
122+
]
123+
elif field_name in dict_schema_field_names:
124+
snake_case_schema[field_name] = {
125+
key: _sanitize_schema_formats_for_gemini(value)
126+
for key, value in field_value.items()
127+
}
128+
# special handle of format field
129+
elif field_name == "format" and field_value:
130+
current_type = schema.get("type")
131+
if (
132+
# only "int32" and "int64" are supported for integer or number type
133+
(current_type == "integer" or current_type == "number")
134+
and field_value in ("int32", "int64")
135+
or
136+
# only 'enum' and 'date-time' are supported for STRING type"
137+
(current_type == "string" and field_value in ("date-time", "enum"))
138+
):
139+
snake_case_schema[field_name] = field_value
140+
elif field_name in supported_fields and field_value is not None:
141+
snake_case_schema[field_name] = field_value
142+
143+
return _sanitize_schema_type(snake_case_schema)
116144

117145

118146
def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema:

tests/unittests/tools/test_gemini_schema_utils.py renamed to tests/unittests/tools/test_gemini_schema_util.py

Lines changed: 138 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.adk.tools._gemini_schema_util import _sanitize_schema_formats_for_gemini
1516
from google.adk.tools._gemini_schema_util import _to_gemini_schema
1617
from google.adk.tools._gemini_schema_util import _to_snake_case
1718
from google.genai.types import Schema
@@ -31,7 +32,7 @@ def test_to_gemini_schema_not_dict(self):
3132
def test_to_gemini_schema_empty_dict(self):
3233
result = _to_gemini_schema({})
3334
assert isinstance(result, Schema)
34-
assert result.type is None
35+
assert result.type is Type.OBJECT
3536
assert result.properties is None
3637

3738
def test_to_gemini_schema_dict_with_only_object_type(self):
@@ -64,10 +65,8 @@ def test_to_gemini_schema_array_string_types(self):
6465
"nonnullable_string": {"type": ["string"]},
6566
"nullable_string": {"type": ["string", "null"]},
6667
"nullable_number": {"type": ["null", "integer"]},
67-
"object_nullable": {"type": "null"}, # invalid
68-
"multi_types_nullable": {
69-
"type": ["string", "null", "integer"]
70-
}, # invalid
68+
"object_nullable": {"type": "null"},
69+
"multi_types_nullable": {"type": ["string", "null", "integer"]},
7170
"empty_default_object": {},
7271
},
7372
}
@@ -85,14 +84,14 @@ def test_to_gemini_schema_array_string_types(self):
8584
assert gemini_schema.properties["nullable_number"].type == Type.INTEGER
8685
assert gemini_schema.properties["nullable_number"].nullable
8786

88-
assert gemini_schema.properties["object_nullable"].type is None
87+
assert gemini_schema.properties["object_nullable"].type == Type.OBJECT
8988
assert gemini_schema.properties["object_nullable"].nullable
9089

91-
assert gemini_schema.properties["multi_types_nullable"].type is None
90+
assert gemini_schema.properties["multi_types_nullable"].type == Type.STRING
9291
assert gemini_schema.properties["multi_types_nullable"].nullable
9392

94-
assert gemini_schema.properties["empty_default_object"].type is None
95-
assert not gemini_schema.properties["empty_default_object"].nullable
93+
assert gemini_schema.properties["empty_default_object"].type == Type.OBJECT
94+
assert gemini_schema.properties["empty_default_object"].nullable is None
9695

9796
def test_to_gemini_schema_nested_objects(self):
9897
openapi_schema = {
@@ -382,6 +381,136 @@ def test_to_gemini_schema_property_ordering(self):
382381
gemini_schema = _to_gemini_schema(openapi_schema)
383382
assert gemini_schema.property_ordering == ["name", "age"]
384383

384+
def test_sanitize_schema_formats_for_gemini(self):
385+
schema = {
386+
"type": "object",
387+
"description": "Test schema", # Top-level description
388+
"properties": {
389+
"valid_int": {"type": "integer", "format": "int32"},
390+
"invalid_format_prop": {"type": "integer", "format": "unsigned"},
391+
"valid_string": {"type": "string", "format": "date-time"},
392+
"camelCaseKey": {"type": "string"},
393+
"prop_with_extra_key": {
394+
"type": "boolean",
395+
"unknownInternalKey": "discard_this_value",
396+
},
397+
},
398+
"required": ["valid_int"],
399+
"additionalProperties": False, # This is an unsupported top-level key
400+
"unknownTopLevelKey": (
401+
"discard_me_too"
402+
), # Another unsupported top-level key
403+
}
404+
sanitized = _sanitize_schema_formats_for_gemini(schema)
405+
406+
# Check description is preserved
407+
assert sanitized["description"] == "Test schema"
408+
409+
# Check properties and their sanitization
410+
assert "properties" in sanitized
411+
sanitized_props = sanitized["properties"]
412+
413+
assert "valid_int" in sanitized_props
414+
assert sanitized_props["valid_int"]["type"] == "integer"
415+
assert sanitized_props["valid_int"]["format"] == "int32"
416+
417+
assert "invalid_format_prop" in sanitized_props
418+
assert sanitized_props["invalid_format_prop"]["type"] == "integer"
419+
assert (
420+
"format" not in sanitized_props["invalid_format_prop"]
421+
) # Invalid format removed
422+
423+
assert "valid_string" in sanitized_props
424+
assert sanitized_props["valid_string"]["type"] == "string"
425+
assert sanitized_props["valid_string"]["format"] == "date-time"
426+
427+
# Check camelCase keys not changed for properties
428+
assert "camel_case_key" not in sanitized_props
429+
assert "camelCaseKey" in sanitized_props
430+
assert sanitized_props["camelCaseKey"]["type"] == "string"
431+
432+
# Check removal of unsupported keys within a property definition
433+
assert "prop_with_extra_key" in sanitized_props
434+
assert sanitized_props["prop_with_extra_key"]["type"] == "boolean"
435+
assert (
436+
"unknown_internal_key" # snake_cased version of unknownInternalKey
437+
not in sanitized_props["prop_with_extra_key"]
438+
)
439+
440+
# Check removal of unsupported top-level fields (after snake_casing)
441+
assert "additional_properties" not in sanitized
442+
assert "unknown_top_level_key" not in sanitized
443+
444+
# Check original unsupported top-level field names are not there either
445+
assert "additionalProperties" not in sanitized
446+
assert "unknownTopLevelKey" not in sanitized
447+
448+
# Check required is preserved
449+
assert sanitized["required"] == ["valid_int"]
450+
451+
# Test with a schema that has a list of types for a property
452+
schema_with_list_type = {
453+
"type": "object",
454+
"properties": {
455+
"nullable_field": {"type": ["string", "null"], "format": "uuid"}
456+
},
457+
}
458+
sanitized_list_type = _sanitize_schema_formats_for_gemini(
459+
schema_with_list_type
460+
)
461+
# format should be removed because 'uuid' is not supported for string
462+
assert "format" not in sanitized_list_type["properties"]["nullable_field"]
463+
# type should be processed by _sanitize_schema_type and preserved
464+
assert sanitized_list_type["properties"]["nullable_field"]["type"] == [
465+
"string",
466+
"null",
467+
]
468+
469+
def test_sanitize_schema_formats_for_gemini_nullable(self):
470+
openapi_schema = {
471+
"properties": {
472+
"case_id": {
473+
"description": "The ID of the case.",
474+
"title": "Case Id",
475+
"type": "string",
476+
},
477+
"next_page_token": {
478+
"anyOf": [{"type": "string"}, {"type": "null"}],
479+
"default": None,
480+
"description": (
481+
"The nextPageToken to fetch the next page of results."
482+
),
483+
"title": "Next Page Token",
484+
},
485+
},
486+
"required": ["case_id"],
487+
"title": "list_alerts_by_caseArguments",
488+
"type": "object",
489+
}
490+
openapi_schema = _sanitize_schema_formats_for_gemini(openapi_schema)
491+
assert openapi_schema == {
492+
"properties": {
493+
"case_id": {
494+
"description": "The ID of the case.",
495+
"title": "Case Id",
496+
"type": "string",
497+
},
498+
"next_page_token": {
499+
"any_of": [
500+
{"type": "string"},
501+
{"type": ["object", "null"]},
502+
],
503+
"description": (
504+
"The nextPageToken to fetch the next page of results."
505+
),
506+
"title": "Next Page Token",
507+
},
508+
},
509+
"required": ["case_id"],
510+
"title": "list_alerts_by_caseArguments",
511+
"type": "object",
512+
}
513+
385514

386515
class TestToSnakeCase:
387516

0 commit comments

Comments
 (0)