Skip to content

Commit ff59d95

Browse files
committed
Add CountTokensIntegrationTests
1 parent 5ad783d commit ff59d95

10 files changed

+161
-45
lines changed

FirebaseVertexAI/Sources/CountTokensRequest.swift

+8-8
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@ import Foundation
1717
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
1818
struct CountTokensRequest {
1919
let generateContentRequest: GenerateContentRequest
20-
21-
let apiConfig: APIConfig
22-
let options: RequestOptions
2320
}
2421

2522
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
2623
extension CountTokensRequest: GenerativeAIRequest {
2724
typealias Response = CountTokensResponse
2825

26+
var options: RequestOptions { generateContentRequest.options }
27+
28+
var apiConfig: APIConfig { generateContentRequest.apiConfig }
29+
2930
var url: URL {
30-
URL(string:
31-
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
31+
let version = apiConfig.version.rawValue
32+
let endpoint = apiConfig.service.endpoint.rawValue
33+
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
3234
}
3335
}
3436

@@ -64,9 +66,7 @@ extension CountTokensRequest: Encodable {
6466
}
6567

6668
func encode(to encoder: any Encoder) throws {
67-
let backendAPI = encoder.userInfo[CodingUserInfoKey(rawValue: "BackendAPI")!] as! BackendAPI
68-
69-
switch backendAPI {
69+
switch apiConfig.service {
7070
case .vertexAI:
7171
var container = encoder.container(keyedBy: VertexCodingKeys.self)
7272
try container.encode(generateContentRequest.contents, forKey: .contents)

FirebaseVertexAI/Sources/FirebaseInfo.swift

+1-4
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,18 @@ struct FirebaseInfo: Sendable {
2828
let apiKey: String
2929
let googleAppID: String
3030
let app: FirebaseApp
31-
let backendAPI: BackendAPI
3231

3332
init(appCheck: AppCheckInterop? = nil,
3433
auth: AuthInterop? = nil,
3534
projectID: String,
3635
apiKey: String,
3736
googleAppID: String,
38-
firebaseApp: FirebaseApp,
39-
backendAPI: BackendAPI) {
37+
firebaseApp: FirebaseApp) {
4038
self.appCheck = appCheck
4139
self.auth = auth
4240
self.projectID = projectID
4341
self.apiKey = apiKey
4442
self.googleAppID = googleAppID
4543
app = firebaseApp
46-
self.backendAPI = backendAPI
4744
}
4845
}

FirebaseVertexAI/Sources/GenerateContentRequest.swift

+15
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ import Foundation
1818
struct GenerateContentRequest: Sendable {
1919
/// Model name.
2020
let model: String
21+
2122
let contents: [ModelContent]
2223
let generationConfig: GenerationConfig?
2324
let safetySettings: [SafetySetting]?
2425
let tools: [Tool]?
2526
let toolConfig: ToolConfig?
2627
let systemInstruction: ModelContent?
28+
2729
let apiConfig: APIConfig
2830
let apiMethod: APIMethod
2931
let options: RequestOptions
@@ -40,6 +42,19 @@ extension GenerateContentRequest: Encodable {
4042
case toolConfig
4143
case systemInstruction
4244
}
45+
46+
func encode(to encoder: any Encoder) throws {
47+
var container = encoder.container(keyedBy: CodingKeys.self)
48+
if apiMethod == .countTokens {
49+
try container.encode(model, forKey: .model)
50+
}
51+
try container.encode(contents, forKey: .contents)
52+
try container.encodeIfPresent(generationConfig, forKey: .generationConfig)
53+
try container.encodeIfPresent(safetySettings, forKey: .safetySettings)
54+
try container.encodeIfPresent(tools, forKey: .tools)
55+
try container.encodeIfPresent(toolConfig, forKey: .toolConfig)
56+
try container.encodeIfPresent(systemInstruction, forKey: .systemInstruction)
57+
}
4358
}
4459

4560
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseVertexAI/Sources/GenerativeModel.swift

+14-2
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,27 @@ public final class GenerativeModel: Sendable {
260260
/// - Returns: The results of running the model's tokenizer on the input; contains
261261
/// ``CountTokensResponse/totalTokens``.
262262
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
263+
let requestContent = switch apiConfig.service {
264+
case .vertexAI:
265+
content
266+
case .developer:
267+
// The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
268+
// erroneously counted towards the prompt and total token count when using the Developer API
269+
// backend; set to `nil` to avoid token count discrepancies between `countTokens` and
270+
// `generateContent` and the two backend APIs.
271+
content.map { ModelContent(role: nil, parts: $0.parts) }
272+
}
273+
263274
let generateContentRequest = GenerateContentRequest(
264275
model: modelResourceName,
265-
contents: content,
276+
contents: requestContent,
266277
generationConfig: generationConfig,
267278
safetySettings: safetySettings,
268279
tools: tools,
269280
toolConfig: toolConfig,
270281
systemInstruction: systemInstruction,
271-
isStreaming: false,
282+
apiConfig: apiConfig,
283+
apiMethod: .countTokens,
272284
options: requestOptions
273285
)
274286
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)

FirebaseVertexAI/Sources/VertexAI.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ public class VertexAI {
178178
projectID: projectID,
179179
apiKey: apiKey,
180180
googleAppID: app.options.googleAppID,
181-
firebaseApp: app,
182-
backendAPI: .vertexAI
181+
firebaseApp: app
183182
)
184183
self.apiConfig = apiConfig
185184
self.location = location

FirebaseVertexAI/Tests/TestApp/Sources/Constants.swift

+1
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ public enum FirebaseAppNames {
2121
}
2222

2323
public enum ModelNames {
24+
public static let gemini2Flash = "gemini-2.0-flash-001"
2425
public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
2526
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseAuth
16+
import FirebaseCore
17+
import FirebaseStorage
18+
import FirebaseVertexAI
19+
import Testing
20+
import VertexAITestApp
21+
22+
@testable import struct FirebaseVertexAI.APIConfig
23+
24+
@Suite(.serialized)
25+
struct CountTokensIntegrationTests {
26+
let generationConfig = GenerationConfig(
27+
temperature: 1.2,
28+
topP: 0.95,
29+
topK: 32,
30+
candidateCount: 1,
31+
maxOutputTokens: 8192,
32+
presencePenalty: 1.5,
33+
frequencyPenalty: 1.75,
34+
stopSequences: ["cat", "dog", "bird"]
35+
)
36+
let safetySettings = [
37+
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
38+
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
39+
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
40+
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
41+
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
42+
]
43+
44+
@Test(arguments: InstanceConfig.allConfigs)
45+
func countTokens_text(_ config: InstanceConfig) async throws {
46+
let prompt = "Why is the sky blue?"
47+
let model = VertexAI.componentInstance(config).generativeModel(
48+
modelName: ModelNames.gemini2Flash,
49+
generationConfig: generationConfig,
50+
safetySettings: safetySettings
51+
)
52+
53+
let response = try await model.countTokens(prompt)
54+
55+
#expect(response.totalTokens == 6)
56+
switch config.apiConfig.service {
57+
case .vertexAI:
58+
#expect(response.totalBillableCharacters == 16)
59+
case .developer:
60+
#expect(response.totalBillableCharacters == nil)
61+
}
62+
#expect(response.promptTokensDetails.count == 1)
63+
let promptTokensDetails = try #require(response.promptTokensDetails.first)
64+
#expect(promptTokensDetails.modality == .text)
65+
#expect(promptTokensDetails.tokenCount == response.totalTokens)
66+
}
67+
}

FirebaseVertexAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift

+4-25
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,8 @@ import FirebaseVertexAI
1919
import Testing
2020
import VertexAITestApp
2121

22-
@testable import struct FirebaseVertexAI.APIConfig
23-
2422
@Suite(.serialized)
2523
struct GenerateContentIntegrationTests {
26-
static let vertexV1Config =
27-
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1))
28-
static let vertexV1BetaConfig =
29-
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1beta))
30-
static let developerV1Config = InstanceConfig(
31-
appName: FirebaseAppNames.spark,
32-
apiConfig: APIConfig(
33-
service: .developer(endpoint: .generativeLanguage), version: .v1
34-
)
35-
)
36-
static let developerV1BetaConfig = InstanceConfig(
37-
appName: FirebaseAppNames.spark,
38-
apiConfig: APIConfig(
39-
service: .developer(endpoint: .generativeLanguage), version: .v1beta
40-
)
41-
)
42-
static let allConfigs =
43-
[vertexV1Config, vertexV1BetaConfig, developerV1Config, developerV1BetaConfig]
44-
4524
// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
4625
let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
4726
let safetySettings = [
@@ -67,7 +46,7 @@ struct GenerateContentIntegrationTests {
6746
storage = Storage.storage()
6847
}
6948

70-
@Test(arguments: allConfigs)
49+
@Test(arguments: InstanceConfig.allConfigs)
7150
func generateContent(_ config: InstanceConfig) async throws {
7251
let model = VertexAI.componentInstance(config).generativeModel(
7352
modelName: ModelNames.gemini2FlashLite,
@@ -98,10 +77,10 @@ struct GenerateContentIntegrationTests {
9877
@Test(
9978
"Generate an enum and provide a system instruction",
10079
arguments: [
101-
vertexV1Config,
102-
vertexV1BetaConfig,
80+
InstanceConfig.vertexV1,
81+
InstanceConfig.vertexV1Beta,
10382
/* System instructions are not supported on the v1 Developer API. */
104-
developerV1BetaConfig,
83+
InstanceConfig.developerV1Beta,
10584
]
10685
)
10786
func generateContentEnum(_ config: InstanceConfig) async throws {

FirebaseVertexAI/Tests/TestApp/Tests/Utilities/VertexAITestUtils.swift FirebaseVertexAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift

+42
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,25 @@
1313
// limitations under the License.
1414

1515
import FirebaseCore
16+
import Testing
1617
import VertexAITestApp
1718

1819
@testable import struct FirebaseVertexAI.APIConfig
1920
@testable import class FirebaseVertexAI.VertexAI
2021

2122
struct InstanceConfig {
23+
static let vertexV1 = InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1))
24+
static let vertexV1Beta = InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1beta))
25+
static let developerV1 = InstanceConfig(
26+
appName: FirebaseAppNames.spark,
27+
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1)
28+
)
29+
static let developerV1Beta = InstanceConfig(
30+
appName: FirebaseAppNames.spark,
31+
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)
32+
)
33+
static let allConfigs = [vertexV1, vertexV1Beta, developerV1, developerV1Beta]
34+
2235
let appName: String?
2336
let location: String?
2437
let apiConfig: APIConfig
@@ -32,6 +45,35 @@ struct InstanceConfig {
3245
var app: FirebaseApp? {
3346
return appName.map { FirebaseApp.app(name: $0) } ?? FirebaseApp.app()
3447
}
48+
49+
var serviceName: String {
50+
switch apiConfig.service {
51+
case .vertexAI:
52+
return "Vertex AI"
53+
case .developer:
54+
return "Developer"
55+
}
56+
}
57+
58+
var versionName: String {
59+
return apiConfig.version.rawValue
60+
}
61+
}
62+
63+
extension InstanceConfig: CustomTestStringConvertible {
64+
var testDescription: String {
65+
let endpointSuffix = switch apiConfig.service.endpoint {
66+
case .firebaseVertexAIProd:
67+
""
68+
case .firebaseVertexAIStaging:
69+
" - Staging"
70+
case .generativeLanguage:
71+
" - Generative Language"
72+
}
73+
let locationSuffix = location.map { " - \($0)" } ?? ""
74+
75+
return "\(serviceName) (\(versionName))\(endpointSuffix)\(locationSuffix)"
76+
}
3577
}
3678

3779
extension VertexAI {

FirebaseVertexAI/Tests/TestApp/VertexAITestApp.xcodeproj/project.pbxproj

+8-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
8661385C2CC943DD00F4B78E /* TestApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385B2CC943DD00F4B78E /* TestApp.swift */; };
1313
8661385E2CC943DD00F4B78E /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385D2CC943DD00F4B78E /* ContentView.swift */; };
1414
8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661386D2CC943DE00F4B78E /* IntegrationTests.swift */; };
15+
8689CDCC2D7F8BD700BF426B /* CountTokensIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */; };
1516
868A7C482CCA931B00E449DD /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 868A7C462CCA931B00E449DD /* GoogleService-Info.plist */; };
1617
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */ = {isa = PBXBuildFile; fileRef = 868A7C4D2CCC1F4700E449DD /* Credentials.swift */; };
1718
868A7C522CCC263300E449DD /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 868A7C502CCC263300E449DD /* Preview Assets.xcassets */; };
@@ -25,7 +26,7 @@
2526
86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */; };
2627
86D77DFE2D7B5C86003D155D /* GoogleService-Info-Spark.plist in Resources */ = {isa = PBXBuildFile; fileRef = 86D77DFD2D7B5C86003D155D /* GoogleService-Info-Spark.plist */; };
2728
86D77E022D7B63AF003D155D /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77E012D7B63AC003D155D /* Constants.swift */; };
28-
86D77E042D7B6C9D003D155D /* VertexAITestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77E032D7B6C95003D155D /* VertexAITestUtils.swift */; };
29+
86D77E042D7B6C9D003D155D /* InstanceConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77E032D7B6C95003D155D /* InstanceConfig.swift */; };
2930
/* End PBXBuildFile section */
3031

