Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt caching improvements, debug #32

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions node/chat/chat.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"initial render of chat works",
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
).toEqual(Chat.LOGO.split("\n") as Line[]);

app.dispatch({
type: "add-message",
Expand Down Expand Up @@ -144,7 +144,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"initial render of chat works",
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
).toEqual(Chat.LOGO.split("\n") as Line[]);

app.dispatch({
type: "add-message",
Expand Down Expand Up @@ -172,10 +172,6 @@ describe("tea/chat.spec.ts", () => {
"Stopped (end_turn) [input: 0, output: 0]",
] as Line[]);

// expect(
// await extractMountTree(mountedApp.getMountedNode()),
// ).toMatchSnapshot();

app.dispatch({
type: "clear",
});
Expand All @@ -184,11 +180,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"finished render is as expected",
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);

// expect(
// await extractMountTree(mountedApp.getMountedNode()),
// ).toMatchSnapshot();
).toEqual(Chat.LOGO.split("\n") as Line[]);
});
});
});
88 changes: 78 additions & 10 deletions node/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import {
type Update,
wrapThunk,
} from "../tea/tea.ts";
import { d, type View } from "../tea/view.ts";
import { d, withBindings, type View } from "../tea/view.ts";
import * as ToolManager from "../tools/toolManager.ts";
import { type Result } from "../utils/result.ts";
import { Counter } from "../utils/uniqueId.ts";
import type { Nvim } from "nvim-node";
import type { Lsp } from "../lsp.ts";
import {
getClient,
getClient as getProvider,
type ProviderMessage,
type ProviderMessageContent,
type ProviderName,
Expand All @@ -24,6 +24,7 @@ import {
} from "../providers/provider.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";
import { DEFAULT_OPTIONS, type MagentaOptions } from "../options.ts";
import { getOption } from "../nvim/nvim.ts";

export type Role = "user" | "assistant";

Expand Down Expand Up @@ -103,6 +104,9 @@ export type Msg =
| {
type: "set-opts";
options: MagentaOptions;
}
| {
type: "show-message-debug-info";
};

export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
Expand Down Expand Up @@ -421,7 +425,7 @@ ${msg.error.stack}`,
model,
// eslint-disable-next-line @typescript-eslint/require-await
async () => {
getClient(nvim, model.activeProvider, model.options).abort();
getProvider(nvim, model.activeProvider, model.options).abort();
},
];
}
Expand All @@ -430,6 +434,10 @@ ${msg.error.stack}`,
return [{ ...model, options: msg.options }];
}

case "show-message-debug-info": {
return [model, () => showDebugInfo(model)];
}

