Skip to content

Commit

Permalink
v0.0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
JiauZhang committed May 21, 2024
1 parent d6908ed commit bc7ea5d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 20 deletions.
32 changes: 23 additions & 9 deletions chatchat/baidu.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from chatchat.base import Base
import httpx, json, time
import httpx, time

class Completion(Base):
def __init__(self, jfile, name='ERNIE-Speed-8K'):
def __init__(self, jfile, model='ERNIE-Speed-8K'):
# https://console.bce.baidu.com/qianfan/ais/console/onlineService
self.api_list = {
'ERNIE-Speed-8K': 'ernie_speed',
'ERNIE-Speed-128K': 'ernie-speed-128k',
'ERNIE Speed-AppBuilder': 'ai_apaas',
'ERNIE-Lite-8K': 'ernie-lite-8k',
'ERNIE-Lite-8K-0922': 'eb-instant',
'ERNIE-Bot-turbo-0922': 'eb-instant',
'ERNIE-Tiny-8K': 'ernie-tiny-8k',
'Yi-34B-Chat': 'yi_34b_chat',
}

if name not in self.api_list:
if model not in self.api_list:
raise RuntimeError(f'supported chat type: {self.api_list.keys()}')
self.api = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/' + self.api_list[name]
self.api = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/' + self.api_list[model]
self.client = httpx.Client()

# jfile: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
# {
Expand Down Expand Up @@ -57,7 +60,7 @@ def update_access_token(self):
# "session_secret": "f"
# }
cur_time = time.time()
r = httpx.post(url, headers=self.headers, params=params).json()
r = self.client.post(url, headers=self.headers, params=params).json()
self.jdata['access_token'] = r['access_token']
self.jdata['expires_in'] = cur_time + float(r['expires_in'])
jdata = self.load_json(self.jfile)
Expand All @@ -68,11 +71,22 @@ def get_access_token(self):
self.update_access_token()
return self.jdata['access_token']

def create(self, json):
def create(self, message):
jmsg = {
"messages": [
{
"role": "user",
"content": message,
}
]
}
url = f'{self.api}?access_token={self.get_access_token()}'
r = httpx.request("POST", url, headers=self.headers, json=json)
r = self.client.post(url, headers=self.headers, json=jmsg)
return r.json()

class Chat():
def __init__(self):
class Chat(Completion):
def __init__(self, jfile, model='ERNIE-Speed-8K'):
super().__init__(jfile, model=model)

def chat(self):
...
10 changes: 1 addition & 9 deletions examples/baidu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,7 @@
# }
# }
completion = cc.baidu.Completion('./data.json')
payload = {
"messages": [
{
"role": "user",
"content": "简单介绍一下你自己,控制在五十个字之内。"
}
]
}
r = completion.create(payload)
r = completion.create("简单介绍一下你自己,控制在五十个字之内。")
# {
# 'id': 'xxx',
# 'object': 'chat.completion',
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
setup(
name = 'chatchat',
packages = find_packages(exclude=['examples']),
version = '0.0.4',
version = '0.0.5',
license = 'GPL-2.0',
description = 'large language model api',
description = 'Large Language Model API',
author = 'JiauZhang',
author_email = 'jiauzhang@163.com',
url = 'https://github.com/JiauZhang/chatchat',
Expand Down

0 comments on commit bc7ea5d

Please sign in to comment.