Skip to content

Commit

Permalink
Feat/0.2.2 (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 authored Jan 1, 2024
2 parents 2bb1005 + 0ba770b commit 73bd8a1
Show file tree
Hide file tree
Showing 38 changed files with 1,311 additions and 437 deletions.
16 changes: 14 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ jobs:
runs-on: ubuntu-latest
#if: startsWith(github.event.ref, 'refs/tags')
steps:
# deploy
- name: get commit message
id: get_commit_message
run: |
echo "::set-output name=commit_message::$(git log -1 --pretty=%B)"
- name: notify feishu
uses: fjogeleit/http-request-action@v1
with:
url: ' https://open.feishu.cn/open-apis/bot/v2/hook/2cfe0d8d-647c-4408-9f39-c59134035c4b'
method: 'POST'
data: '{"msg_type":"text","content":{"text":"${{steps.get_commit_message.outputs.commit_message}}"}}'

- name: checkout
uses: actions/checkout@v2

Expand Down Expand Up @@ -71,7 +84,6 @@ jobs:
run: |
cd ./src/backend
poetry source add --priority=supplemental foo http://${{ secrets.NEXUS_PUBLIC }}:${{ secrets.NEXUS_PUBLIC_PASSWORD }}@${{ env.PY_NEXUS }}/repository/pypi-group/simple
# sed -i 's/^bisheng_langchain.*/bisheng_langchain = "0.0.0"/g' pyproject.toml
poetry lock
cd ../../
Expand Down Expand Up @@ -118,5 +130,5 @@ jobs:
- name: Deploy Stage
uses: fjogeleit/http-request-action@v1
with:
url: 'http://110.16.193.170:50055/cgi-bin/delpoy.py'
url: 'https://bisheng.dataelem.com/deploy/cgi-bin/deploy_script.py'
method: 'GET'
2 changes: 2 additions & 0 deletions docker/nginx/conf.d/default.conf
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ server {
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
client_max_body_size 50m;
add_header Access-Control-Allow-Origin xxxxx;

}

}
5 changes: 4 additions & 1 deletion src/backend/bisheng/api/JWT.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
class Settings(BaseModel):
authjwt_secret_key: str = settings.jwt_secret
# Configure application to store and get JWT from cookies
authjwt_token_location: set = {'cookies'}
authjwt_token_location: set = {'cookies', 'headers'}
# Disable CSRF Protection for this example. default is True
authjwt_cookie_csrf_protect: bool = False
jwt_optional_claims = ['exp', 'nbf', 'aud', 'iss', 'iat']
iss = 'AITools'
aud = 'AITools'
67 changes: 10 additions & 57 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from typing import List, Optional
from typing import List, Optional, Union
from uuid import UUID

