Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt caching improvements, debug #32

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add stop message debug view, token counting impl
  • Loading branch information
dlants committed Jan 13, 2025
commit a0fd132093024a68eae479b0b8336433e877a9ed
71 changes: 61 additions & 10 deletions node/chat/chat.ts
Original file line number Diff line number Diff line change
@@ -8,14 +8,14 @@ import {
type Update,
wrapThunk,
} from "../tea/tea.ts";
import { d, type View } from "../tea/view.ts";
import { d, withBindings, type View } from "../tea/view.ts";
import * as ToolManager from "../tools/toolManager.ts";
import { type Result } from "../utils/result.ts";
import { Counter } from "../utils/uniqueId.ts";
import type { Nvim } from "nvim-node";
import type { Lsp } from "../lsp.ts";
import {
getClient,
getClient as getProvider,
type ProviderMessage,
type ProviderMessageContent,
type ProviderName,
@@ -24,6 +24,7 @@ import {
} from "../providers/provider.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";
import { DEFAULT_OPTIONS, type MagentaOptions } from "../options.ts";
import { getOption } from "../nvim/nvim.ts";

export type Role = "user" | "assistant";

@@ -103,6 +104,9 @@ export type Msg =
| {
type: "set-opts";
options: MagentaOptions;
}
| {
type: "show-message-debug-info";
};

export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
@@ -421,7 +425,7 @@ ${msg.error.stack}`,
model,
// eslint-disable-next-line @typescript-eslint/require-await
async () => {
getClient(nvim, model.activeProvider, model.options).abort();
getProvider(nvim, model.activeProvider, model.options).abort();
},
];
}
@@ -430,6 +434,10 @@ ${msg.error.stack}`,
return [{ ...model, options: msg.options }];
}

case "show-message-debug-info": {
return [model, () => showDebugInfo(model)];
}

default:
assertUnreachable(msg);
}
@@ -490,7 +498,7 @@ ${msg.error.stack}`,
});
let res;
try {
res = await getClient(
res = await getProvider(
nvim,
model.activeProvider,
model.options,
@@ -562,12 +570,17 @@ ${msg.error.stack}`,
) % MESSAGE_ANIMATION.length
]
}`
: d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`
: withBindings(
d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`,
{
"<CR>": () => dispatch({ type: "show-message-debug-info" }),
},
)
: ""
}${
model.conversation.state == "stopped" &&
@@ -642,6 +655,44 @@ ${msg.error.stack}`,
return messages.map((m) => m.message);
}

async function showDebugInfo(model: Model) {
const messages = await getMessages(model);
const provider = getProvider(nvim, model.activeProvider, model.options);
const params = provider.createStreamParameters(messages);
const nTokens = await provider.countTokens(messages);

// Create a floating window
const bufnr = await nvim.call("nvim_create_buf", [false, true]);
await nvim.call("nvim_buf_set_option", [bufnr, "bufhidden", "wipe"]);
const [editorWidth, editorHeight] = (await Promise.all([
getOption("columns", nvim),
await getOption("lines", nvim),
])) as [number, number];
const width = 80;
const height = editorHeight - 20;
await nvim.call("nvim_open_win", [
bufnr,
true,
{
relative: "editor",
width,
height,
col: Math.floor((editorWidth - width) / 2),
row: Math.floor((editorHeight - height) / 2),
style: "minimal",
border: "single",
},
]);

const lines = JSON.stringify(params, null, 2).split("\n");
lines.push(`nTokens: ${nTokens}`);
await nvim.call("nvim_buf_set_lines", [bufnr, 0, -1, false, lines]);

// Set buffer options
await nvim.call("nvim_buf_set_option", [bufnr, "modifiable", false]);
await nvim.call("nvim_buf_set_option", [bufnr, "filetype", "json"]);
}

return {
initModel,
update,
123 changes: 72 additions & 51 deletions node/providers/anthropic.ts
Original file line number Diff line number Diff line change
@@ -47,34 +47,9 @@ export class AnthropicProvider implements Provider {
}
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
onError: (error: Error) => void,
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;

const flushBuffer = () => {
if (buf.length && !flushInProgress) {
const text = buf.join("");
buf.splice(0);

flushInProgress = true;

try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

createStreamParameters(
messages: ProviderMessage[],
): Anthropic.Messages.MessageStreamParams {
const anthropicMessages = messages.map((m): MessageParam => {
let content: Anthropic.Messages.ContentBlockParam[];
if (typeof m.content == "string") {
@@ -127,31 +102,77 @@ export class AnthropicProvider implements Provider {
},
);

return {
messages: anthropicMessages,
model: this.options.model,
max_tokens: 4096,
system: [
{
type: "text",
text: DEFAULT_SYSTEM_PROMPT,
// the prompt appears in the following order:
// tools
// system
// messages
// This ensures the tools + system prompt (which is approx 1400 tokens) is cached.
cache_control:
cacheControlItemsPlaced < 4 ? { type: "ephemeral" } : null,
},
],
tool_choice: {
type: "auto",
disable_parallel_tool_use: false,
},
tools,
};
}

async countTokens(messages: Array<ProviderMessage>): Promise<number> {
const params = this.createStreamParameters(messages);
const lastMessage = params.messages[params.messages.length - 1];
if (!lastMessage || lastMessage.role != "user") {
params.messages.push({ role: "user", content: "test" });
}
const res = await this.client.messages.countTokens({
messages: params.messages,
model: params.model,
system: params.system as Anthropic.TextBlockParam[],
tools: params.tools as Anthropic.Tool[],
});
return res.input_tokens;
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
onError: (error: Error) => void,
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;

const flushBuffer = () => {
if (buf.length && !flushInProgress) {
const text = buf.join("");
buf.splice(0);

flushInProgress = true;

try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

try {
this.request = this.client.messages
.stream({
messages: anthropicMessages,
model: this.options.model,
max_tokens: 4096,
system: [
{
type: "text",
text: DEFAULT_SYSTEM_PROMPT,
// the prompt appears in the following order:
// tools
// system
// messages
// This ensures the tools + system prompt (which is approx 1400 tokens) is cached.
cache_control:
cacheControlItemsPlaced < 4 ? { type: "ephemeral" } : null,
},
],
tool_choice: {
type: "auto",
disable_parallel_tool_use: false,
},
tools,
})
.stream(this.createStreamParameters(messages))
.on("text", (text: string) => {
buf.push(text);
flushBuffer();
9 changes: 9 additions & 0 deletions node/providers/mock.ts
Original file line number Diff line number Diff line change
@@ -39,6 +39,15 @@ export class MockProvider implements Provider {
}
}

createStreamParameters(messages: Array<ProviderMessage>): unknown {
return messages;
}

// eslint-disable-next-line @typescript-eslint/require-await
async countTokens(messages: Array<ProviderMessage>): Promise<number> {
return messages.length;
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
Loading
Loading