Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Add Developer API encoding CountTokensRequest #14512

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,21 @@ import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let model: String

let contents: [ModelContent]
let systemInstruction: ModelContent?
let tools: [Tool]?
let generationConfig: GenerationConfig?

let apiConfig: APIConfig
let options: RequestOptions
let generateContentRequest: GenerateContentRequest
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension CountTokensRequest: GenerativeAIRequest {
typealias Response = CountTokensResponse

var options: RequestOptions { generateContentRequest.options }

var apiConfig: APIConfig { generateContentRequest.apiConfig }

var url: URL {
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
let version = apiConfig.version.rawValue
let endpoint = apiConfig.service.endpoint.rawValue
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
}
}

Expand All @@ -57,12 +54,42 @@ public struct CountTokensResponse {

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension CountTokensRequest: Encodable {
enum CodingKeys: CodingKey {
enum VertexCodingKeys: CodingKey {
case contents
case systemInstruction
case tools
case generationConfig
}

enum DeveloperCodingKeys: CodingKey {
case generateContentRequest
}

func encode(to encoder: any Encoder) throws {
switch apiConfig.service {
case .vertexAI:
try encodeForVertexAI(to: encoder)
case .developer:
try encodeForDeveloper(to: encoder)
}
}

private func encodeForVertexAI(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: VertexCodingKeys.self)
try container.encode(generateContentRequest.contents, forKey: .contents)
try container.encodeIfPresent(
generateContentRequest.systemInstruction, forKey: .systemInstruction
)
try container.encodeIfPresent(generateContentRequest.tools, forKey: .tools)
try container.encodeIfPresent(
generateContentRequest.generationConfig, forKey: .generationConfig
)
}

private func encodeForDeveloper(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: DeveloperCodingKeys.self)
try container.encode(generateContentRequest, forKey: .generateContentRequest)
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
19 changes: 19 additions & 0 deletions FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import Foundation
struct GenerateContentRequest: Sendable {
/// Model name.
let model: String

let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let toolConfig: ToolConfig?
let systemInstruction: ModelContent?

let apiConfig: APIConfig
let apiMethod: APIMethod
let options: RequestOptions
Expand All @@ -32,13 +34,30 @@ struct GenerateContentRequest: Sendable {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension GenerateContentRequest: Encodable {
enum CodingKeys: String, CodingKey {
case model
case contents
case generationConfig
case safetySettings
case tools
case toolConfig
case systemInstruction
}

func encode(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
// The model name only needs to be encoded when this `GenerateContentRequest` instance is used
// in a `CountTokensRequest` (calling `countTokens`). When calling `generateContent` or
// `generateContentStream`, the `model` field is populated in the backend from the `url`.
if apiMethod == .countTokens {
try container.encode(model, forKey: .model)
}
try container.encode(contents, forKey: .contents)
try container.encodeIfPresent(generationConfig, forKey: .generationConfig)
try container.encodeIfPresent(safetySettings, forKey: .safetySettings)
try container.encodeIfPresent(tools, forKey: .tools)
try container.encodeIfPresent(toolConfig, forKey: .toolConfig)
try container.encodeIfPresent(systemInstruction, forKey: .systemInstruction)
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
24 changes: 20 additions & 4 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,31 @@ public final class GenerativeModel: Sendable {
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
let requestContent = switch apiConfig.service {
case .vertexAI:
content
case .developer:
// The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
// erroneously counted towards the prompt and total token count when using the Developer API
// backend; set to `nil` to avoid token count discrepancies between `countTokens` and
// `generateContent` and the two backend APIs.
content.map { ModelContent(role: nil, parts: $0.parts) }
}

let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
contents: content,
systemInstruction: systemInstruction,
tools: tools,
contents: requestContent,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .countTokens,
options: requestOptions
)
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)

return try await generativeAIService.loadRequest(request: countTokensRequest)
}

Expand Down
1 change: 1 addition & 0 deletions FirebaseVertexAI/Tests/TestApp/Sources/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ public enum FirebaseAppNames {
}

public enum ModelNames {
public static let gemini2Flash = "gemini-2.0-flash-001"
public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAuth
import FirebaseCore
import FirebaseStorage
import FirebaseVertexAI
import Testing
import VertexAITestApp

@testable import struct FirebaseVertexAI.APIConfig

@Suite(.serialized)
struct CountTokensIntegrationTests {
let generationConfig = GenerationConfig(
temperature: 1.2,
topP: 0.95,
topK: 32,
candidateCount: 1,
maxOutputTokens: 8192,
presencePenalty: 1.5,
frequencyPenalty: 1.75,
stopSequences: ["cat", "dog", "bird"]
)
let safetySettings = [
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
]
let systemInstruction = ModelContent(
role: "system",
parts: "You are a friendly and helpful assistant."
)

@Test(arguments: InstanceConfig.allConfigs)
func countTokens_text(_ config: InstanceConfig) async throws {
let prompt = "Why is the sky blue?"
let model = VertexAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2Flash,
generationConfig: generationConfig,
safetySettings: safetySettings
)

let response = try await model.countTokens(prompt)

#expect(response.totalTokens == 6)
switch config.apiConfig.service {
case .vertexAI:
#expect(response.totalBillableCharacters == 16)
case .developer:
#expect(response.totalBillableCharacters == nil)
}
#expect(response.promptTokensDetails.count == 1)
let promptTokensDetails = try #require(response.promptTokensDetails.first)
#expect(promptTokensDetails.modality == .text)
#expect(promptTokensDetails.tokenCount == response.totalTokens)
}

@Test(arguments: [
InstanceConfig.vertexV1,
InstanceConfig.vertexV1Beta,
/* System instructions are not supported on the v1 Developer API. */
InstanceConfig.developerV1Beta,
])
func countTokens_text_systemInstruction(_ config: InstanceConfig) async throws {
let model = VertexAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2Flash,
generationConfig: generationConfig,
safetySettings: safetySettings,
systemInstruction: systemInstruction // Not supported on the v1 Developer API
)

let response = try await model.countTokens("What is your favourite colour?")

#expect(response.totalTokens == 14)
switch config.apiConfig.service {
case .vertexAI:
#expect(response.totalBillableCharacters == 61)
case .developer:
#expect(response.totalBillableCharacters == nil)
}
#expect(response.promptTokensDetails.count == 1)
let promptTokensDetails = try #require(response.promptTokensDetails.first)
#expect(promptTokensDetails.modality == .text)
#expect(promptTokensDetails.tokenCount == response.totalTokens)
}

@Test(arguments: [
/* System instructions are not supported on the v1 Developer API. */
InstanceConfig.developerV1,
])
func countTokens_text_systemInstruction_unsupported(_ config: InstanceConfig) async throws {
let model = VertexAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2Flash,
systemInstruction: systemInstruction // Not supported on the v1 Developer API
)

try await #require(
throws: BackendError.self,
"""
If this test fails (i.e., `countTokens` succeeds), remove \(config) from this test and add it
to `countTokens_text_systemInstruction`.
""",
performing: {
try await model.countTokens("What is your favourite colour?")
}
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,8 @@ import FirebaseVertexAI
import Testing
import VertexAITestApp

@testable import struct FirebaseVertexAI.APIConfig

@Suite(.serialized)
struct GenerateContentIntegrationTests {
static let vertexV1Config =
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1))
static let vertexV1BetaConfig =
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1beta))
static let developerV1Config = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(
service: .developer(endpoint: .generativeLanguage), version: .v1
)
)
static let developerV1BetaConfig = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(
service: .developer(endpoint: .generativeLanguage), version: .v1beta
)
)
static let allConfigs =
[vertexV1Config, vertexV1BetaConfig, developerV1Config, developerV1BetaConfig]

// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
let safetySettings = [
Expand All @@ -67,7 +46,7 @@ struct GenerateContentIntegrationTests {
storage = Storage.storage()
}

@Test(arguments: allConfigs)
@Test(arguments: InstanceConfig.allConfigs)
func generateContent(_ config: InstanceConfig) async throws {
let model = VertexAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2FlashLite,
Expand Down Expand Up @@ -98,10 +77,10 @@ struct GenerateContentIntegrationTests {
@Test(
"Generate an enum and provide a system instruction",
arguments: [
vertexV1Config,
vertexV1BetaConfig,
InstanceConfig.vertexV1,
InstanceConfig.vertexV1Beta,
/* System instructions are not supported on the v1 Developer API. */
developerV1BetaConfig,
InstanceConfig.developerV1Beta,
]
)
func generateContentEnum(_ config: InstanceConfig) async throws {
Expand Down
Loading
Loading