3132
/* Begin PBXContainerItemProxy section */
@@ -46,6 +47,7 @@
4647
8661385D2CC943DD00F4B78E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
4748
866138692CC943DE00F4B78E /* IntegrationTests-SPM.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "IntegrationTests-SPM.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
4849
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IntegrationTests.swift; sourceTree = "<group>"; };
50+
8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CountTokensIntegrationTests.swift; sourceTree = "<group>"; };
4951
868A7C462CCA931B00E449DD /* GoogleService-Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = "GoogleService-Info.plist"; sourceTree = "<group>"; };
5052
868A7C4D2CCC1F4700E449DD /* Credentials.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Credentials.swift; sourceTree = "<group>"; };
5153
868A7C502CCC263300E449DD /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
@@ -56,7 +58,7 @@
5658
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerateContentIntegrationTests.swift; sourceTree = "<group>"; };
5759
86D77DFD2D7B5C86003D155D /* GoogleService-Info-Spark.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = "GoogleService-Info-Spark.plist"; sourceTree = "<group>"; };
5860
86D77E012D7B63AC003D155D /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = "<group>"; };
59-
86D77E032D7B6C95003D155D /* VertexAITestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VertexAITestUtils.swift; sourceTree = "<group>"; };
61+
86D77E032D7B6C95003D155D /* InstanceConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InstanceConfig.swift; sourceTree = "<group>"; };
6062
/* End PBXFileReference section */
6163