from bisheng.api.utils import build_flow, build_flow_no_yield, build_input_keys_response
from bisheng.api.utils import build_flow, build_input_keys_response
from bisheng.api.v1.schemas import BuildStatus, BuiltResponse, ChatList, InitResponse, StreamData
from bisheng.cache.redis import redis_client
from bisheng.chat.manager import ChatManager
Expand Down Expand Up @@ -111,11 +111,14 @@ def get_chatlist_list(*, session: Session = Depends(get_session), Authorize: Aut


@router.websocket('/chat/{flow_id}')
async def chat(flow_id: str,
websocket: WebSocket,
chat_id: Optional[str] = None,
type: Optional[str] = None,
Authorize: AuthJWT = Depends()):
async def chat(
flow_id: str,
websocket: WebSocket,
chat_id: Optional[str] = None,
type: Optional[str] = None,
session_id: Union[None, str] = None, # noqa: F821
Authorize: AuthJWT = Depends(),
):
Authorize.jwt_required(auth_from='websocket', websocket=websocket)
payload = json.loads(Authorize.get_jwt_subject())
user_id = payload.get('user_id')
Expand Down Expand Up @@ -159,56 +162,6 @@ async def chat(flow_id: str,
logger.error(str(e))


@router.websocket('/chat/ws/{client_id}')
async def union_websocket(client_id: str,
websocket: WebSocket,
chat_id: Optional[str] = None,
type: Optional[str] = None,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required(auth_from='websocket', websocket=websocket)
payload = json.loads(Authorize.get_jwt_subject())
user_id = payload.get('user_id')
"""Websocket endpoint for chat."""
if type and type == 'L1':
with next(get_session()) as session:
db_flow = session.get(Flow, client_id)
if not db_flow:
await websocket.accept()
message = '该技能已被删除'
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
if db_flow.status != 2:
await websocket.accept()
message = '当前技能未上线,无法直接对话'
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
graph_data = db_flow.data
else:
flow_data_key = 'flow_data_' + client_id
if str(flow_data_store.hget(flow_data_key, 'status'), 'utf-8') != BuildStatus.SUCCESS.value:
await websocket.accept()
message = '当前编译没通过'
await websocket.close(code=status.WS_1013_TRY_AGAIN_LATER, reason=message)
graph_data = json.loads(flow_data_store.hget(flow_data_key, 'graph_data'))

try:
process_file = False if chat_id else True
graph = build_flow_no_yield(graph_data=graph_data,
artifacts={},
process_file=process_file,
flow_id=UUID(client_id).hex,
chat_id=chat_id)
langchain_object = graph.build()
for node in langchain_object:
key_node = get_cache_key(client_id, chat_id, node.id)
chat_manager.set_cache(key_node, node._built_object)
chat_manager.set_cache(get_cache_key(client_id, chat_id), node._built_object)
await chat_manager.handle_websocket(client_id, chat_id, websocket, user_id)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as e:
logger.error(str(e))


@router.post('/build/init/{flow_id}', response_model=InitResponse, status_code=201)
async def init_build(*, graph_data: dict, session: Session = Depends(get_session), flow_id: str):
"""Initialize the build by storing graph data and returning a unique session ID."""
Expand Down
108 changes: 100 additions & 8 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import copy
import json
from typing import Optional
from typing import Annotated, Optional, Union

import yaml
from bisheng import settings
from bisheng.api.v1 import knowledge
from bisheng.api.v1.schemas import ProcessResponse, UploadFileResponse
from bisheng.cache.redis import redis_client
from bisheng.cache.utils import save_uploaded_file
from bisheng.chat.utils import judge_source, process_source_document
from bisheng.database.base import get_session
from bisheng.database.models.config import Config
from bisheng.database.models.flow import Flow
from bisheng.database.models.message import ChatMessage
from bisheng.interface.types import langchain_types_dict
from bisheng.processing.process import process_graph_cached, process_tweaks
from bisheng.services.deps import get_session_service, get_task_service
from bisheng.services.task.service import TaskService
from bisheng.settings import parse_key
from bisheng.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import delete
from sqlmodel import Session, select

try:
from bisheng.worker import process_graph_cached_task
except ImportError:

def process_graph_cached_task(*args, **kwargs):
raise NotImplementedError('Celery is not installed')


# build router
router = APIRouter(tags=['Base'])

Expand All @@ -31,8 +43,10 @@ def get_all():

@router.get('/env')
def getn_env():
uns_support = ['png', 'jpg', 'jpeg', 'bmp', 'doc', 'docx', 'ppt',
'pptx', 'xls', 'xlsx', 'txt', 'md', 'html', 'pdf']
uns_support = [
'png', 'jpg', 'jpeg', 'bmp', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'txt', 'md',
'html', 'pdf'
]
env = {}
if isinstance(settings.settings.environment, str):
env['env'] = settings.settings.environment
Expand Down Expand Up @@ -91,10 +105,14 @@ def save_config(data: dict, session: Session = Depends(get_session)):
@router.post('/predict/{flow_id}', response_model=ProcessResponse)
@router.post('/process/{flow_id}', response_model=ProcessResponse)
async def process_flow(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
session: Session = Depends(get_session),
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
task_service: 'TaskService' = Depends(get_task_service),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
):
"""
Endpoint to process an input with a given flow_id.
Expand All @@ -106,17 +124,91 @@ async def process_flow(
flow = session.get(Flow, flow_id)
if flow is None:
raise ValueError(f'Flow {flow_id} not found')

if flow.data is None:
raise ValueError(f'Flow {flow_id} has no data')

graph_data = flow.data
if tweaks:
try:
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f'Error processing tweaks: {exc}')
response = process_graph_cached(graph_data, inputs)
return ProcessResponse(result=response,)

# process
if sync:
result = await process_graph_cached(
graph_data,
inputs,
clear_cache,
session_id,
)
if isinstance(result, dict) and 'result' in result:
task_result = result['result']
session_id = result['session_id']
elif hasattr(result, 'result') and hasattr(result, 'session_id'):
task_result = result.result
session_id = result.session_id
else:
logger.warning('This is an experimental feature and may not work as expected.'
'Please report any issues to our GitHub repository.')
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(session_id=session_id,
data_graph=graph_data)
task_id, task = await task_service.launch_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
if task.status == 'SUCCESS':
task_result = task.result
if hasattr(task_result, 'result'):
task_result = task_result.result
else:
logger.error(f'task_id={task_id} exception task result={task}')

# 判断溯源
source_documents = task_result.pop('source_documents', '')
answer = list(task_result.values())[0]
extra = {}
source = await judge_source(answer, source_documents, session_id, extra)

try:
question = ChatMessage(user_id=0,
is_bot=False,
type='end',
chat_id=session_id,
category='question',
flow_id=flow_id,
message=inputs)
message = ChatMessage(user_id=0,
is_bot=True,
chat_id=session_id,
flow_id=flow_id,
type='end',
category='answer',
message=answer,
source=source)
session.add(question)
session.add(message)
session.commit()
session.refresh(message)
extra.update({'source': source, 'message_id': message.id})
task_result.update(extra)
if source != 0:
await process_source_document(source_documents, session_id, message.id, answer)
except Exception as e:
logger.error(e)

return ProcessResponse(
result=task_result,
# task=task_response,
session_id=session_id,
backend=task_service.backend_name,
)

except Exception as e:
# Log stack trace
logger.exception(e)
Expand Down
4 changes: 1 addition & 3 deletions src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def read_flows(*,
"""Read all flows."""
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

try:
sql = select(Flow.id)
count_sql = select(func.count(Flow.id))
Expand All @@ -73,7 +72,7 @@ def read_flows(*,
sql = sql.where(Flow.status == status)
count_sql = count_sql.where(Flow.status == status)
total_count = session.scalar(count_sql)

logger.debug('flows_get end_count')
sql = sql.order_by(Flow.update_time.desc())
if page_num and page_size:
sql = sql.offset((page_num - 1) * page_size).limit(page_size)
Expand All @@ -82,7 +81,6 @@ def read_flows(*,
if flows:
flows = session.exec(
select(Flow).where(Flow.id.in_(flows)).order_by(Flow.update_time.desc())).all()

res = [jsonable_encoder(flow) for flow in flows]
if flows:
db_user_ids = {flow.user_id for flow in flows}
Expand Down
Loading

0 comments on commit 73bd8a1

Please sign in to comment.