Skip to content

VertexAI: make FunctionCallPart.args nullable #6106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase-vertexai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Unreleased
* [feature] added support for `responseSchema` in `GenerationConfig`.
* [changed] Made `FunctionCallPart.args` nullable.

# 16.0.0-beta03
* [changed] Breaking Change: changed `Schema.int` to return 32 bit integers instead of 64 bit (long).
Expand Down
2 changes: 1 addition & 1 deletion firebase-vertexai/firebase-vertexai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dependencies {
implementation("com.google.firebase:firebase-components:18.0.0")
implementation("com.google.firebase:firebase-annotations:16.2.0")
implementation("com.google.firebase:firebase-appcheck-interop:17.1.0")
implementation("com.google.ai.client.generativeai:common:0.9.0")
implementation("com.google.ai.client.generativeai:common:0.10.0")
implementation(libs.androidx.annotation)
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("androidx.core:core-ktx:1.12.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ internal fun Part.toInternal(): com.google.ai.client.generativeai.common.shared.
)
)
is com.google.firebase.vertexai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args.orEmpty()))
FunctionCallPart(FunctionCall(name, args))
is com.google.firebase.vertexai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
is FileDataPart ->
Expand Down Expand Up @@ -220,7 +220,7 @@ internal fun com.google.ai.client.generativeai.common.shared.Part.toPublic(): Pa
is FunctionCallPart ->
com.google.firebase.vertexai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty(),
functionCall.args,
)
is FunctionResponsePart ->
com.google.firebase.vertexai.type.FunctionResponsePart(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ fun <T, U, W, Z> defineFunction(
) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function)

private fun <T> FunctionCallPart.getArgOrThrow(param: Schema<T>): T {
return param.fromString(args[param.name])
return param.fromString(args?.get(param.name))
?: throw RuntimeException(
"Missing argument for parameter \"${param.name}\" for function \"$name\""
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BlobPart(val mimeType: String, val blob: ByteArray) : Part
* @param name the name of the function to call
* @param args the function parameters and values as a [Map]
*/
class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part
class FunctionCallPart(val name: String, val args: Map<String, String?>?) : Part

/**
* Represents function call output to be returned to the model when it requests a function call.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart)

callPart.args["season"] shouldBe null
callPart.args shouldNotBe null
callPart.args?.get("seasons") shouldBe null
}
}

Expand All @@ -352,7 +353,19 @@ internal class UnarySnapshotTests {
it.parts.first().shouldBeInstanceOf<FunctionCallPart>()
}

callPart.args["current"] shouldBe "true"
callPart.args?.get("current") shouldBe "true"
}
}

@Test
fun `function call has no arguments field`() =
goldenUnaryFile("unary-success-function-call-empty-arguments.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.first()

callPart.name shouldBe "current_time"
callPart.args shouldBe null
}
}

Expand All @@ -364,7 +377,7 @@ internal class UnarySnapshotTests {
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.name shouldBe "current_time"
callPart.args.isEmpty() shouldBe true
callPart.args?.isEmpty() shouldBe true
}
}

Expand All @@ -376,8 +389,8 @@ internal class UnarySnapshotTests {
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.name shouldBe "sum"
callPart.args["x"] shouldBe "4"
callPart.args["y"] shouldBe "5"
callPart.args?.get("x") shouldBe "4"
callPart.args?.get("y") shouldBe "5"
}
}

Expand All @@ -391,7 +404,7 @@ internal class UnarySnapshotTests {
callList.size shouldBe 3
callList.forEach {
it.name shouldBe "sum"
it.args.size shouldBe 2
it.args?.size shouldBe 2
}
}
}
Expand All @@ -405,7 +418,7 @@ internal class UnarySnapshotTests {

response.text shouldBe "The sum of [1, 2, 3] is"
callList.size shouldBe 2
callList.forEach { it.args.size shouldBe 2 }
callList.forEach { it.args?.size shouldBe 2 }
}
}

Expand Down
Loading