Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Add APIConfig to support switching to the Developer API #14521

Merged
merged 12 commits into from
Mar 5, 2025
Prev Previous commit
Next Next commit
Refactor APIConfig.endpoint as an associated value of Service
  • Loading branch information
andrewheard committed Mar 5, 2025
commit ff5b4600aef038838e5a1d842610579975e35557
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ extension CountTokensRequest: GenerativeAIRequest {

var url: URL {
URL(string:
"\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
}
}

2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
let modelURL = "\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
let modelURL = "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
switch apiMethod {
case .generateContent:
return URL(string: "\(modelURL):\(apiMethod.rawValue)")!
29 changes: 17 additions & 12 deletions FirebaseVertexAI/Sources/Types/Internal/APIConfig.swift
Original file line number Diff line number Diff line change
@@ -19,23 +19,16 @@ struct APIConfig: Sendable, Hashable {
/// This controls which backend API is used by the SDK.
let service: Service

/// The specific network address to use for API requests.
///
/// This must correspond with the API set in `service`.
let serviceEndpoint: ServiceEndpoint

/// The version of the selected API to use, e.g., "v1".
let version: Version

/// Initializes an API configuration.
///
/// - Parameters:
/// - service: The API service to use for generative AI.
/// - serviceEndpoint: The network address to use for the API service.
/// - version: The version of the API to use.
init(service: Service, serviceEndpoint: ServiceEndpoint, version: Version) {
init(service: Service, version: Version) {
self.service = service
self.serviceEndpoint = serviceEndpoint
self.version = version
}
}
@@ -46,7 +39,7 @@ extension APIConfig {
/// See [Vertex AI and Google AI
/// differences](https://cloud.google.com/vertex-ai/generative-ai/docs/overview#how-gemini-vertex-different-gemini-aistudio)
/// for a comparison of the two [API services](https://google.aip.dev/9#api-service).
enum Service {
enum Service: Hashable {
/// The Gemini Enterprise API provided by Vertex AI.
///
/// See the [Cloud
@@ -57,13 +50,25 @@ extension APIConfig {
/// The Gemini Developer API provided by Google AI.
///
/// See the [Google AI docs](https://ai.google.dev/gemini-api/docs) for more details.
case developer
case developer(endpoint: Endpoint)

/// The specific network address to use for API requests.
///
/// This must correspond with the API set in `service`.
var endpoint: Endpoint {
switch self {
case .vertexAI:
return .firebaseVertexAIProd
case let .developer(endpoint: endpoint):
return endpoint
}
}
}
}

extension APIConfig {
extension APIConfig.Service {
/// Network addresses for generative AI API services.
enum ServiceEndpoint: String {
enum Endpoint: String {
/// The Vertex AI in Firebase production endpoint.
case firebaseVertexAIProd = "https://firebasevertexai.googleapis.com"

Original file line number Diff line number Diff line change
@@ -40,9 +40,8 @@ extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodabl
typealias Response = ImagenGenerationResponse<ImageType>

var url: URL {
return URL(
string: "\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict"
)!
return URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict")!
}
}

17 changes: 6 additions & 11 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ public class VertexAI {
}
let vertexInstance = vertexAI(app: app, location: location)
assert(vertexInstance.apiConfig.service == .vertexAI)
assert(vertexInstance.apiConfig.serviceEndpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.version == .v1beta)

return vertexInstance
@@ -56,7 +56,7 @@ public class VertexAI {
public static func vertexAI(app: FirebaseApp, location: String = "us-central1") -> VertexAI {
let vertexInstance = vertexAI(app: app, location: location, apiConfig: defaultVertexAIAPIConfig)
assert(vertexInstance.apiConfig.service == .vertexAI)
assert(vertexInstance.apiConfig.serviceEndpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.version == .v1beta)

return vertexInstance
@@ -159,14 +159,9 @@ public class VertexAI {

let location: String?

static let defaultVertexAIAPIConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
static let defaultVertexAIAPIConfig = APIConfig(service: .vertexAI, version: .v1beta)
static let defaultDeveloperAPIConfig = APIConfig(
service: .developer,
serviceEndpoint: .generativeLanguage,
service: .developer(endpoint: .generativeLanguage),
version: .v1beta
)

@@ -256,14 +251,14 @@ public class VertexAI {
}

private func developerModelResourceName(modelName: String) -> String {
switch apiConfig.serviceEndpoint {
switch apiConfig.service.endpoint {
case .firebaseVertexAIStaging:
let projectID = firebaseInfo.projectID
return "projects/\(projectID)/models/\(modelName)"
case .generativeLanguage:
return "models/\(modelName)"
default:
fatalError("The Developer API is not supported on '\(apiConfig.serviceEndpoint)'.")
fatalError("The Developer API is not supported on '\(apiConfig.service.endpoint)'.")
}
}

6 changes: 1 addition & 5 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
MockURLProtocol.requestHandler = { request in

Check warning on line 47 in FirebaseVertexAI/Tests/Unit/ChatTests.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let response = HTTPURLResponse(
url: request.url!,
statusCode: 200,
@@ -65,11 +65,7 @@
googleAppID: "My app ID",
firebaseApp: app
),
apiConfig: APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
),
apiConfig: APIConfig(service: .vertexAI, version: .v1beta),
tools: nil,
requestOptions: RequestOptions(),
urlSession: urlSession
6 changes: 1 addition & 5 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
@@ -58,11 +58,7 @@
].sorted()
let testModelResourceName =
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)

var urlSession: URLSession!
var model: GenerativeModel!
@@ -1496,7 +1492,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
return { request in

Check warning on line 1495 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
// This is *not* an HTTPURLResponse
let response = URLResponse(
url: request.url!,
@@ -1523,7 +1519,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
let bundle = BundleTestUtil.bundle()

Check warning on line 1522 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
return { request in
let requestURL = try XCTUnwrap(request.url)
Original file line number Diff line number Diff line change
@@ -36,11 +36,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
addWatermark: nil,
includeResponsibleAIFilterReason: includeResponsibleAIFilterReason
)
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)

let instance = ImageGenerationInstance(prompt: "test-prompt")

@@ -64,7 +60,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(
request.url,
URL(string:
"\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
}

@@ -84,7 +80,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(
request.url,
URL(string:
"\(apiConfig.serviceEndpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
}

31 changes: 8 additions & 23 deletions FirebaseVertexAI/Tests/Unit/Types/Internal/APIConfigTests.swift
Original file line number Diff line number Diff line change
@@ -19,49 +19,34 @@ import XCTest
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class APIConfigTests: XCTestCase {
func testInitialize_vertexAI_prod_v1() {
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1)

XCTAssertEqual(apiConfig.serviceEndpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.version.rawValue, "v1")
}

func testInitialize_vertexAI_prod_v1beta() {
let apiConfig = APIConfig(
service: .vertexAI,
serviceEndpoint: .firebaseVertexAIProd,
version: .v1beta
)
let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)

XCTAssertEqual(apiConfig.serviceEndpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://firebasevertexai.googleapis.com")
XCTAssertEqual(apiConfig.version.rawValue, "v1beta")
}

func testInitialize_developer_staging_v1beta() {
let apiConfig = APIConfig(
service: .developer,
serviceEndpoint: .firebaseVertexAIStaging,
version: .v1beta
service: .developer(endpoint: .firebaseVertexAIStaging), version: .v1beta
)

XCTAssertEqual(
apiConfig.serviceEndpoint.rawValue,
"https://staging-firebasevertexai.sandbox.googleapis.com"
apiConfig.service.endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com"
)
XCTAssertEqual(apiConfig.version.rawValue, "v1beta")
}

func testInitialize_developer_generativeLanguage_v1beta() {
let apiConfig = APIConfig(
service: .developer,
serviceEndpoint: .generativeLanguage,
version: .v1beta
)
let apiConfig = APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)

XCTAssertEqual(apiConfig.serviceEndpoint.rawValue, "https://generativelanguage.googleapis.com")
XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://generativelanguage.googleapis.com")
XCTAssertEqual(apiConfig.version.rawValue, "v1beta")
}
}
7 changes: 3 additions & 4 deletions FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ class VertexComponentTests: XCTestCase {
XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey)
XCTAssertEqual(vertex.location, location)
XCTAssertEqual(vertex.apiConfig.service, .vertexAI)
XCTAssertEqual(vertex.apiConfig.serviceEndpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.version, .v1beta)
}

@@ -73,7 +73,7 @@ class VertexComponentTests: XCTestCase {
XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey)
XCTAssertEqual(vertex.location, location)
XCTAssertEqual(vertex.apiConfig.service, .vertexAI)
XCTAssertEqual(vertex.apiConfig.serviceEndpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.version, .v1beta)
}

@@ -179,8 +179,7 @@ class VertexComponentTests: XCTestCase {
func testModelResourceName_developerAPI_firebaseVertexAI() throws {
let app = try XCTUnwrap(VertexComponentTests.app)
let apiConfig = APIConfig(
service: .developer,
serviceEndpoint: .firebaseVertexAIStaging,
service: .developer(endpoint: .firebaseVertexAIStaging),
version: .v1beta
)
let vertex = VertexAI.developerAPI(app: app, apiConfig: apiConfig)
Loading