6264
/* Begin PBXFrameworksBuildPhase section */
@@ -134,6 +136,7 @@
134136
868A7C572CCC27AF00E449DD /* Integration */ = {
135137
isa = PBXGroup;
136138
children = (
139+
8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */,
137140
868A7C4D2CCC1F4700E449DD /* Credentials.swift */,
138141
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */,
139142
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */,
@@ -154,7 +157,7 @@
154157
8698D7442CD3CEF700ABA833 /* Utilities */ = {
155158
isa = PBXGroup;
156159
children = (
157-
86D77E032D7B6C95003D155D /* VertexAITestUtils.swift */,
160+
86D77E032D7B6C95003D155D /* InstanceConfig.swift */,
158161
8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */,
159162
862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */,
160163
);
@@ -283,7 +286,8 @@
283286
isa = PBXSourcesBuildPhase;
284287
buildActionMask = 2147483647;
285288
files = (
286-
86D77E042D7B6C9D003D155D /* VertexAITestUtils.swift in Sources */,
289+
8689CDCC2D7F8BD700BF426B /* CountTokensIntegrationTests.swift in Sources */,
290+
86D77E042D7B6C9D003D155D /* InstanceConfig.swift in Sources */,
287291
8698D7462CD3CF3600ABA833 /* FirebaseAppTestUtils.swift in Sources */,
288292
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */,
289293
864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */,

0 commit comments

Comments
 (0)