diff --git a/node/chat/chat.spec.ts b/node/chat/chat.spec.ts index f635c73..af3013c 100644 --- a/node/chat/chat.spec.ts +++ b/node/chat/chat.spec.ts @@ -35,7 +35,7 @@ describe("tea/chat.spec.ts", () => { expect( await buffer.getLines({ start: 0, end: -1 }), "initial render of chat works", - ).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]); + ).toEqual(Chat.LOGO.split("\n") as Line[]); app.dispatch({ type: "add-message", @@ -144,7 +144,7 @@ describe("tea/chat.spec.ts", () => { expect( await buffer.getLines({ start: 0, end: -1 }), "initial render of chat works", - ).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]); + ).toEqual(Chat.LOGO.split("\n") as Line[]); app.dispatch({ type: "add-message", @@ -172,10 +172,6 @@ describe("tea/chat.spec.ts", () => { "Stopped (end_turn) [input: 0, output: 0]", ] as Line[]); - // expect( - // await extractMountTree(mountedApp.getMountedNode()), - // ).toMatchSnapshot(); - app.dispatch({ type: "clear", }); @@ -184,11 +180,7 @@ describe("tea/chat.spec.ts", () => { expect( await buffer.getLines({ start: 0, end: -1 }), "finished render is as expected", - ).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]); - - // expect( - // await extractMountTree(mountedApp.getMountedNode()), - // ).toMatchSnapshot(); + ).toEqual(Chat.LOGO.split("\n") as Line[]); }); }); }); diff --git a/node/chat/chat.ts b/node/chat/chat.ts index d1eaa14..7ea1ee5 100644 --- a/node/chat/chat.ts +++ b/node/chat/chat.ts @@ -8,14 +8,14 @@ import { type Update, wrapThunk, } from "../tea/tea.ts"; -import { d, type View } from "../tea/view.ts"; +import { d, withBindings, type View } from "../tea/view.ts"; import * as ToolManager from "../tools/toolManager.ts"; import { type Result } from "../utils/result.ts"; import { Counter } from "../utils/uniqueId.ts"; import type { Nvim } from "nvim-node"; import type { Lsp } from "../lsp.ts"; import { - getClient, + getClient as getProvider, type ProviderMessage, type ProviderMessageContent, type ProviderName, @@ -24,6 +24,7 @@ import { } from "../providers/provider.ts"; import { assertUnreachable } from "../utils/assertUnreachable.ts"; import { DEFAULT_OPTIONS, type MagentaOptions } from "../options.ts"; +import { getOption } from "../nvim/nvim.ts"; export type Role = "user" | "assistant"; @@ -103,6 +104,9 @@ export type Msg = | { type: "set-opts"; options: MagentaOptions; + } + | { + type: "show-message-debug-info"; }; export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) { @@ -421,7 +425,7 @@ ${msg.error.stack}`, model, // eslint-disable-next-line @typescript-eslint/require-await async () => { - getClient(nvim, model.activeProvider, model.options).abort(); + getProvider(nvim, model.activeProvider, model.options).abort(); }, ]; } @@ -430,6 +434,10 @@ ${msg.error.stack}`, return [{ ...model, options: msg.options }]; } + case "show-message-debug-info": { + return [model, () => showDebugInfo(model)]; + } + default: assertUnreachable(msg); } @@ -490,7 +498,7 @@ ${msg.error.stack}`, }); let res; try { - res = await getClient( + res = await getProvider( nvim, model.activeProvider, model.options, @@ -537,6 +545,13 @@ ${msg.error.stack}`, model, dispatch, }) => { + if ( + model.messages.length == 0 && + Object.keys(model.contextManager.files).length == 0 + ) { + return d`${LOGO}`; + } + return d`${model.messages.map( (m, idx) => d`${messageModel.view({ @@ -556,12 +571,17 @@ ${msg.error.stack}`, ) % MESSAGE_ANIMATION.length ] }` - : d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${ - model.conversation.usage.cacheHits !== undefined && - model.conversation.usage.cacheMisses !== undefined - ? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}` - : "" - }]` + : withBindings( + d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${ + model.conversation.usage.cacheHits !== undefined && + model.conversation.usage.cacheMisses !== undefined + ? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}` + : "" + }]`, + { + "": () => dispatch({ type: "show-message-debug-info" }), + }, + ) }${ model.conversation.state == "stopped" && !contextManagerModel.isContextEmpty(model.contextManager) @@ -634,6 +654,44 @@ ${msg.error.stack}`, return messages.map((m) => m.message); } + async function showDebugInfo(model: Model) { + const messages = await getMessages(model); + const provider = getProvider(nvim, model.activeProvider, model.options); + const params = provider.createStreamParameters(messages); + const nTokens = await provider.countTokens(messages); + + // Create a floating window + const bufnr = await nvim.call("nvim_create_buf", [false, true]); + await nvim.call("nvim_buf_set_option", [bufnr, "bufhidden", "wipe"]); + const [editorWidth, editorHeight] = (await Promise.all([ + getOption("columns", nvim), + await getOption("lines", nvim), + ])) as [number, number]; + const width = 80; + const height = editorHeight - 20; + await nvim.call("nvim_open_win", [ + bufnr, + true, + { + relative: "editor", + width, + height, + col: Math.floor((editorWidth - width) / 2), + row: Math.floor((editorHeight - height) / 2), + style: "minimal", + border: "single", + }, + ]); + + const lines = JSON.stringify(params, null, 2).split("\n"); + lines.push(`nTokens: ${nTokens}`); + await nvim.call("nvim_buf_set_lines", [bufnr, 0, -1, false, lines]); + + // Set buffer options + await nvim.call("nvim_buf_set_option", [bufnr, "modifiable", false]); + await nvim.call("nvim_buf_set_option", [bufnr, "filetype", "json"]); + } + return { initModel, update, @@ -641,3 +699,13 @@ ${msg.error.stack}`, getMessages, }; } + +export const LOGO = `\ + + ________ + ╱ ╲ + ╱ ╱ +╱ ╱ +╲__╱__╱__╱ + +# magenta.nvim`; diff --git a/node/magenta.spec.ts b/node/magenta.spec.ts index 0440bb1..6bea6fb 100644 --- a/node/magenta.spec.ts +++ b/node/magenta.spec.ts @@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest"; import { withDriver } from "./test/preamble"; import { pollUntil } from "./utils/async"; import type { Position0Indexed } from "./nvim/window"; +import { LOGO } from "./chat/chat"; describe("node/magenta.spec.ts", () => { it("clear command should work", async () => { @@ -25,9 +26,7 @@ sup? Stopped (end_turn) [input: 0, output: 0]`); await driver.clear(); - await driver.assertDisplayBufferContent( - `Stopped (end_turn) [input: 0, output: 0]`, - ); + await driver.assertDisplayBufferContent(LOGO); await driver.inputMagentaText(`hello again`); await driver.send(); await driver.mockAnthropic.respond({ diff --git a/node/providers/anthropic.ts b/node/providers/anthropic.ts index ba3c85d..f679662 100644 --- a/node/providers/anthropic.ts +++ b/node/providers/anthropic.ts @@ -47,34 +47,9 @@ export class AnthropicProvider implements Provider { } } - async sendMessage( - messages: Array, - onText: (text: string) => void, - onError: (error: Error) => void, - ): Promise<{ - toolRequests: Result[]; - stopReason: StopReason; - usage: Usage; - }> { - const buf: string[] = []; - let flushInProgress: boolean = false; - - const flushBuffer = () => { - if (buf.length && !flushInProgress) { - const text = buf.join(""); - buf.splice(0); - - flushInProgress = true; - - try { - onText(text); - } finally { - flushInProgress = false; - setInterval(flushBuffer, 1); - } - } - }; - + createStreamParameters( + messages: ProviderMessage[], + ): Anthropic.Messages.MessageStreamParams { const anthropicMessages = messages.map((m): MessageParam => { let content: Anthropic.Messages.ContentBlockParam[]; if (typeof m.content == "string") { @@ -116,7 +91,7 @@ export class AnthropicProvider implements Provider { }; }); - placeCacheBreakpoints(anthropicMessages); + const cacheControlItemsPlaced = placeCacheBreakpoints(anthropicMessages); const tools: Anthropic.Tool[] = ToolManager.TOOL_SPECS.map( (t): Anthropic.Tool => { @@ -127,19 +102,77 @@ export class AnthropicProvider implements Provider { }, ); + return { + messages: anthropicMessages, + model: this.options.model, + max_tokens: 4096, + system: [ + { + type: "text", + text: DEFAULT_SYSTEM_PROMPT, + // the prompt appears in the following order: + // tools + // system + // messages + // This ensures the tools + system prompt (which is approx 1400 tokens) is cached. + cache_control: + cacheControlItemsPlaced < 4 ? { type: "ephemeral" } : null, + }, + ], + tool_choice: { + type: "auto", + disable_parallel_tool_use: false, + }, + tools, + }; + } + + async countTokens(messages: Array): Promise { + const params = this.createStreamParameters(messages); + const lastMessage = params.messages[params.messages.length - 1]; + if (!lastMessage || lastMessage.role != "user") { + params.messages.push({ role: "user", content: "test" }); + } + const res = await this.client.messages.countTokens({ + messages: params.messages, + model: params.model, + system: params.system as Anthropic.TextBlockParam[], + tools: params.tools as Anthropic.Tool[], + }); + return res.input_tokens; + } + + async sendMessage( + messages: Array, + onText: (text: string) => void, + onError: (error: Error) => void, + ): Promise<{ + toolRequests: Result[]; + stopReason: StopReason; + usage: Usage; + }> { + const buf: string[] = []; + let flushInProgress: boolean = false; + + const flushBuffer = () => { + if (buf.length && !flushInProgress) { + const text = buf.join(""); + buf.splice(0); + + flushInProgress = true; + + try { + onText(text); + } finally { + flushInProgress = false; + setInterval(flushBuffer, 1); + } + } + }; + try { this.request = this.client.messages - .stream({ - messages: anthropicMessages, - model: this.options.model, - max_tokens: 4096, - system: DEFAULT_SYSTEM_PROMPT, - tool_choice: { - type: "auto", - disable_parallel_tool_use: false, - }, - tools, - }) + .stream(this.createStreamParameters(messages)) .on("text", (text: string) => { buf.push(text); flushBuffer(); @@ -247,7 +280,7 @@ export class AnthropicProvider implements Provider { } } -export function placeCacheBreakpoints(messages: MessageParam[]) { +export function placeCacheBreakpoints(messages: MessageParam[]): number { // when we scan the messages, keep track of where each part ends. const blocks: { block: Anthropic.Messages.ContentBlockParam; acc: number }[] = []; @@ -315,6 +348,8 @@ export function placeCacheBreakpoints(messages: MessageParam[]) { } } } + + return powers.length; } const STR_CHARS_PER_TOKEN = 4; diff --git a/node/providers/mock.ts b/node/providers/mock.ts index 41899f4..7277a16 100644 --- a/node/providers/mock.ts +++ b/node/providers/mock.ts @@ -39,6 +39,15 @@ export class MockProvider implements Provider { } } + createStreamParameters(messages: Array): unknown { + return messages; + } + + // eslint-disable-next-line @typescript-eslint/require-await + async countTokens(messages: Array): Promise { + return messages.length; + } + async sendMessage( messages: Array, onText: (text: string) => void, diff --git a/node/providers/openai.ts b/node/providers/openai.ts index 46557d7..4176a3b 100644 --- a/node/providers/openai.ts +++ b/node/providers/openai.ts @@ -12,6 +12,7 @@ import type { ToolName, ToolRequestId } from "../tools/toolManager.ts"; import type { Nvim } from "nvim-node"; import type { Stream } from "openai/streaming.mjs"; import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts"; +import tiktoken from "tiktoken"; export type OpenAIOptions = { model: "gpt-4o"; @@ -44,15 +45,50 @@ export class OpenAIProvider implements Provider { }); } - async sendMessage( + // eslint-disable-next-line @typescript-eslint/require-await + async countTokens(messages: Array): Promise { + const enc = tiktoken.encoding_for_model("gpt-4o"); + let totalTokens = 0; + + // Count system message + totalTokens += enc.encode(DEFAULT_SYSTEM_PROMPT).length; + + for (const message of messages) { + if (typeof message.content === "string") { + totalTokens += enc.encode(message.content).length; + } else { + for (const content of message.content) { + switch (content.type) { + case "text": + totalTokens += enc.encode(content.text).length; + break; + case "tool_use": + totalTokens += enc.encode(content.request.name).length; + totalTokens += enc.encode( + JSON.stringify(content.request.input), + ).length; + break; + case "tool_result": + totalTokens += enc.encode( + content.result.status === "ok" + ? content.result.value + : content.result.error, + ).length; + break; + } + } + } + // Add tokens for message format (role, etc) + totalTokens += 3; + } + + enc.free(); + return totalTokens; + } + + createStreamParameters( messages: Array, - onText: (text: string) => void, - _onError: (error: Error) => void, - ): Promise<{ - toolRequests: Result[]; - stopReason: StopReason; - usage: Usage; - }> { + ): OpenAI.ChatCompletionCreateParamsStreaming { const openaiMessages: OpenAI.ChatCompletionMessageParam[] = [ { role: "system", @@ -131,26 +167,40 @@ export class OpenAIProvider implements Provider { } } + return { + model: this.options.model, + stream: true, + messages: openaiMessages, + // see https://platform.openai.com/docs/guides/function-calling#parallel-function-calling-and-structured-outputs + // this recommends disabling parallel tool calls when strict adherence to schema is needed + parallel_tool_calls: false, + tools: ToolManager.TOOL_SPECS.map((s): OpenAI.ChatCompletionTool => { + return { + type: "function", + function: { + name: s.name, + description: s.description, + strict: true, + parameters: s.input_schema as OpenAI.FunctionParameters, + }, + }; + }), + }; + } + + async sendMessage( + messages: Array, + onText: (text: string) => void, + _onError: (error: Error) => void, + ): Promise<{ + toolRequests: Result[]; + stopReason: StopReason; + usage: Usage; + }> { try { - const stream = (this.request = await this.client.chat.completions.create({ - model: this.options.model, - stream: true, - messages: openaiMessages, - // see https://platform.openai.com/docs/guides/function-calling#parallel-function-calling-and-structured-outputs - // this recommends disabling parallel tool calls when strict adherence to schema is needed - parallel_tool_calls: false, - tools: ToolManager.TOOL_SPECS.map((s): OpenAI.ChatCompletionTool => { - return { - type: "function", - function: { - name: s.name, - description: s.description, - strict: true, - parameters: s.input_schema as OpenAI.FunctionParameters, - }, - }; - }), - })); + const stream = (this.request = await this.client.chat.completions.create( + this.createStreamParameters(messages), + )); const toolRequests = []; let stopReason: StopReason | undefined; diff --git a/node/providers/provider.ts b/node/providers/provider.ts index 194b57c..be3e07f 100644 --- a/node/providers/provider.ts +++ b/node/providers/provider.ts @@ -57,6 +57,9 @@ export type ProviderMessageContent = | ProviderToolResultContent; export interface Provider { + createStreamParameters(messages: Array): unknown; + countTokens(messages: Array): Promise; + sendMessage( messages: Array, onText: (text: string) => void, diff --git a/package-lock.json b/package-lock.json index 42127d9..7017a43 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "ignore": "^7.0.0", "nvim-node": "0.0.2", "openai": "^4.77.0", + "tiktoken": "^1.0.18", "tsx": "^4.19.2", "typescript-eslint": "^8.19.0" }, @@ -2854,6 +2855,12 @@ "integrity": "sha512-uuVGNWzgJ4yhRaNSiubPY7OjISw4sw4E5Uv0wbjp+OzcbmVU/rsT8ujgcXJhn9ypzsgr5vlzpPqP+MBBKcGvbg==", "license": "MIT" }, + "node_modules/tiktoken": { + "version": "1.0.18", + "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.18.tgz", + "integrity": "sha512-DXJesdYwmBHtkmz1sji+UMZ4AOEE8F7Uw/PS/uy0XfkKOzZC4vXkYXHMYyDT+grdflvF4bggtPt9cYaqOMslBw==", + "license": "MIT" + }, "node_modules/tinybench": { "version": "2.9.0", "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", diff --git a/package.json b/package.json index 3c9feca..4f7ef31 100644 --- a/package.json +++ b/package.json @@ -13,6 +13,7 @@ "ignore": "^7.0.0", "nvim-node": "0.0.2", "openai": "^4.77.0", + "tiktoken": "^1.0.18", "tsx": "^4.19.2", "typescript-eslint": "^8.19.0" },