From b1a74d099fae44d41750b79e58455282d919dd78 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 3 Jun 2025 22:58:56 -0700 Subject: [PATCH] fix: relax openapi spec to gemini schema conversion to tolerate more cases PiperOrigin-RevId: 766985694 --- src/google/adk/tools/_gemini_schema_util.py | 106 ++++++++----- ...ma_utils.py => test_gemini_schema_util.py} | 147 ++++++++++++++++-- 2 files changed, 205 insertions(+), 48 deletions(-) rename tests/unittests/tools/{test_gemini_schema_utils.py => test_gemini_schema_util.py} (74%) diff --git a/src/google/adk/tools/_gemini_schema_util.py b/src/google/adk/tools/_gemini_schema_util.py index 92bf770f5..020e38fce 100644 --- a/src/google/adk/tools/_gemini_schema_util.py +++ b/src/google/adk/tools/_gemini_schema_util.py @@ -74,45 +74,73 @@ def _to_snake_case(text: str) -> str: return text -def _sanitize_schema_formats_for_gemini(schema_node: Any) -> Any: - """Helper function to sanitize schema formats for Gemini compatibility""" - if isinstance(schema_node, dict): - new_node = {} - current_type = schema_node.get("type") - - for key, value in schema_node.items(): - key = _to_snake_case(key) - - # special handle of format field - if key == "format": - current_format = value - format_to_keep = None - if current_format: - if current_type == "integer" or current_type == "number": - if current_format in ("int32", "int64"): - format_to_keep = current_format - elif current_type == "string": - # only 'enum' and 'date-time' are supported for STRING type" - if current_format in ("date-time", "enum"): - format_to_keep = current_format - # For any other type or unhandled format - # the 'format' key will be effectively removed for that node. - if format_to_keep: - new_node[key] = format_to_keep - continue - # don't change property name - if key == "properties": - new_node[key] = { - k: _sanitize_schema_formats_for_gemini(v) for k, v in value.items() - } - continue - # Recursively sanitize other parts of the schema - new_node[key] = _sanitize_schema_formats_for_gemini(value) - return new_node - elif isinstance(schema_node, list): - return [_sanitize_schema_formats_for_gemini(item) for item in schema_node] - else: - return schema_node +def _sanitize_schema_type(schema: dict[str, Any]) -> dict[str, Any]: + if ("type" not in schema or not schema["type"]) and schema.keys().isdisjoint( + schema + ): + schema["type"] = "object" + if isinstance(schema.get("type"), list): + nullable = False + non_null_type = None + for t in schema["type"]: + if t == "null": + nullable = True + elif not non_null_type: + non_null_type = t + if not non_null_type: + non_null_type = "object" + if nullable: + schema["type"] = [non_null_type, "null"] + else: + schema["type"] = non_null_type + elif schema.get("type") == "null": + schema["type"] = ["object", "null"] + + return schema + + +def _sanitize_schema_formats_for_gemini( + schema: dict[str, Any], +) -> dict[str, Any]: + """Filters the schema to only include fields that are supported by JSONSchema.""" + supported_fields: set[str] = set(_ExtendedJSONSchema.model_fields.keys()) + schema_field_names: set[str] = {"items"} # 'additional_properties' to come + list_schema_field_names: set[str] = { + "any_of", # 'one_of', 'all_of', 'not' to come + } + snake_case_schema = {} + dict_schema_field_names: tuple[str] = ("properties",) # 'defs' to come + for field_name, field_value in schema.items(): + field_name = _to_snake_case(field_name) + if field_name in schema_field_names: + snake_case_schema[field_name] = _sanitize_schema_formats_for_gemini( + field_value + ) + elif field_name in list_schema_field_names: + snake_case_schema[field_name] = [ + _sanitize_schema_formats_for_gemini(value) for value in field_value + ] + elif field_name in dict_schema_field_names: + snake_case_schema[field_name] = { + key: _sanitize_schema_formats_for_gemini(value) + for key, value in field_value.items() + } + # special handle of format field + elif field_name == "format" and field_value: + current_type = schema.get("type") + if ( + # only "int32" and "int64" are supported for integer or number type + (current_type == "integer" or current_type == "number") + and field_value in ("int32", "int64") + or + # only 'enum' and 'date-time' are supported for STRING type" + (current_type == "string" and field_value in ("date-time", "enum")) + ): + snake_case_schema[field_name] = field_value + elif field_name in supported_fields and field_value is not None: + snake_case_schema[field_name] = field_value + + return _sanitize_schema_type(snake_case_schema) def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema: diff --git a/tests/unittests/tools/test_gemini_schema_utils.py b/tests/unittests/tools/test_gemini_schema_util.py similarity index 74% rename from tests/unittests/tools/test_gemini_schema_utils.py rename to tests/unittests/tools/test_gemini_schema_util.py index 56aeca1f8..71143debc 100644 --- a/tests/unittests/tools/test_gemini_schema_utils.py +++ b/tests/unittests/tools/test_gemini_schema_util.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.adk.tools._gemini_schema_util import _sanitize_schema_formats_for_gemini from google.adk.tools._gemini_schema_util import _to_gemini_schema from google.adk.tools._gemini_schema_util import _to_snake_case from google.genai.types import Schema @@ -31,7 +32,7 @@ def test_to_gemini_schema_not_dict(self): def test_to_gemini_schema_empty_dict(self): result = _to_gemini_schema({}) assert isinstance(result, Schema) - assert result.type is None + assert result.type is Type.OBJECT assert result.properties is None def test_to_gemini_schema_dict_with_only_object_type(self): @@ -64,10 +65,8 @@ def test_to_gemini_schema_array_string_types(self): "nonnullable_string": {"type": ["string"]}, "nullable_string": {"type": ["string", "null"]}, "nullable_number": {"type": ["null", "integer"]}, - "object_nullable": {"type": "null"}, # invalid - "multi_types_nullable": { - "type": ["string", "null", "integer"] - }, # invalid + "object_nullable": {"type": "null"}, + "multi_types_nullable": {"type": ["string", "null", "integer"]}, "empty_default_object": {}, }, } @@ -85,14 +84,14 @@ def test_to_gemini_schema_array_string_types(self): assert gemini_schema.properties["nullable_number"].type == Type.INTEGER assert gemini_schema.properties["nullable_number"].nullable - assert gemini_schema.properties["object_nullable"].type is None + assert gemini_schema.properties["object_nullable"].type == Type.OBJECT assert gemini_schema.properties["object_nullable"].nullable - assert gemini_schema.properties["multi_types_nullable"].type is None + assert gemini_schema.properties["multi_types_nullable"].type == Type.STRING assert gemini_schema.properties["multi_types_nullable"].nullable - assert gemini_schema.properties["empty_default_object"].type is None - assert not gemini_schema.properties["empty_default_object"].nullable + assert gemini_schema.properties["empty_default_object"].type == Type.OBJECT + assert gemini_schema.properties["empty_default_object"].nullable is None def test_to_gemini_schema_nested_objects(self): openapi_schema = { @@ -382,6 +381,136 @@ def test_to_gemini_schema_property_ordering(self): gemini_schema = _to_gemini_schema(openapi_schema) assert gemini_schema.property_ordering == ["name", "age"] + def test_sanitize_schema_formats_for_gemini(self): + schema = { + "type": "object", + "description": "Test schema", # Top-level description + "properties": { + "valid_int": {"type": "integer", "format": "int32"}, + "invalid_format_prop": {"type": "integer", "format": "unsigned"}, + "valid_string": {"type": "string", "format": "date-time"}, + "camelCaseKey": {"type": "string"}, + "prop_with_extra_key": { + "type": "boolean", + "unknownInternalKey": "discard_this_value", + }, + }, + "required": ["valid_int"], + "additionalProperties": False, # This is an unsupported top-level key + "unknownTopLevelKey": ( + "discard_me_too" + ), # Another unsupported top-level key + } + sanitized = _sanitize_schema_formats_for_gemini(schema) + + # Check description is preserved + assert sanitized["description"] == "Test schema" + + # Check properties and their sanitization + assert "properties" in sanitized + sanitized_props = sanitized["properties"] + + assert "valid_int" in sanitized_props + assert sanitized_props["valid_int"]["type"] == "integer" + assert sanitized_props["valid_int"]["format"] == "int32" + + assert "invalid_format_prop" in sanitized_props + assert sanitized_props["invalid_format_prop"]["type"] == "integer" + assert ( + "format" not in sanitized_props["invalid_format_prop"] + ) # Invalid format removed + + assert "valid_string" in sanitized_props + assert sanitized_props["valid_string"]["type"] == "string" + assert sanitized_props["valid_string"]["format"] == "date-time" + + # Check camelCase keys not changed for properties + assert "camel_case_key" not in sanitized_props + assert "camelCaseKey" in sanitized_props + assert sanitized_props["camelCaseKey"]["type"] == "string" + + # Check removal of unsupported keys within a property definition + assert "prop_with_extra_key" in sanitized_props + assert sanitized_props["prop_with_extra_key"]["type"] == "boolean" + assert ( + "unknown_internal_key" # snake_cased version of unknownInternalKey + not in sanitized_props["prop_with_extra_key"] + ) + + # Check removal of unsupported top-level fields (after snake_casing) + assert "additional_properties" not in sanitized + assert "unknown_top_level_key" not in sanitized + + # Check original unsupported top-level field names are not there either + assert "additionalProperties" not in sanitized + assert "unknownTopLevelKey" not in sanitized + + # Check required is preserved + assert sanitized["required"] == ["valid_int"] + + # Test with a schema that has a list of types for a property + schema_with_list_type = { + "type": "object", + "properties": { + "nullable_field": {"type": ["string", "null"], "format": "uuid"} + }, + } + sanitized_list_type = _sanitize_schema_formats_for_gemini( + schema_with_list_type + ) + # format should be removed because 'uuid' is not supported for string + assert "format" not in sanitized_list_type["properties"]["nullable_field"] + # type should be processed by _sanitize_schema_type and preserved + assert sanitized_list_type["properties"]["nullable_field"]["type"] == [ + "string", + "null", + ] + + def test_sanitize_schema_formats_for_gemini_nullable(self): + openapi_schema = { + "properties": { + "case_id": { + "description": "The ID of the case.", + "title": "Case Id", + "type": "string", + }, + "next_page_token": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": ( + "The nextPageToken to fetch the next page of results." + ), + "title": "Next Page Token", + }, + }, + "required": ["case_id"], + "title": "list_alerts_by_caseArguments", + "type": "object", + } + openapi_schema = _sanitize_schema_formats_for_gemini(openapi_schema) + assert openapi_schema == { + "properties": { + "case_id": { + "description": "The ID of the case.", + "title": "Case Id", + "type": "string", + }, + "next_page_token": { + "any_of": [ + {"type": "string"}, + {"type": ["object", "null"]}, + ], + "description": ( + "The nextPageToken to fetch the next page of results." + ), + "title": "Next Page Token", + }, + }, + "required": ["case_id"], + "title": "list_alerts_by_caseArguments", + "type": "object", + } + class TestToSnakeCase: