diff --git a/server/prisma/seed.ts b/server/prisma/seed.ts index 5f1505e2..1c2968a6 100644 --- a/server/prisma/seed.ts +++ b/server/prisma/seed.ts @@ -354,6 +354,22 @@ const LLMS: { stream_available: true, model_provider: "OpenAI", config: "{}", + }, + { + model_id: "gemini-1.5-flash-dbase", + name: "Gemini 1.5 Flash (Google)", + model_type: "chat", + stream_available: true, + model_provider: "Google", + config: "{}", + }, + { + model_id: "gemini-1.5-pro-dbase", + name: "Gemini 1.5 Pro (Google)", + model_type: "chat", + stream_available: true, + model_provider: "Google", + config: "{}", } ]; diff --git a/server/src/chain/index.ts b/server/src/chain/index.ts index 2ab5c696..cf63f0c1 100644 --- a/server/src/chain/index.ts +++ b/server/src/chain/index.ts @@ -1,4 +1,5 @@ import { BaseLanguageModel } from "@langchain/core/language_models/base"; +import { BaseChatModel } from "@langchain/core/language_models/chat_models"; import { Document } from "@langchain/core/documents"; import { ChatPromptTemplate, @@ -14,6 +15,7 @@ import { RunnableMap, RunnableSequence, } from "@langchain/core/runnables"; + type RetrievalChainInput = { chat_history: string; question: string; @@ -107,8 +109,8 @@ export const createChain = ({ retriever, response_template, }: { - llm: BaseLanguageModel; - question_llm: BaseLanguageModel; + llm: BaseLanguageModel | BaseChatModel ; + question_llm: BaseLanguageModel | BaseChatModel; retriever: Runnable; question_template: string; response_template: string; diff --git a/server/src/utils/models.ts b/server/src/utils/models.ts index ee1fb64e..84538b5c 100644 --- a/server/src/utils/models.ts +++ b/server/src/utils/models.ts @@ -84,7 +84,7 @@ export const chatModelProvider = ( maxOutputTokens: 2048, apiKey: process.env.GOOGLE_API_KEY, ...otherFields, - }); + }) as any case "ollama": return new ChatOllama({ baseUrl: otherFields.baseURL,