From a5c9dbafe7cd1bb3eec8530fdd89760370b9d0ea Mon Sep 17 00:00:00 2001 From: yaojin Date: Wed, 31 Jan 2024 20:32:46 +0800 Subject: [PATCH 1/2] fix bug --- src/backend/bisheng/api/v1/endpoints.py | 3 +-- src/backend/bisheng/api/v1/knowledge.py | 1 - src/backend/bisheng/api/v1/report.py | 2 ++ src/backend/bisheng/api/v1/user.py | 1 - src/backend/bisheng/chat/utils.py | 11 +++++++---- src/backend/bisheng/interface/initialize/loading.py | 1 + 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/backend/bisheng/api/v1/endpoints.py b/src/backend/bisheng/api/v1/endpoints.py index bb931f3a6..a9d1bf2de 100644 --- a/src/backend/bisheng/api/v1/endpoints.py +++ b/src/backend/bisheng/api/v1/endpoints.py @@ -198,8 +198,7 @@ async def process_flow( message=answer, source=source) with session_getter() as session: - session.add(question) - session.add(message) + session.add_all([question, message]) session.commit() session.refresh(message) extra.update({'source': source, 'message_id': message.id}) diff --git a/src/backend/bisheng/api/v1/knowledge.py b/src/backend/bisheng/api/v1/knowledge.py index 0f8a63b7d..739aa551f 100644 --- a/src/backend/bisheng/api/v1/knowledge.py +++ b/src/backend/bisheng/api/v1/knowledge.py @@ -474,7 +474,6 @@ def addEmbedding(collection_name, index_name, knowledge_id: int, model: str, chu object_name_original = f'original/{db_file.id}' setattr(db_file, 'object_name', object_name_original) session.add(db_file) - session.flush() session.commit() session.refresh(db_file) diff --git a/src/backend/bisheng/api/v1/report.py b/src/backend/bisheng/api/v1/report.py index ff9ddfb6d..a9d2be8ab 100644 --- a/src/backend/bisheng/api/v1/report.py +++ b/src/backend/bisheng/api/v1/report.py @@ -58,12 +58,14 @@ async def get_template(*, flow_id: str): db_report = Report(flow_id=flow_id) elif db_report.object_name: file_url = minio_client.MinioClient().get_share_link(db_report.object_name) + if not db_report.newversion_key or not db_report.object_name: version_key = uuid4().hex db_report.newversion_key = version_key with session_getter() as session: session.add(db_report) session.commit() + session.refresh() else: version_key = db_report.newversion_key res = { diff --git a/src/backend/bisheng/api/v1/user.py b/src/backend/bisheng/api/v1/user.py index e7dba2f25..2ae3b43b1 100644 --- a/src/backend/bisheng/api/v1/user.py +++ b/src/backend/bisheng/api/v1/user.py @@ -381,7 +381,6 @@ async def access_list(*, role_id: int, type: Optional[int] = None, Authorize: Au with session_getter() as session: db_role_access = session.exec(sql).all() total_count = session.scalar(count_sql) - session.commit() # uuid 和str的转化 for access in db_role_access: if access.type == AccessType.FLOW.value: diff --git a/src/backend/bisheng/chat/utils.py b/src/backend/bisheng/chat/utils.py index c50095355..e77c481b2 100644 --- a/src/backend/bisheng/chat/utils.py +++ b/src/backend/bisheng/chat/utils.py @@ -139,6 +139,8 @@ async def process_source_document(source_document: List[Document], chat_id, mess logger.error('不能使用配置模型进行关键词抽取,配置不正确') answer_keywords = extract_answer_keys(answer, model, host_base_url) + + batch_insert = [] for doc in source_document: if 'bbox' in doc.metadata: # 表示支持溯源 @@ -149,7 +151,8 @@ async def process_source_document(source_document: List[Document], chat_id, mess file_id=doc.metadata.get('file_id'), meta_data=json.dumps(doc.metadata), message_id=message_id) - with session_getter() as db_session: - db_session.add(recall_chunk) - db_session.commit() - db_session.refresh(recall_chunk) + batch_insert.append(recall_chunk) + if batch_insert: + with session_getter() as db_session: + db_session.add_all(batch_insert) + db_session.commit() diff --git a/src/backend/bisheng/interface/initialize/loading.py b/src/backend/bisheng/interface/initialize/loading.py index a27b47463..e8155793c 100644 --- a/src/backend/bisheng/interface/initialize/loading.py +++ b/src/backend/bisheng/interface/initialize/loading.py @@ -290,6 +290,7 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict, id_di # sequence chain if node_type == 'SequentialChain': # 改造sequence 支持自定义chain顺序 + params.pop('input_node', '') # sequential 不支持增加入参 try: chain_order = json.loads(params.pop('chain_order')) except Exception: From a25044c241658396692aae91efc357bf3b09d89f Mon Sep 17 00:00:00 2001 From: yaojin Date: Wed, 31 Jan 2024 23:27:18 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E9=87=8D=E8=BD=BDconversationRetriev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bisheng_langchain/chains/__init__.py | 5 +- .../conversational_retrieval/__init__.py | 0 .../chains/conversational_retrieval/base.py | 115 ++++++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/__init__.py create mode 100644 src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/base.py diff --git a/src/bisheng-langchain/bisheng_langchain/chains/__init__.py b/src/bisheng-langchain/bisheng_langchain/chains/__init__.py index 03907f3f2..79a27210c 100644 --- a/src/bisheng-langchain/bisheng_langchain/chains/__init__.py +++ b/src/bisheng-langchain/bisheng_langchain/chains/__init__.py @@ -1,5 +1,6 @@ from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain from bisheng_langchain.chains.combine_documents.stuff import StuffDocumentsChain +from bisheng_langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain from bisheng_langchain.chains.retrieval.retrieval_chain import RetrievalChain from bisheng_langchain.chains.router.multi_rule import MultiRuleChain from bisheng_langchain.chains.router.rule_router import RuleBasedRouter @@ -7,6 +8,6 @@ from .loader_output import LoaderOutputChain __all__ = [ - 'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter', 'MultiRuleChain', - 'RetrievalChain' + 'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter', + 'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain' ] diff --git a/src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/__init__.py b/src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/base.py b/src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/base.py new file mode 100644 index 000000000..4eacd807a --- /dev/null +++ b/src/bisheng-langchain/bisheng_langchain/chains/conversational_retrieval/base.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun +from langchain.chains.conversational_retrieval.base import \ + ConversationalRetrievalChain as BaseConversationalRetrievalChain +from langchain_core.messages import BaseMessage + +# Depending on the memory type and configuration, the chat history format may differ. +# This needs to be consolidated. +CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] + +_ROLE_MAP = {'human': 'Human: ', 'ai': 'Assistant: '} + + +def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str: + buffer = '' + for dialogue_turn in chat_history: + if isinstance(dialogue_turn, BaseMessage): + role_prefix = _ROLE_MAP.get(dialogue_turn.type, f'{dialogue_turn.type}: ') + buffer += f'\n{role_prefix}{dialogue_turn.content}' + elif isinstance(dialogue_turn, tuple): + human = 'Human: ' + dialogue_turn[0] + ai = 'Assistant: ' + dialogue_turn[1] + buffer += '\n' + '\n'.join([human, ai]) + else: + raise ValueError(f'Unsupported chat history format: {type(dialogue_turn)}.' + f' Full chat history: {chat_history} ') + return buffer + + +class ConversationalRetrievalChain(BaseConversationalRetrievalChain): + """ConversationalRetrievalChain is a chain you can use to have a conversation with a character from a series.""" + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + question = inputs['question'] + get_chat_history = self.get_chat_history or _get_chat_history + chat_history_str = get_chat_history(inputs['chat_history']) + + if chat_history_str: + # callbacks = _run_manager.get_child() + new_question = self.question_generator.run(question=question, + chat_history=chat_history_str) + else: + new_question = question + accepts_run_manager = ('run_manager' in inspect.signature(self._get_docs).parameters) + if accepts_run_manager: + docs = self._get_docs(new_question, inputs, run_manager=_run_manager) + else: + docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] + output: Dict[str, Any] = {} + if self.response_if_no_docs_found is not None and len(docs) == 0: + output[self.output_key] = self.response_if_no_docs_found + else: + new_inputs = inputs.copy() + if self.rephrase_question: + new_inputs['question'] = new_question + new_inputs['chat_history'] = chat_history_str + answer = self.combine_docs_chain.run(input_documents=docs, + callbacks=_run_manager.get_child(), + **new_inputs) + output[self.output_key] = answer + + if self.return_source_documents: + output['source_documents'] = docs + if self.return_generated_question: + output['generated_question'] = new_question + return output + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + question = inputs['question'] + get_chat_history = self.get_chat_history or _get_chat_history + chat_history_str = get_chat_history(inputs['chat_history']) + if chat_history_str: + # callbacks = _run_manager.get_child() + new_question = await self.question_generator.arun(question=question, + chat_history=chat_history_str) + else: + new_question = question + accepts_run_manager = ('run_manager' in inspect.signature(self._aget_docs).parameters) + if accepts_run_manager: + docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager) + else: + docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg] + + output: Dict[str, Any] = {} + if self.response_if_no_docs_found is not None and len(docs) == 0: + output[self.output_key] = self.response_if_no_docs_found + else: + new_inputs = inputs.copy() + if self.rephrase_question: + new_inputs['question'] = new_question + new_inputs['chat_history'] = chat_history_str + answer = await self.combine_docs_chain.arun(input_documents=docs, + callbacks=_run_manager.get_child(), + **new_inputs) + output[self.output_key] = answer + + if self.return_source_documents: + output['source_documents'] = docs + if self.return_generated_question: + output['generated_question'] = new_question + return output