Skip to content

Commit

Permalink
[inference] fold openai support into provider param (#1205)
Browse files Browse the repository at this point in the history
ie. no need to override a `endpoint` anymore

This only works in "client-side" mode ie when passing a provider key

WDYT?

---------

Co-authored-by: Wauplin <lucainp@gmail.com>
Co-authored-by: SBrandeis <simon@huggingface.co>
  • Loading branch information
3 people authored Feb 28, 2025
1 parent fac3157 commit 822ab9e
Show file tree
Hide file tree
Showing 6 changed files with 7,561 additions and 7,458 deletions.
51 changes: 38 additions & 13 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { NOVITA_CONFIG } from "../providers/novita";
import { REPLICATE_CONFIG } from "../providers/replicate";
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
import { TOGETHER_CONFIG } from "../providers/together";
import { OPENAI_CONFIG } from "../providers/openai";
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
Expand All @@ -33,6 +34,7 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
hyperbolic: HYPERBOLIC_CONFIG,
openai: OPENAI_CONFIG,
nebius: NEBIUS_CONFIG,
novita: NOVITA_CONFIG,
replicate: REPLICATE_CONFIG,
Expand Down Expand Up @@ -72,22 +74,38 @@ export async function makeRequestOptions(
if (!providerConfig) {
throw new Error(`No provider config found for provider ${provider}`);
}
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
const model = await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});
const model = providerConfig.clientSideRoutingOnly
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
removeProviderPrefix(maybeModel!, provider)
: // For closed-models API providers, one needs to pass the model ID directly (e.g. "gpt-3.5-turbo")
await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});

/// If accessToken is passed, it should take precedence over includeCredentials
const authMethod = accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";
const authMethod = (() => {
if (providerConfig.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
if (accessToken && accessToken.startsWith("hf_")) {
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
}
return "provider-key";
}
if (accessToken) {
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
}
if (includeCredentials === "include") {
// If accessToken is passed, it should take precedence over includeCredentials
return "credentials-include";
}
return "none";
})();

// Make URL
const url = endpointUrl
Expand Down Expand Up @@ -176,3 +194,10 @@ async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[]
}
return await res.json();
}

function removeProviderPrefix(model: string, provider: string): string {
if (!model.startsWith(`${provider}/`)) {
throw new Error(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
}
return model.slice(provider.length + 1);
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
hyperbolic: {},
nebius: {},
novita: {},
openai: {},
replicate: {},
sambanova: {},
together: {},
Expand Down
35 changes: 35 additions & 0 deletions packages/inference/src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Special case: provider configuration for a private models provider (OpenAI in this case).
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const OPENAI_API_BASE_URL = "https://api.openai.com";

const makeBody = (params: BodyParams): Record<string, unknown> => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return `${params.baseUrl}/v1/chat/completions`;
};

export const OPENAI_CONFIG: ProviderConfig = {
baseUrl: OPENAI_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
clientSideRoutingOnly: true,
};
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export const INFERENCE_PROVIDERS = [
"hyperbolic",
"nebius",
"novita",
"openai",
"replicate",
"sambanova",
"together",
Expand Down Expand Up @@ -96,6 +97,7 @@ export interface ProviderConfig {
makeBody: (params: BodyParams) => Record<string, unknown>;
makeHeaders: (params: HeaderParams) => Record<string, string>;
makeUrl: (params: UrlParams) => string;
clientSideRoutingOnly?: boolean;
}

export interface HeaderParams {
Expand Down
15 changes: 12 additions & 3 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,9 @@ describe.concurrent("HfInference", () => {
it("custom openai - OpenAI Specs", async () => {
const OPENAI_KEY = env.OPENAI_KEY;
const hf = new HfInference(OPENAI_KEY);
const ep = hf.endpoint("https://api.openai.com");
const stream = ep.chatCompletionStream({
model: "gpt-3.5-turbo",
const stream = hf.chatCompletionStream({
provider: "openai",
model: "openai/gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the equation one + one =" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
Expand All @@ -768,6 +768,15 @@ describe.concurrent("HfInference", () => {
}
expect(out).toContain("two");
});
it("OpenAI client side routing - model should have provider as prefix", async () => {
await expect(
new HfInference("dummy_token").chatCompletion({
model: "gpt-3.5-turbo", // must be "openai/gpt-3.5-turbo"
provider: "openai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
})
).rejects.toThrowError(`Models from openai must be prefixed by "openai/". Got "gpt-3.5-turbo".`);
});
},
TIMEOUT
);
Expand Down
Loading

0 comments on commit 822ab9e

Please sign in to comment.