Skip to content

Commit fe7a33c

Browse files
committed
VertexAI: make FunctionCallPart.args nullable
1 parent 08deb69 commit fe7a33c

File tree

6 files changed

+26
-12
lines changed

6 files changed

+26
-12
lines changed

firebase-vertexai/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Unreleased
22
* [feature] added support for `responseSchema` in `GenerationConfig`.
3+
* [changed] Made `FunctionCallPart.args` nullable.
34

45
# 16.0.0-beta03
56
* [changed] Breaking Change: changed `Schema.int` to return 32 bit integers instead of 64 bit (long).

firebase-vertexai/firebase-vertexai.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ dependencies {
6161
implementation("com.google.firebase:firebase-components:18.0.0")
6262
implementation("com.google.firebase:firebase-annotations:16.2.0")
6363
implementation("com.google.firebase:firebase-appcheck-interop:17.1.0")
64-
implementation("com.google.ai.client.generativeai:common:0.9.0")
64+
implementation("com.google.ai.client.generativeai:common:0.10.0")
6565
implementation(libs.androidx.annotation)
6666
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
6767
implementation("androidx.core:core-ktx:1.12.0")

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ internal fun Part.toInternal(): com.google.ai.client.generativeai.common.shared.
8686
)
8787
)
8888
is com.google.firebase.vertexai.type.FunctionCallPart ->
89-
FunctionCallPart(FunctionCall(name, args.orEmpty()))
89+
FunctionCallPart(FunctionCall(name, args))
9090
is com.google.firebase.vertexai.type.FunctionResponsePart ->
9191
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
9292
is FileDataPart ->
@@ -220,7 +220,7 @@ internal fun com.google.ai.client.generativeai.common.shared.Part.toPublic(): Pa
220220
is FunctionCallPart ->
221221
com.google.firebase.vertexai.type.FunctionCallPart(
222222
functionCall.name,
223-
functionCall.args.orEmpty(),
223+
functionCall.args,
224224
)
225225
is FunctionResponsePart ->
226226
com.google.firebase.vertexai.type.FunctionResponsePart(

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclarations.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ fun <T, U, W, Z> defineFunction(
368368
) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function)
369369

370370
private fun <T> FunctionCallPart.getArgOrThrow(param: Schema<T>): T {
371-
return param.fromString(args[param.name])
371+
return param.fromString(args?.get(param.name))
372372
?: throw RuntimeException(
373373
"Missing argument for parameter \"${param.name}\" for function \"$name\""
374374
)

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class BlobPart(val mimeType: String, val blob: ByteArray) : Part
4949
* @param name the name of the function to call
5050
* @param args the function parameters and values as a [Map]
5151
*/
52-
class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part
52+
class FunctionCallPart(val name: String, val args: Map<String, String?>?) : Part
5353

5454
/**
5555
* Represents function call output to be returned to the model when it requests a function call.

firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ internal class UnarySnapshotTests {
335335
val response = model.generateContent("prompt")
336336
val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart)
337337

338-
callPart.args["season"] shouldBe null
338+
callPart.args shouldNotBe null
339+
callPart.args?.get("seasons") shouldBe null
339340
}
340341
}
341342

@@ -352,7 +353,19 @@ internal class UnarySnapshotTests {
352353
it.parts.first().shouldBeInstanceOf<FunctionCallPart>()
353354
}
354355

355-
callPart.args["current"] shouldBe "true"
356+
callPart.args?.get("current") shouldBe "true"
357+
}
358+
}
359+
360+
@Test
361+
fun `function call has no arguments field`() =
362+
goldenUnaryFile("unary-success-function-call-empty-arguments.json") {
363+
withTimeout(testTimeout) {
364+
val response = model.generateContent("prompt")
365+
val callPart = response.functionCalls.first()
366+
367+
callPart.name shouldBe "current_time"
368+
callPart.args shouldBe null
356369
}
357370
}
358371

@@ -364,7 +377,7 @@ internal class UnarySnapshotTests {
364377
val callPart = response.functionCalls.shouldNotBeEmpty().first()
365378

366379
callPart.name shouldBe "current_time"
367-
callPart.args.isEmpty() shouldBe true
380+
callPart.args?.isEmpty() shouldBe true
368381
}
369382
}
370383

@@ -376,8 +389,8 @@ internal class UnarySnapshotTests {
376389
val callPart = response.functionCalls.shouldNotBeEmpty().first()
377390

378391
callPart.name shouldBe "sum"
379-
callPart.args["x"] shouldBe "4"
380-
callPart.args["y"] shouldBe "5"
392+
callPart.args?.get("x") shouldBe "4"
393+
callPart.args?.get("y") shouldBe "5"
381394
}
382395
}
383396

@@ -391,7 +404,7 @@ internal class UnarySnapshotTests {
391404
callList.size shouldBe 3
392405
callList.forEach {
393406
it.name shouldBe "sum"
394-
it.args.size shouldBe 2
407+
it.args?.size shouldBe 2
395408
}
396409
}
397410
}
@@ -405,7 +418,7 @@ internal class UnarySnapshotTests {
405418

406419
response.text shouldBe "The sum of [1, 2, 3] is"
407420
callList.size shouldBe 2
408-
callList.forEach { it.args.size shouldBe 2 }
421+
callList.forEach { it.args?.size shouldBe 2 }
409422
}
410423
}
411424

0 commit comments

Comments
 (0)