Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Add Developer API encoding CountTokensRequest #14512

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,21 @@ import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let model: String

let contents: [ModelContent]
let systemInstruction: ModelContent?
let tools: [Tool]?
let generationConfig: GenerationConfig?

let apiConfig: APIConfig
let options: RequestOptions
let generateContentRequest: GenerateContentRequest
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension CountTokensRequest: GenerativeAIRequest {
typealias Response = CountTokensResponse

var options: RequestOptions { generateContentRequest.options }

var apiConfig: APIConfig { generateContentRequest.apiConfig }

var url: URL {
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
let version = apiConfig.version.rawValue
let endpoint = apiConfig.service.endpoint.rawValue
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
}
}

Expand All @@ -57,12 +54,34 @@ public struct CountTokensResponse {

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension CountTokensRequest: Encodable {
enum CodingKeys: CodingKey {
enum VertexCodingKeys: CodingKey {
case contents
case systemInstruction
case tools
case generationConfig
}

enum DeveloperCodingKeys: CodingKey {
case generateContentRequest
}

func encode(to encoder: any Encoder) throws {
switch apiConfig.service {
case .vertexAI:
var container = encoder.container(keyedBy: VertexCodingKeys.self)
try container.encode(generateContentRequest.contents, forKey: .contents)
try container.encodeIfPresent(
generateContentRequest.systemInstruction, forKey: .systemInstruction
)
try container.encodeIfPresent(generateContentRequest.tools, forKey: .tools)
try container.encodeIfPresent(
generateContentRequest.generationConfig, forKey: .generationConfig
)
case .developer:
var container = encoder.container(keyedBy: DeveloperCodingKeys.self)
try container.encode(generateContentRequest, forKey: .generateContentRequest)
}
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
16 changes: 16 additions & 0 deletions FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import Foundation
struct GenerateContentRequest: Sendable {
/// Model name.
let model: String

let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let toolConfig: ToolConfig?
let systemInstruction: ModelContent?

let apiConfig: APIConfig
let apiMethod: APIMethod
let options: RequestOptions
Expand All @@ -32,13 +34,27 @@ struct GenerateContentRequest: Sendable {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension GenerateContentRequest: Encodable {
enum CodingKeys: String, CodingKey {
case model
case contents
case generationConfig
case safetySettings
case tools
case toolConfig
case systemInstruction
}

func encode(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
if apiMethod == .countTokens {
try container.encode(model, forKey: .model)
}
try container.encode(contents, forKey: .contents)
try container.encodeIfPresent(generationConfig, forKey: .generationConfig)
try container.encodeIfPresent(safetySettings, forKey: .safetySettings)
try container.encodeIfPresent(tools, forKey: .tools)
try container.encodeIfPresent(toolConfig, forKey: .toolConfig)
try container.encodeIfPresent(systemInstruction, forKey: .systemInstruction)
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
24 changes: 20 additions & 4 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,31 @@ public final class GenerativeModel: Sendable {
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
let requestContent = switch apiConfig.service {
case .vertexAI:
content
case .developer:
// The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
// erroneously counted towards the prompt and total token count when using the Developer API
// backend; set to `nil` to avoid token count discrepancies between `countTokens` and
// `generateContent` and the two backend APIs.
content.map { ModelContent(role: nil, parts: $0.parts) }
}

let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
contents: content,
systemInstruction: systemInstruction,
tools: tools,
contents: requestContent,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .countTokens,
options: requestOptions
)
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)

return try await generativeAIService.loadRequest(request: countTokensRequest)
}

Expand Down
1 change: 1 addition & 0 deletions FirebaseVertexAI/Tests/TestApp/Sources/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ public enum FirebaseAppNames {
}

public enum ModelNames {
public static let gemini2Flash = "gemini-2.0-flash-001"
public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAuth
import FirebaseCore
import FirebaseStorage
import FirebaseVertexAI
import Testing
import VertexAITestApp

@testable import struct FirebaseVertexAI.APIConfig

@Suite(.serialized)
struct CountTokensIntegrationTests {
let generationConfig = GenerationConfig(
temperature: 1.2,
topP: 0.95,
topK: 32,
candidateCount: 1,
maxOutputTokens: 8192,
presencePenalty: 1.5,
frequencyPenalty: 1.75,
stopSequences: ["cat", "dog", "bird"]
)
let safetySettings = [
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
]

@Test(arguments: InstanceConfig.allConfigs)
func countTokens_text(_ config: InstanceConfig) async throws {
let prompt = "Why is the sky blue?"
let model = VertexAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2Flash,
generationConfig: generationConfig,
safetySettings: safetySettings
)

let response = try await model.countTokens(prompt)

#expect(response.totalTokens == 6)
switch config.apiConfig.service {
case .vertexAI:
#expect(response.totalBillableCharacters == 16)
case .developer:
#expect(response.totalBillableCharacters == nil)
}
#expect(response.promptTokensDetails.count == 1)
let promptTokensDetails = try #require(response.promptTokensDetails.first)
#expect(promptTokensDetails.modality == .text)
#expect(promptTokensDetails.tokenCount == response.totalTokens)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,8 @@ import FirebaseVertexAI
import Testing
import VertexAITestApp

@testable import struct FirebaseVertexAI.APIConfig

@Suite(.serialized)
struct GenerateContentIntegrationTests {
static let vertexV1Config =
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1))
static let vertexV1BetaConfig =
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1beta))
static let developerV1Config = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(
service: .developer(endpoint: .generativeLanguage), version: .v1
)
)
static let developerV1BetaConfig = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(
service: .developer(endpoint: .generativeLanguage), version: .v1beta
)
)
static let allConfigs =
[vertexV1Config, vertexV1BetaConfig, developerV1Config, developerV1BetaConfig]

// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
let safetySettings = [
Expand All @@ -67,7 +46,7 @@ struct GenerateContentIntegrationTests {
storage = Storage.storage()
}

@Test(arguments: allConfigs)
@Test(arguments: InstanceConfig.allConfigs)
func generateContent(_ config: InstanceConfig) async throws {
let model = VertexAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2FlashLite,
Expand Down Expand Up @@ -98,10 +77,10 @@ struct GenerateContentIntegrationTests {
@Test(
"Generate an enum and provide a system instruction",
arguments: [
vertexV1Config,
vertexV1BetaConfig,
InstanceConfig.vertexV1,
InstanceConfig.vertexV1Beta,
/* System instructions are not supported on the v1 Developer API. */
developerV1BetaConfig,
InstanceConfig.developerV1Beta,
]
)
func generateContentEnum(_ config: InstanceConfig) async throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,25 @@
// limitations under the License.

import FirebaseCore
import Testing
import VertexAITestApp

@testable import struct FirebaseVertexAI.APIConfig
@testable import class FirebaseVertexAI.VertexAI

struct InstanceConfig {
static let vertexV1 = InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1))
static let vertexV1Beta = InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1beta))
static let developerV1 = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1)
)
static let developerV1Beta = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)
)
static let allConfigs = [vertexV1, vertexV1Beta, developerV1, developerV1Beta]

