diff --git a/plugins/openai/src/index.ts b/plugins/openai/src/index.ts index a7cdee1e..3192c87a 100644 --- a/plugins/openai/src/index.ts +++ b/plugins/openai/src/index.ts @@ -36,6 +36,8 @@ import { } from './gpt.js'; import { SUPPORTED_TTS_MODELS, ttsModel, tts1, tts1Hd } from './tts.js'; import { whisper1, whisper1Model } from './whisper.js'; +import { ModelAction } from '@genkit-ai/ai/model'; +import z from 'zod'; export { dallE3, gpt35Turbo, @@ -52,8 +54,12 @@ export { textEmbeddingAda002, }; +export type OpenAICustomModelAction = (client: OpenAI) => ModelAction + export interface PluginOptions { apiKey?: string; + baseURL?: string; + customModels?: Array; } /** @@ -106,7 +112,8 @@ export const openAI: Plugin<[PluginOptions] | []> = genkitPlugin( throw new Error( 'please pass in the API key or set the OPENAI_API_KEY environment variable' ); - const client = new OpenAI({ apiKey }); + const baseURL = options?.baseURL || process.env.OPENAI_BASE_URL; + const client = new OpenAI({ apiKey, baseURL }); return { models: [ ...Object.keys(SUPPORTED_GPT_MODELS).map((name) => @@ -117,6 +124,7 @@ export const openAI: Plugin<[PluginOptions] | []> = genkitPlugin( ), dallE3Model(client), whisper1Model(client), + ...(options?.customModels?.map(model => model(client)) || []), ], embedders: Object.keys(SUPPORTED_EMBEDDING_MODELS).map((name) => openaiEmbedder(name, options)