Skip to content

Commit

Permalink
Feat/0.2.2.5 (#318)
Browse files Browse the repository at this point in the history
conversation retrieve 不重复
  • Loading branch information
yaojin3616 authored Jan 31, 2024
2 parents fe534d2 + 115d8aa commit e0eab7d
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 10 deletions.
3 changes: 1 addition & 2 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ async def process_flow(
message=answer,
source=source)
with session_getter() as session:
session.add(question)
session.add(message)
session.add_all([question, message])
session.commit()
session.refresh(message)
extra.update({'source': source, 'message_id': message.id})
Expand Down
1 change: 0 additions & 1 deletion src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def addEmbedding(collection_name, index_name, knowledge_id: int, model: str, chu
object_name_original = f'original/{db_file.id}'
setattr(db_file, 'object_name', object_name_original)
session.add(db_file)
session.flush()
session.commit()
session.refresh(db_file)

Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ async def get_template(*, flow_id: str):
db_report = Report(flow_id=flow_id)
elif db_report.object_name:
file_url = minio_client.MinioClient().get_share_link(db_report.object_name)

if not db_report.newversion_key or not db_report.object_name:
version_key = uuid4().hex
db_report.newversion_key = version_key
with session_getter() as session:
session.add(db_report)
session.commit()
session.refresh()
else:
version_key = db_report.newversion_key
res = {
Expand Down
1 change: 0 additions & 1 deletion src/backend/bisheng/api/v1/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ async def access_list(*, role_id: int, type: Optional[int] = None, Authorize: Au
with session_getter() as session:
db_role_access = session.exec(sql).all()
total_count = session.scalar(count_sql)
session.commit()
# uuid 和str的转化
for access in db_role_access:
if access.type == AccessType.FLOW.value:
Expand Down
11 changes: 7 additions & 4 deletions src/backend/bisheng/chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ async def process_source_document(source_document: List[Document], chat_id, mess
logger.error('不能使用配置模型进行关键词抽取,配置不正确')

answer_keywords = extract_answer_keys(answer, model, host_base_url)

batch_insert = []
for doc in source_document:
if 'bbox' in doc.metadata:
# 表示支持溯源
Expand All @@ -149,7 +151,8 @@ async def process_source_document(source_document: List[Document], chat_id, mess
file_id=doc.metadata.get('file_id'),
meta_data=json.dumps(doc.metadata),
message_id=message_id)
with session_getter() as db_session:
db_session.add(recall_chunk)
db_session.commit()
db_session.refresh(recall_chunk)
batch_insert.append(recall_chunk)
if batch_insert:
with session_getter() as db_session:
db_session.add_all(batch_insert)
db_session.commit()
1 change: 1 addition & 0 deletions src/backend/bisheng/interface/initialize/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict, id_di
# sequence chain
if node_type == 'SequentialChain':
# 改造sequence 支持自定义chain顺序
params.pop('input_node', '') # sequential 不支持增加入参
try:
chain_order = json.loads(params.pop('chain_order'))
except Exception:
Expand Down
5 changes: 3 additions & 2 deletions src/bisheng-langchain/bisheng_langchain/chains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain
from bisheng_langchain.chains.combine_documents.stuff import StuffDocumentsChain
from bisheng_langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from bisheng_langchain.chains.retrieval.retrieval_chain import RetrievalChain
from bisheng_langchain.chains.router.multi_rule import MultiRuleChain
from bisheng_langchain.chains.router.rule_router import RuleBasedRouter

from .loader_output import LoaderOutputChain

__all__ = [
'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter', 'MultiRuleChain',
'RetrievalChain'
'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter',
'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain'
]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import inspect
from typing import Any, Dict, List, Optional, Tuple, Union

from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
from langchain.chains.conversational_retrieval.base import \
ConversationalRetrievalChain as BaseConversationalRetrievalChain
from langchain_core.messages import BaseMessage

# Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated.
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]

_ROLE_MAP = {'human': 'Human: ', 'ai': 'Assistant: '}


def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
buffer = ''
for dialogue_turn in chat_history:
if isinstance(dialogue_turn, BaseMessage):
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f'{dialogue_turn.type}: ')
buffer += f'\n{role_prefix}{dialogue_turn.content}'
elif isinstance(dialogue_turn, tuple):
human = 'Human: ' + dialogue_turn[0]
ai = 'Assistant: ' + dialogue_turn[1]
buffer += '\n' + '\n'.join([human, ai])
else:
raise ValueError(f'Unsupported chat history format: {type(dialogue_turn)}.'
f' Full chat history: {chat_history} ')
return buffer


class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""ConversationalRetrievalChain is a chain you can use to have a conversation with a character from a series."""

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs['question']
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs['chat_history'])

if chat_history_str:
# callbacks = _run_manager.get_child()
new_question = self.question_generator.run(question=question,
chat_history=chat_history_str)
else:
new_question = question
accepts_run_manager = ('run_manager' in inspect.signature(self._get_docs).parameters)
if accepts_run_manager:
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs['question'] = new_question
new_inputs['chat_history'] = chat_history_str
answer = self.combine_docs_chain.run(input_documents=docs,
callbacks=_run_manager.get_child(),
**new_inputs)
output[self.output_key] = answer

if self.return_source_documents:
output['source_documents'] = docs
if self.return_generated_question:
output['generated_question'] = new_question
return output

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs['question']
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs['chat_history'])
if chat_history_str:
# callbacks = _run_manager.get_child()
new_question = await self.question_generator.arun(question=question,
chat_history=chat_history_str)
else:
new_question = question
accepts_run_manager = ('run_manager' in inspect.signature(self._aget_docs).parameters)
if accepts_run_manager:
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]

output: Dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs['question'] = new_question
new_inputs['chat_history'] = chat_history_str
answer = await self.combine_docs_chain.arun(input_documents=docs,
callbacks=_run_manager.get_child(),
**new_inputs)
output[self.output_key] = answer

if self.return_source_documents:
output['source_documents'] = docs
if self.return_generated_question:
output['generated_question'] = new_question
return output

0 comments on commit e0eab7d

Please sign in to comment.