Skip to content

Commit

Permalink
Allow multiple messages (context window) to be passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikey Stengel committed Nov 29, 2023
1 parent f12b8a5 commit f8da70e
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 22 deletions.
10 changes: 7 additions & 3 deletions __tests__/schema/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ const user = { ...baseUser, roles: ['de_reviewer'] }

const query = new Client({ userId: user.id }).prepareQuery({
query: gql`
query ($prompt: String!) {
query ($messages: [ChatCompletionMessageParam!]!) {
ai {
executePrompt(prompt: $prompt) {
executePrompt(messages: $messages) {
success
record
}
}
}
`,
variables: { prompt: 'Generate exercise for 7th grade math in json' },
variables: {
messages: [
{ role: 'user', content: 'Generate exercise for 7th grade math in json' },
],
},
})

beforeAll(() => {
Expand Down
29 changes: 14 additions & 15 deletions packages/server/src/model/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ export async function executePrompt(args: {
// OpenAI, we can pass the user to the model. See
// https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids
userId: number | null
prompt: string
messages: OpenAI.Chat.Completions.ChatCompletionMessageParam[]
}): Promise<Record<string, unknown>> {
const { userId, prompt } = args
const { userId, messages } = args

if (!prompt || prompt.trim() === '') {
throw new UserInputError('Missing prompt parameter')
if (
!messages[0].content ||
(typeof messages[0].content === 'string' &&
messages[0].content.trim() === '')
) {
throw new UserInputError('Missing prompt within message')
}

try {
Expand All @@ -40,28 +44,23 @@ export async function executePrompt(args: {

const response = await openai.chat.completions.create({
model: 'gpt-4-1106-preview',
messages: [
{
role: 'user',
content: prompt,
},
],
messages,
temperature: 0.4,
user: String(userId),
response_format: { type: 'json_object' },
})

// As we now have the response_format defined as json_object, we shouldn't
// need to call JSON.parse on the stringMessage. However, right now the OpenAI
// types seem to be broken (thinking the API is returning a string or null).
// Instead of fighting the types, we can simply adjust this in the next
// version.
const stringMessage = response.choices[0].message.content

if (!stringMessage) {
throw new Error('No content received from LLM!')
}

// As we now have the response_format defined as json_object, we shouldn't
// need to call JSON.parse on the stringMessage. However, right now the OpenAI
// types seem to be broken (thinking the API is returning a string or null).
// Instead of fighting the types, we can simply adjust this in the next
// version.
const message = JSON.parse(stringMessage) as unknown

if (!t.UnknownRecord.is(message)) {
Expand Down
24 changes: 24 additions & 0 deletions packages/server/src/schema/ai/resolvers.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
import * as auth from '@serlo/authorization'
import { Scope } from '@serlo/authorization'
import { either as E } from 'fp-ts'
import * as t from 'io-ts'

import { UserInputError } from '~/errors'
import {
assertUserIsAuthenticated,
assertUserIsAuthorized,
createNamespace,
Queries,
} from '~/internals/graphql'

const ChatCompletionMessageParamType = t.type({
// Restricts role to 'user' or 'system'. Right now, we don't want to allow
// assistant-, tool-, or function calls. See
// https://github.com/openai/openai-node/blob/a048174c0e53269a01993a573a10f96c4c9ec79e/src/resources/chat/completions.ts#L405
role: t.keyof({ user: null, system: null }),
content: t.string,
})

const ExecutePromptRequestType = t.type({
messages: t.array(ChatCompletionMessageParamType),
})

export const resolvers: Queries<'ai'> = {
Query: {
ai: createNamespace(),
Expand All @@ -21,9 +36,18 @@ export const resolvers: Queries<'ai'> = {
message: 'Insufficient role to execute the prompt.',
dataSources,
})
const { messages } = payload

const validationResult = ExecutePromptRequestType.decode({ messages })
if (E.isLeft(validationResult)) {
throw new UserInputError(
'Must contain exclusively user or system messages',
)
}

const record = await dataSources.model.serlo.executePrompt({
...payload,
messages: validationResult.right.messages,
userId,
})

Expand Down
9 changes: 8 additions & 1 deletion packages/server/src/schema/ai/types.graphql
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
input ChatCompletionMessageParam {
role: String!
content: String!
}

extend type Query {
ai: AiQuery!
}

type AiQuery {
executePrompt(prompt: String!): ExecutePromptResponse!
executePrompt(
messages: [ChatCompletionMessageParam!]!
): ExecutePromptResponse!
}

type ExecutePromptResponse {
Expand Down
11 changes: 9 additions & 2 deletions packages/server/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ export type AiQuery = {


export type AiQueryExecutePromptArgs = {
prompt: Scalars['String']['input'];
messages: Array<ChatCompletionMessageParam>;
};

export type AliasInput = {
Expand Down Expand Up @@ -519,6 +519,11 @@ export type CacheUpdateResponse = {
success: Scalars['Boolean']['output'];
};

export type ChatCompletionMessageParam = {
content: Scalars['String']['input'];
role: Scalars['String']['input'];
};

export type CheckoutRevisionInput = {
reason: Scalars['String']['input'];
revisionId: Scalars['Int']['input'];
Expand Down Expand Up @@ -2927,6 +2932,7 @@ export type ResolversTypes = {
CacheSetResponse: ResolverTypeWrapper<ModelOf<CacheSetResponse>>;
CacheUpdateInput: ResolverTypeWrapper<ModelOf<CacheUpdateInput>>;
CacheUpdateResponse: ResolverTypeWrapper<ModelOf<CacheUpdateResponse>>;
ChatCompletionMessageParam: ResolverTypeWrapper<ModelOf<ChatCompletionMessageParam>>;
CheckoutRevisionInput: ResolverTypeWrapper<ModelOf<CheckoutRevisionInput>>;
CheckoutRevisionNotificationEvent: ResolverTypeWrapper<ModelOf<CheckoutRevisionNotificationEvent>>;
CheckoutRevisionResponse: ResolverTypeWrapper<ModelOf<CheckoutRevisionResponse>>;
Expand Down Expand Up @@ -3132,6 +3138,7 @@ export type ResolversParentTypes = {
CacheSetResponse: ModelOf<CacheSetResponse>;
CacheUpdateInput: ModelOf<CacheUpdateInput>;
CacheUpdateResponse: ModelOf<CacheUpdateResponse>;
ChatCompletionMessageParam: ModelOf<ChatCompletionMessageParam>;
CheckoutRevisionInput: ModelOf<CheckoutRevisionInput>;
CheckoutRevisionNotificationEvent: ModelOf<CheckoutRevisionNotificationEvent>;
CheckoutRevisionResponse: ModelOf<CheckoutRevisionResponse>;
Expand Down Expand Up @@ -3447,7 +3454,7 @@ export type AddRevisionResponseResolvers<ContextType = Context, ParentType exten
};

export type AiQueryResolvers<ContextType = Context, ParentType extends ResolversParentTypes['AiQuery'] = ResolversParentTypes['AiQuery']> = {
executePrompt?: Resolver<ResolversTypes['ExecutePromptResponse'], ParentType, ContextType, RequireFields<AiQueryExecutePromptArgs, 'prompt'>>;
executePrompt?: Resolver<ResolversTypes['ExecutePromptResponse'], ParentType, ContextType, RequireFields<AiQueryExecutePromptArgs, 'messages'>>;
__isTypeOf?: IsTypeOfResolverFn<ParentType, ContextType>;
};

Expand Down
7 changes: 6 additions & 1 deletion packages/types/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ export type AiQuery = {


export type AiQueryExecutePromptArgs = {
prompt: Scalars['String']['input'];
messages: Array<ChatCompletionMessageParam>;
};

export type AliasInput = {
Expand Down Expand Up @@ -515,6 +515,11 @@ export type CacheUpdateResponse = {
success: Scalars['Boolean']['output'];
};

export type ChatCompletionMessageParam = {
content: Scalars['String']['input'];
role: Scalars['String']['input'];
};

export type CheckoutRevisionInput = {
reason: Scalars['String']['input'];
revisionId: Scalars['Int']['input'];
Expand Down

0 comments on commit f8da70e

Please sign in to comment.