Skip to content

Commit

Permalink
feat: Compatible with lower version node 14 (#426)
Browse files Browse the repository at this point in the history
* fix:【js】限流依赖库变更

* fix:【js】文档更新

* fix: 解决流式返回数据乱码问题

* fix: 图生文预置服务

* fix: 更新版本号

* AIPD-DEV-22980 [Story] 【js】兼容低版本node16

* AIPD-DEV-22980 [Story] 【js】兼容低版本node16

* feat:【js】动态获取模型列表

---------

Co-authored-by: wangting31 <wangting31@baidu.com>
  • Loading branch information
wangting829 and wangting31 authored Apr 10, 2024
1 parent 888f507 commit ebb7476
Show file tree
Hide file tree
Showing 21 changed files with 880 additions and 405 deletions.
4 changes: 3 additions & 1 deletion javascript/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
"author": "",
"license": "ISC",
"dependencies": {
"@types/node-fetch": "^2.6.11",
"async-mutex": "^0.5.0",
"bottleneck": "^2.19.5",
"debug": "^3.1.0",
"dotenv": "^16.4.1",
"rollup": "^4.9.6",
"node-fetch": "2.7.0",
"rollup": "^3.29.4",
"tslib": "^2.6.2",
"typescript": "^5.3.3",
"underscore": "^1.9.1",
Expand Down
139 changes: 51 additions & 88 deletions javascript/src/Base/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
// limitations under the License.

import HttpClient from '../HttpClient';
import {Fetch, FetchConfig} from '../Fetch';
import {TokenLimiter} from '../Limiter';
import Fetch, {FetchConfig} from '../Fetch/fetch';
import {DEFAULT_HEADERS} from '../constant';
import {getAccessTokenUrl, getIAMConfig, getDefaultConfig, calculateRetryDelay, isOpenTpm} from '../utils';
import {Stream} from '../streaming';
import {Resp, AsyncIterableType, AccessTokenResp, RespBase} from '../interface';
import {getAccessTokenUrl, getIAMConfig, getDefaultConfig, getPath} from '../utils';
import {Resp, AsyncIterableType, AccessTokenResp} from '../interface';
import DynamicModelEndpoint from '../DynamicModelEndpoint';

export class BaseClient {
protected controller: AbortController;
Expand All @@ -27,6 +26,7 @@ export class BaseClient {
protected qianfanAccessKey?: string;
protected qianfanSecretKey?: string;
protected qianfanBaseUrl?: string;
protected qianfanConsoleApiBaseUrl?: string;
protected qianfanLlmApiRetryTimeout?: string;
protected qianfanLlmApiRetryBackoffFactor?: string;
protected qianfanLlmApiRetryCount?: string;
Expand All @@ -35,7 +35,6 @@ export class BaseClient {
protected headers = DEFAULT_HEADERS;
protected fetchInstance: Fetch;
protected fetchConfig: FetchConfig;
private tokenLimiter: TokenLimiter;
access_token = '';
expires_in = 0;

Expand All @@ -45,6 +44,7 @@ export class BaseClient {
QIANFAN_ACCESS_KEY?: string;
QIANFAN_SECRET_KEY?: string;
QIANFAN_BASE_URL?: string;
QIANFAN_CONSOLE_API_BASE_URL?: string;
QIANFAN_LLM_API_RETRY_TIMEOUT?: string;
QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR?: string;
QIANFAN_LLM_API_RETRY_COUNT?: string;
Expand All @@ -58,25 +58,21 @@ export class BaseClient {
this.qianfanSecretKey = options?.QIANFAN_SECRET_KEY ?? defaultConfig.QIANFAN_SECRET_KEY;
this.Endpoint = options?.Endpoint;
this.qianfanBaseUrl = options?.QIANFAN_BASE_URL ?? defaultConfig.QIANFAN_BASE_URL;
this.qianfanConsoleApiBaseUrl
= options?.QIANFAN_CONSOLE_API_BASE_URL ?? defaultConfig.QIANFAN_CONSOLE_API_BASE_URL;
this.qianfanLlmApiRetryTimeout
= options?.QIANFAN_LLM_API_RETRY_TIMEOUT ?? defaultConfig.QIANFAN_LLM_API_RETRY_TIMEOUT;
this.qianfanLlmApiRetryBackoffFactor
= options?.QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR ?? defaultConfig.QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR;
this.qianfanLlmApiRetryCount
= options?.QIANFAN_LLM_API_RETRY_COUNT ?? defaultConfig.QIANFAN_LLM_API_RETRY_COUNT;
this.controller = new AbortController();
this.fetchConfig = {
retries: Number(this.qianfanLlmApiRetryCount),
this.fetchInstance = new Fetch({
maxRetries: Number(this.qianfanLlmApiRetryCount),
timeout: Number(this.qianfanLlmApiRetryTimeout),
retryDelay: attempt =>
calculateRetryDelay(
attempt,
Number(this.qianfanLlmApiRetryBackoffFactor),
Number(this.qianfanLlmRetryMaxWaitInterval)
),
};
this.fetchInstance = new Fetch(this.fetchConfig);
this.tokenLimiter = new TokenLimiter();
backoffFactor: Number(this.qianfanLlmApiRetryBackoffFactor),
retryMaxWaitInterval: Number(this.qianfanLlmRetryMaxWaitInterval),
});
}

/**
Expand All @@ -87,16 +83,15 @@ export class BaseClient {
private async getAccessToken(): Promise<AccessTokenResp> {
const url = getAccessTokenUrl(this.qianfanAk, this.qianfanSk, this.qianfanBaseUrl);
try {
const resp = await this.fetchInstance.fetchWithRetry(url, {headers: this.headers, method: 'POST'});
const data = (await resp.json()) as AccessTokenResp;
if (data?.error) {
throw new Error(data?.error_description || 'Failed to get access token');
}
this.access_token = data.access_token ?? '';
this.expires_in = data.expires_in + Date.now() / 1000;
const resp = await this.fetchInstance.makeRequest(url, {
headers: this.headers,
method: 'POST',
});
this.access_token = resp.access_token ?? '';
this.expires_in = resp.expires_in + Date.now() / 1000;
return {
access_token: data.access_token,
expires_in: data.expires_in,
access_token: resp.access_token,
expires_in: resp.expires_in,
};
}
catch (error) {
Expand All @@ -106,7 +101,8 @@ export class BaseClient {
}

protected async sendRequest(
IAMpath: string,
type: string,
model: string,
AKPath: string,
requestBody: string,
stream = false
Expand All @@ -120,10 +116,31 @@ export class BaseClient {
// IAM鉴权
if (this.qianfanAccessKey && this.qianfanSecretKey) {
const config = getIAMConfig(this.qianfanAccessKey, this.qianfanSecretKey, this.qianfanBaseUrl);
console.log(config);
const client = new HttpClient(config);
const dynamicModelEndpoint = new DynamicModelEndpoint(
client,
this.qianfanConsoleApiBaseUrl,
this.qianfanBaseUrl
);
let IAMPath = '';
if (this.Endpoint) {
IAMPath = getPath({
Authentication: 'IAM',
api_base: this.qianfanBaseUrl,
endpoint: this.Endpoint,
type,
});
}
else {
IAMPath = await dynamicModelEndpoint.getEndpoint(type, model);
}
if (!IAMPath) {
throw new Error(`${model} is not supported`);
}
fetchOptions = await client.getSignature({
httpMethod: 'POST',
path: IAMpath,
path: IAMPath,
body: requestBody,
headers: this.headers,
});
Expand All @@ -141,67 +158,13 @@ export class BaseClient {
body: requestBody,
};
}

// 计算请求token
const tokens = this.tokenLimiter.calculateTokens(requestBody);
const hasToken = await this.tokenLimiter.acquireTokens(tokens);
// 满足token限制
if (hasToken) {
try {
const resp = await this.fetchInstance.fetchWithRetry(fetchOptions.url, fetchOptions);
const val = this.getTpmHeader(resp.headers);
let usedTokens = 0;
if (stream) {
const sseStream = Stream.fromSSEResponse(resp, this.controller);
const [stream1, stream2] = sseStream.tee();
if (isOpenTpm(val)) {
const updateTokensAsync = async () => {
for await (const data of stream1) {
const typedData = data as RespBase;
if (typedData.is_end) {
usedTokens = typedData?.usage?.total_tokens;
await this.tokenLimiter.acquireTokens(usedTokens - tokens);
break;
}
}
};
setTimeout(updateTokensAsync, 0);
}
return stream2 as AsyncIterableType;
}
const data = await resp.json();
setTimeout(async () => {
usedTokens = this.getUsedTokens(data);
await this.tokenLimiter.acquireTokens(usedTokens - tokens);
}, 0);
return data as any;
}
catch (error) {
throw error;
}
}
else {
throw new Error('Token limit exceeded');
try {
const {url, ...rest} = fetchOptions;
const resp = await this.fetchInstance.makeRequest(url, {...rest, stream});
return resp;
}
}

getTpmHeader(headers: any): void {
const val = headers.get('x-ratelimit-limit-tokens') ?? '0';
this.tokenLimiter.resetTokens(val);
return val;
}

async getStreamUsedTokens(data: AsyncIterableType): Promise<number> {
let usedTokens = 0;
for await (const chunk of data as AsyncIterableType) {
if (chunk.is_end) {
usedTokens = chunk?.usage?.total_tokens;
}
catch (error) {
throw error;
}
return usedTokens ?? 0;
}
getUsedTokens(data: Resp): number {
const usage = data?.usage?.total_tokens;
return usage ?? 0;
}
}
8 changes: 5 additions & 3 deletions javascript/src/ChatCompletion/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {BaseClient} from '../Base';
import {ChatBody, Resp} from '../interface';
import {modelInfoMap} from './utils';
import {getPathAndBody, getUpperCaseModelAndModelMap} from '../utils';
import {ModelType} from '../enum';

class ChatCompletion extends BaseClient {
/**
Expand All @@ -28,15 +29,16 @@ class ChatCompletion extends BaseClient {
public async chat(body: ChatBody, model = 'ERNIE-Bot-turbo'): Promise<Resp | AsyncIterable<Resp>> {
const stream = body.stream ?? false;
const {modelInfoMapUppercase, modelUppercase} = getUpperCaseModelAndModelMap(model, modelInfoMap);
const {IAMPath, AKPath, requestBody} = getPathAndBody({
const type = ModelType.CHAT;
const {AKPath, requestBody} = getPathAndBody({
model: modelUppercase,
modelInfoMap: modelInfoMapUppercase,
baseUrl: this.qianfanBaseUrl,
body,
endpoint: this.Endpoint,
type: 'chat',
type,
});
return this.sendRequest(IAMPath, AKPath, requestBody, stream);
return this.sendRequest(type, model, AKPath, requestBody, stream);
}
}

Expand Down
8 changes: 5 additions & 3 deletions javascript/src/Completions/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {BaseClient} from '../Base';
import {ChatBody, CompletionBody, Resp} from '../interface';
import {modelInfoMap, isCompletionBody} from './utils';
import {getPathAndBody, getUpperCaseModelAndModelMap} from '../utils';
import {ModelType} from '../enum';

class Completions extends BaseClient {
/**
Expand All @@ -33,6 +34,7 @@ class Completions extends BaseClient {
// 兼容Chat模型
const required_keys = modelInfoMapUppercase[modelUppercase]?.required_keys;
let reqBody: CompletionBody | ChatBody;
const type = ModelType.COMPLETIONS;
if (required_keys.includes('messages') && isCompletionBody(body)) {
const {prompt, ...restOfBody} = body;
reqBody = {
Expand All @@ -48,15 +50,15 @@ class Completions extends BaseClient {
else {
reqBody = body;
}
const {IAMPath, AKPath, requestBody} = getPathAndBody({
const {AKPath, requestBody} = getPathAndBody({
model: modelUppercase,
modelInfoMap: modelInfoMapUppercase,
baseUrl: this.qianfanBaseUrl,
body: reqBody,
endpoint: this.Endpoint,
type: 'completions',
type,
});
return this.sendRequest(IAMPath, AKPath, requestBody, stream);
return this.sendRequest(type, model, AKPath, requestBody, stream);
}
}

Expand Down
103 changes: 103 additions & 0 deletions javascript/src/DynamicModelEndpoint/__tests__/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import DynamicModelEndpoint from '../index';
import HttpClient from '../../HttpClient';
import Fetch from '../../Fetch/fetch';

jest.mock('../../HttpClient', () => {
return jest.fn().mockImplementation(() => {
return {
getSignature: jest.fn().mockResolvedValue({
httpMethod: 'POST',
path: 'https://qianfan-console-api-base-url.com',
body: JSON.stringify({apiTypefilter: ['type']}),
headers: {
'Content-Type': 'application/json',
Accept: 'application/json',
},
}),
};
});
});

jest.mock('../../Fetch/fetch', () => {
return jest.fn().mockImplementation(() => {
return {
makeRequest: jest.fn()
.mockResolvedValueOnce({
result: {
common: [
{apiType: 'chat', name: 'model1', url: 'https://example.com/Model1'},
{apiType: 'chat', name: 'model2', url: 'https://example.com/Model2'},
],
},
})
.mockRejectedValueOnce(new Error('Network error')),
};
});
});

function setupDynamicModelEndpoint(clientResponse) {
const client = new HttpClient({}) as jest.Mocked<HttpClient>;
const fetchInstance = new Fetch() as jest.Mocked<Fetch>;

client.getSignature.mockResolvedValue(clientResponse); // 假设这是获取签名的响应
fetchInstance.makeRequest.mockResolvedValue({}); // 假设这是 fetch 请求的响应

return new DynamicModelEndpoint(client, 'https://qianfan-console-api-base-url.com', 'https://qianfan-base-url.com');
}

// Test Suites
describe('DynamicModelEndpoint', () => {
beforeEach(() => {
// 重置所有Mocks
jest.clearAllMocks();
});
// 测试动态映射未过期时获取端点
it('should return endpoint from dynamic mapping when not expired', async () => {
const endpoint = setupDynamicModelEndpoint({});
endpoint.setDynamicMapExpireAt(Date.now() / 1000 + 1000); // 设置为未过期
const dynamicTypeModelEndpointMap = endpoint.getDynamicTypeModelEndpointMap();
dynamicTypeModelEndpointMap.set('chat', new Map([['model1', 'https://example.com/Model0']]));

await expect(endpoint.getEndpoint('chat', 'Model1')).resolves.toEqual('https://example.com/Model0');
});
// 测试动态映射已过期并成功更新后获取端点
it('should update and return endpoint from dynamic mapping when expired', async () => {
const endpoint = setupDynamicModelEndpoint({});
endpoint.setDynamicMapExpireAt(Date.now() / 1000 - 1); // 设置为已过期
const dynamicTypeModelEndpointMap = endpoint.getDynamicTypeModelEndpointMap();
dynamicTypeModelEndpointMap.set('chat', new Map([['model1', 'https://example.com/Model0']]));
await expect(endpoint.getEndpoint('chat', 'model1')).resolves.toEqual('https://example.com/Model1');
expect(dynamicTypeModelEndpointMap.get('chat').get('model1')).toEqual('https://example.com/Model1');
});
// 测试更新动态映射失败时的,读取默认配置
it('should handle failure during dynamic mapping update', async () => {
const client = new HttpClient({}) as jest.Mocked<HttpClient>;
const fetchInstance = new Fetch() as jest.Mocked<Fetch>;
client.getSignature.mockResolvedValue({}); // 假设这是获取签名的响应
fetchInstance.makeRequest.mockRejectedValue(new Error('Failed to fetch')); // 模拟 fetch 请求失败
const endpoint = new DynamicModelEndpoint(client, 'https://qianfan-console-api-base-url.com', 'https://qianfan-base-url.com');
const dynamicTypeModelEndpointMap = endpoint.getDynamicTypeModelEndpointMap();
dynamicTypeModelEndpointMap.set('chat', new Map([['model1', 'https://example.com/Model0']]));
endpoint.setDynamicMapExpireAt(Date.now() / 1000 - 1); // 设置为已过期
expect(dynamicTypeModelEndpointMap.get('chat').get('model1')).toEqual('https://example.com/Model0');
});
// 测试当动态和静态映射中均未找到模型时的行为
it('should return an empty string when the model is not found in both mappings', async () => {
const endpoint = setupDynamicModelEndpoint({});
endpoint.setDynamicMapExpireAt(Date.now() / 1000 + 1000); // 设置为未过期
await expect(endpoint.getEndpoint('chat', 'NonExistentModel')).resolves.toEqual(undefined);
});
});
Loading

0 comments on commit ebb7476

Please sign in to comment.