let appName: String?
let location: String?
let apiConfig: APIConfig
Expand All @@ -32,6 +45,35 @@ struct InstanceConfig {
var app: FirebaseApp? {
return appName.map { FirebaseApp.app(name: $0) } ?? FirebaseApp.app()
}

var serviceName: String {
switch apiConfig.service {
case .vertexAI:
return "Vertex AI"
case .developer:
return "Developer"
}
}

var versionName: String {
return apiConfig.version.rawValue
}
}

extension InstanceConfig: CustomTestStringConvertible {
var testDescription: String {
let endpointSuffix = switch apiConfig.service.endpoint {
case .firebaseVertexAIProd:
""
case .firebaseVertexAIStaging:
" - Staging"
case .generativeLanguage:
" - Generative Language"
}
let locationSuffix = location.map { " - \($0)" } ?? ""

return "\(serviceName) (\(versionName))\(endpointSuffix)\(locationSuffix)"
}
}

extension VertexAI {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
8661385C2CC943DD00F4B78E /* TestApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385B2CC943DD00F4B78E /* TestApp.swift */; };
8661385E2CC943DD00F4B78E /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385D2CC943DD00F4B78E /* ContentView.swift */; };
8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661386D2CC943DE00F4B78E /* IntegrationTests.swift */; };
8689CDCC2D7F8BD700BF426B /* CountTokensIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */; };
868A7C482CCA931B00E449DD /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 868A7C462CCA931B00E449DD /* GoogleService-Info.plist */; };
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */ = {isa = PBXBuildFile; fileRef = 868A7C4D2CCC1F4700E449DD /* Credentials.swift */; };
868A7C522CCC263300E449DD /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 868A7C502CCC263300E449DD /* Preview Assets.xcassets */; };
Expand All @@ -25,7 +26,7 @@
86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */; };
86D77DFE2D7B5C86003D155D /* GoogleService-Info-Spark.plist in Resources */ = {isa = PBXBuildFile; fileRef = 86D77DFD2D7B5C86003D155D /* GoogleService-Info-Spark.plist */; };
86D77E022D7B63AF003D155D /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77E012D7B63AC003D155D /* Constants.swift */; };
86D77E042D7B6C9D003D155D /* VertexAITestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77E032D7B6C95003D155D /* VertexAITestUtils.swift */; };
86D77E042D7B6C9D003D155D /* InstanceConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77E032D7B6C95003D155D /* InstanceConfig.swift */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand All @@ -46,6 +47,7 @@
8661385D2CC943DD00F4B78E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
866138692CC943DE00F4B78E /* IntegrationTests-SPM.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "IntegrationTests-SPM.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IntegrationTests.swift; sourceTree = "<group>"; };
8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CountTokensIntegrationTests.swift; sourceTree = "<group>"; };
868A7C462CCA931B00E449DD /* GoogleService-Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = "GoogleService-Info.plist"; sourceTree = "<group>"; };
868A7C4D2CCC1F4700E449DD /* Credentials.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Credentials.swift; sourceTree = "<group>"; };
868A7C502CCC263300E449DD /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
Expand All @@ -56,7 +58,7 @@
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerateContentIntegrationTests.swift; sourceTree = "<group>"; };
86D77DFD2D7B5C86003D155D /* GoogleService-Info-Spark.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = "GoogleService-Info-Spark.plist"; sourceTree = "<group>"; };
86D77E012D7B63AC003D155D /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = "<group>"; };
86D77E032D7B6C95003D155D /* VertexAITestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VertexAITestUtils.swift; sourceTree = "<group>"; };
86D77E032D7B6C95003D155D /* InstanceConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InstanceConfig.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -134,6 +136,7 @@
868A7C572CCC27AF00E449DD /* Integration */ = {
isa = PBXGroup;
children = (
8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */,
868A7C4D2CCC1F4700E449DD /* Credentials.swift */,
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */,
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */,
Expand All @@ -154,7 +157,7 @@
8698D7442CD3CEF700ABA833 /* Utilities */ = {
isa = PBXGroup;
children = (
86D77E032D7B6C95003D155D /* VertexAITestUtils.swift */,
86D77E032D7B6C95003D155D /* InstanceConfig.swift */,
8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */,
862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */,
);
Expand Down Expand Up @@ -283,7 +286,8 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
86D77E042D7B6C9D003D155D /* VertexAITestUtils.swift in Sources */,
8689CDCC2D7F8BD700BF426B /* CountTokensIntegrationTests.swift in Sources */,
86D77E042D7B6C9D003D155D /* InstanceConfig.swift in Sources */,
8698D7462CD3CF3600ABA833 /* FirebaseAppTestUtils.swift in Sources */,
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */,
864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */,
Expand Down
Loading