default:
assertUnreachable(msg);
}
Expand Down Expand Up @@ -490,7 +498,7 @@ ${msg.error.stack}`,
});
let res;
try {
res = await getClient(
res = await getProvider(
nvim,
model.activeProvider,
model.options,
Expand Down Expand Up @@ -537,6 +545,13 @@ ${msg.error.stack}`,
model,
dispatch,
}) => {
if (
model.messages.length == 0 &&
Object.keys(model.contextManager.files).length == 0
) {
return d`${LOGO}`;
}

return d`${model.messages.map(
(m, idx) =>
d`${messageModel.view({
Expand All @@ -556,12 +571,17 @@ ${msg.error.stack}`,
) % MESSAGE_ANIMATION.length
]
}`
: d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`
: withBindings(
d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`,
{
"<CR>": () => dispatch({ type: "show-message-debug-info" }),
},
)
}${
model.conversation.state == "stopped" &&
!contextManagerModel.isContextEmpty(model.contextManager)
Expand Down Expand Up @@ -634,10 +654,58 @@ ${msg.error.stack}`,
return messages.map((m) => m.message);
}

async function showDebugInfo(model: Model) {
const messages = await getMessages(model);
const provider = getProvider(nvim, model.activeProvider, model.options);
const params = provider.createStreamParameters(messages);
const nTokens = await provider.countTokens(messages);

// Create a floating window
const bufnr = await nvim.call("nvim_create_buf", [false, true]);
await nvim.call("nvim_buf_set_option", [bufnr, "bufhidden", "wipe"]);
const [editorWidth, editorHeight] = (await Promise.all([
getOption("columns", nvim),
await getOption("lines", nvim),
])) as [number, number];
const width = 80;
const height = editorHeight - 20;
await nvim.call("nvim_open_win", [
bufnr,
true,
{
relative: "editor",
width,
height,
col: Math.floor((editorWidth - width) / 2),
row: Math.floor((editorHeight - height) / 2),
style: "minimal",
border: "single",
},
]);

const lines = JSON.stringify(params, null, 2).split("\n");
lines.push(`nTokens: ${nTokens}`);
await nvim.call("nvim_buf_set_lines", [bufnr, 0, -1, false, lines]);

// Set buffer options
await nvim.call("nvim_buf_set_option", [bufnr, "modifiable", false]);
await nvim.call("nvim_buf_set_option", [bufnr, "filetype", "json"]);
}

return {
initModel,
update,
view,
getMessages,
};
}

export const LOGO = `\

________
╱ ╲
╱ ╱
╱ ╱
╲__╱__╱__╱

# magenta.nvim`;
5 changes: 2 additions & 3 deletions node/magenta.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest";
import { withDriver } from "./test/preamble";
import { pollUntil } from "./utils/async";
import type { Position0Indexed } from "./nvim/window";
import { LOGO } from "./chat/chat";

describe("node/magenta.spec.ts", () => {
it("clear command should work", async () => {
Expand All @@ -25,9 +26,7 @@ sup?
Stopped (end_turn) [input: 0, output: 0]`);

await driver.clear();
await driver.assertDisplayBufferContent(
`Stopped (end_turn) [input: 0, output: 0]`,
);
await driver.assertDisplayBufferContent(LOGO);
await driver.inputMagentaText(`hello again`);
await driver.send();
await driver.mockAnthropic.respond({
Expand Down
117 changes: 76 additions & 41 deletions node/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,9 @@ export class AnthropicProvider implements Provider {
}
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
onError: (error: Error) => void,
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;

const flushBuffer = () => {
if (buf.length && !flushInProgress) {
const text = buf.join("");
buf.splice(0);

flushInProgress = true;

try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

createStreamParameters(
messages: ProviderMessage[],
): Anthropic.Messages.MessageStreamParams {
const anthropicMessages = messages.map((m): MessageParam => {
let content: Anthropic.Messages.ContentBlockParam[];
if (typeof m.content == "string") {
Expand Down Expand Up @@ -116,7 +91,7 @@ export class AnthropicProvider implements Provider {
};
});

placeCacheBreakpoints(anthropicMessages);
const cacheControlItemsPlaced = placeCacheBreakpoints(anthropicMessages);

const tools: Anthropic.Tool[] = ToolManager.TOOL_SPECS.map(
(t): Anthropic.Tool => {
Expand All @@ -127,19 +102,77 @@ export class AnthropicProvider implements Provider {
},
);

return {
messages: anthropicMessages,
model: this.options.model,
max_tokens: 4096,
system: [
{
type: "text",
text: DEFAULT_SYSTEM_PROMPT,
// the prompt appears in the following order:
// tools
// system
// messages
// This ensures the tools + system prompt (which is approx 1400 tokens) is cached.
cache_control:
cacheControlItemsPlaced < 4 ? { type: "ephemeral" } : null,
},
],
tool_choice: {
type: "auto",
disable_parallel_tool_use: false,
},
tools,
};
}

async countTokens(messages: Array<ProviderMessage>): Promise<number> {
const params = this.createStreamParameters(messages);
const lastMessage = params.messages[params.messages.length - 1];
if (!lastMessage || lastMessage.role != "user") {
params.messages.push({ role: "user", content: "test" });
}
const res = await this.client.messages.countTokens({
messages: params.messages,
model: params.model,
system: params.system as Anthropic.TextBlockParam[],
tools: params.tools as Anthropic.Tool[],
});
return res.input_tokens;
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
onError: (error: Error) => void,
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;

const flushBuffer = () => {
if (buf.length && !flushInProgress) {
const text = buf.join("");
buf.splice(0);

flushInProgress = true;

try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

try {
this.request = this.client.messages
.stream({
messages: anthropicMessages,
model: this.options.model,
max_tokens: 4096,
system: DEFAULT_SYSTEM_PROMPT,
tool_choice: {
type: "auto",
disable_parallel_tool_use: false,
},
tools,
})
.stream(this.createStreamParameters(messages))
.on("text", (text: string) => {
buf.push(text);
flushBuffer();
Expand Down Expand Up @@ -247,7 +280,7 @@ export class AnthropicProvider implements Provider {
}
}

export function placeCacheBreakpoints(messages: MessageParam[]) {
export function placeCacheBreakpoints(messages: MessageParam[]): number {
// when we scan the messages, keep track of where each part ends.
const blocks: { block: Anthropic.Messages.ContentBlockParam; acc: number }[] =
[];
Expand Down Expand Up @@ -315,6 +348,8 @@ export function placeCacheBreakpoints(messages: MessageParam[]) {
}
}
}

return powers.length;
}

const STR_CHARS_PER_TOKEN = 4;
Expand Down
9 changes: 9 additions & 0 deletions node/providers/mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ export class MockProvider implements Provider {
}
}

createStreamParameters(messages: Array<ProviderMessage>): unknown {
return messages;
}

// eslint-disable-next-line @typescript-eslint/require-await
async countTokens(messages: Array<ProviderMessage>): Promise<number> {
return messages.length;
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
Expand Down
Loading
Loading