Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt caching #30

Merged
merged 5 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
503 changes: 27 additions & 476 deletions node/chat/__snapshots__/chat.spec.ts.snap

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions node/chat/chat.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"initial render of chat works",
).toEqual(["Stopped (end_turn)"] as Line[]);
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);

app.dispatch({
type: "add-message",
Expand Down Expand Up @@ -74,7 +74,7 @@ describe("tea/chat.spec.ts", () => {
"Sure, let me use the list_buffers tool.",
"⚙️ Grabbing buffers...",
"",
"Stopped (end_turn)",
"Stopped (end_turn) [input: 0, output: 0]",
] as Line[]);

expect(
Expand Down Expand Up @@ -111,7 +111,7 @@ describe("tea/chat.spec.ts", () => {
"Sure, let me use the list_buffers tool.",
"✅ Finished getting buffers.",
"",
"Stopped (end_turn)",
"Stopped (end_turn) [input: 0, output: 0]",
] as Line[]);
});
});
Expand Down Expand Up @@ -144,7 +144,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"initial render of chat works",
).toEqual(["Stopped (end_turn)"] as Line[]);
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);

app.dispatch({
type: "add-message",
Expand All @@ -169,7 +169,7 @@ describe("tea/chat.spec.ts", () => {
"# assistant:",
"Sure, let me use the list_buffers tool.",
"",
"Stopped (end_turn)",
"Stopped (end_turn) [input: 0, output: 0]",
] as Line[]);

// expect(
Expand All @@ -184,7 +184,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"finished render is as expected",
).toEqual(["Stopped (end_turn)"] as Line[]);
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);

// expect(
// await extractMountTree(mountedApp.getMountedNode()),
Expand Down
16 changes: 14 additions & 2 deletions node/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
type ProviderMessageContent,
type ProviderName,
type StopReason,
type Usage,
} from "../providers/provider.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";
import { DEFAULT_OPTIONS, type MagentaOptions } from "../options.ts";
Expand All @@ -34,6 +35,7 @@ export type ConversationState =
| {
state: "stopped";
stopReason: StopReason;
usage: Usage;
};

export type Model = {
Expand Down Expand Up @@ -115,7 +117,11 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
lastUserMessageId: counter.last() as Message.MessageId,
options: DEFAULT_OPTIONS,
activeProvider: "anthropic",
conversation: { state: "stopped", stopReason: "end_turn" },
conversation: {
state: "stopped",
stopReason: "end_turn",
usage: { inputTokens: 0, outputTokens: 0 },
},
messages: [],
toolManager: toolManagerModel.initModel(),
contextManager: contextManagerModel.initModel(),
Expand Down Expand Up @@ -509,6 +515,7 @@ ${msg.error.stack}`,
conversation: {
state: "stopped",
stopReason: res?.stopReason || "end_turn",
usage: res?.usage || { inputTokens: 0, outputTokens: 0 },
},
});
}
Expand Down Expand Up @@ -549,7 +556,12 @@ ${msg.error.stack}`,
) % MESSAGE_ANIMATION.length
]
}`
: d`Stopped (${model.conversation.stopReason || ""})`
: d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`
}${
model.conversation.state == "stopped" &&
!contextManagerModel.isContextEmpty(model.contextManager)
Expand Down
8 changes: 5 additions & 3 deletions node/magenta.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ hello
# assistant:
sup?

Stopped (end_turn)`);
Stopped (end_turn) [input: 0, output: 0]`);

await driver.clear();
await driver.assertDisplayBufferContent(`Stopped (end_turn)`);
await driver.assertDisplayBufferContent(
`Stopped (end_turn) [input: 0, output: 0]`,
);
await driver.inputMagentaText(`hello again`);
await driver.send();
await driver.mockAnthropic.respond({
Expand All @@ -41,7 +43,7 @@ hello again
# assistant:
huh?

Stopped (end_turn)`);
Stopped (end_turn) [input: 0, output: 0]`);
});
});

Expand Down
78 changes: 78 additions & 0 deletions node/providers/anthropic.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { describe, it, expect } from "vitest";
import { placeCacheBreakpoints } from "./anthropic.ts";
import type { MessageParam } from "./anthropic.ts";

describe("anthropic.ts", () => {
it("placeCacheBreakpoints should add cache markers at appropriate positions", () => {
const messages: MessageParam[] = [
{
role: "user",
content: [
{
type: "text",
text: "a".repeat(4096), // ~1024 tokens
},
{
type: "text",
text: "b".repeat(4096), // Another ~1024 tokens
},
{
type: "text",
text: "c".repeat(8192), // Another ~2048 tokens
},
],
},
];

placeCacheBreakpoints(messages);

expect(messages[0].content[0].cache_control).toBeUndefined();

expect(messages[0].content[1].cache_control).toEqual({ type: "ephemeral" });

expect(messages[0].content[2].cache_control).toEqual({ type: "ephemeral" });
});

it("placeCacheBreakpoints should handle mixed content types", () => {
const messages: MessageParam[] = [
{
role: "user",
content: [
{
type: "text",
text: "a".repeat(4096),
},
{
type: "tool_use",
name: "test_tool",
id: "123",
input: { param: "a".repeat(4096) },
},
],
},
];

placeCacheBreakpoints(messages);

expect(messages[0].content[0].cache_control).toBeUndefined();
expect(messages[0].content[1].cache_control).toEqual({ type: "ephemeral" });
});

it("placeCacheBreakpoints should not add cache markers for small content", () => {
const messages: MessageParam[] = [
{
role: "user",
content: [
{
type: "text",
text: "short message",
},
],
},
];

placeCacheBreakpoints(messages);

expect(messages[0].content[0].cache_control).toBeUndefined();
});
});
135 changes: 128 additions & 7 deletions node/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ import {
type StopReason,
type Provider,
type ProviderMessage,
type Usage,
} from "./provider.ts";
import type { ToolRequestId } from "../tools/toolManager.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";
import type { MessageStream } from "@anthropic-ai/sdk/lib/MessageStream.mjs";
import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts";

export type MessageParam = Omit<Anthropic.MessageParam, "content"> & {
content: Array<Anthropic.Messages.ContentBlockParam>;
};

export type AnthropicOptions = {
model: "claude-3-5-sonnet-20241022";
};
Expand Down Expand Up @@ -49,6 +54,7 @@ export class AnthropicProvider implements Provider {
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;
Expand All @@ -69,10 +75,15 @@ export class AnthropicProvider implements Provider {
}
};

const anthropicMessages = messages.map((m): Anthropic.MessageParam => {
let content: Anthropic.MessageParam["content"];
const anthropicMessages = messages.map((m): MessageParam => {
let content: Anthropic.Messages.ContentBlockParam[];
if (typeof m.content == "string") {
content = m.content;
content = [
{
type: "text",
text: m.content,
},
];
} else {
content = m.content.map((c): Anthropic.ContentBlockParam => {
switch (c.type) {
Expand Down Expand Up @@ -105,6 +116,17 @@ export class AnthropicProvider implements Provider {
};
});

placeCacheBreakpoints(anthropicMessages);

const tools: Anthropic.Tool[] = ToolManager.TOOL_SPECS.map(
(t): Anthropic.Tool => {
return {
...t,
input_schema: t.input_schema as Anthropic.Messages.Tool.InputSchema,
};
},
);

try {
this.request = this.client.messages
.stream({
Expand All @@ -116,7 +138,7 @@ export class AnthropicProvider implements Provider {
type: "auto",
disable_parallel_tool_use: false,
},
tools: ToolManager.TOOL_SPECS as Anthropic.Tool[],
tools,
})
.on("text", (text: string) => {
buf.push(text);
Expand Down Expand Up @@ -203,11 +225,110 @@ export class AnthropicProvider implements Provider {
return extendError(result, { rawRequest: req });
});

this.nvim.logger?.debug("toolRequests: " + JSON.stringify(toolRequests));
this.nvim.logger?.debug("stopReason: " + response.stop_reason);
return { toolRequests, stopReason: response.stop_reason || "end_turn" };
const usage: Usage = {
inputTokens: response.usage.input_tokens,
outputTokens: response.usage.output_tokens,
};
if (response.usage.cache_read_input_tokens) {
usage.cacheHits = response.usage.cache_read_input_tokens;
}
if (response.usage.cache_creation_input_tokens) {
usage.cacheMisses = response.usage.cache_creation_input_tokens;
}

return {
toolRequests,
stopReason: response.stop_reason || "end_turn",
usage,
};
} finally {
this.request = undefined;
}
}
}

export function placeCacheBreakpoints(messages: MessageParam[]) {
// when we scan the messages, keep track of where each part ends.
const blocks: { block: Anthropic.Messages.ContentBlockParam; acc: number }[] =
[];

let lengthAcc = 0;
for (const message of messages) {
for (const block of message.content) {
switch (block.type) {
case "text":
lengthAcc += block.text.length;
break;
case "image":
lengthAcc += block.source.data.length;
break;
case "tool_use":
lengthAcc += JSON.stringify(block.input).length;
break;
case "tool_result":
if (block.content) {
if (typeof block.content == "string") {
lengthAcc += block.content.length;
} else {
let blockLength = 0;
for (const blockContent of block.content) {
switch (blockContent.type) {
case "text":
blockLength += blockContent.text.length;
break;
case "image":
blockLength += blockContent.source.data.length;
break;
}
}

lengthAcc += blockLength;
}
}
break;
case "document":
lengthAcc += block.source.data.length;
}

blocks.push({ block, acc: lengthAcc });
}
}

// estimating 4 characters per token.
const tokens = Math.floor(lengthAcc / STR_CHARS_PER_TOKEN);

// Anthropic allows for placing up to 4 cache control markers.
// It will not cache anythign less than 1024 tokens for sonnet 3.5
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
// this is pretty rough estimate, due to the conversion between string length and tokens.
// however, since we are not accounting for tools or the system prompt, and generally code and technical writing
// tend to have a lower coefficient of string length to tokens (about 3.5 average sting length per token), this means
// that the first cache control should be past the 1024 mark and should be cached.
const powers = highestPowersOfTwo(tokens, 4).filter((n) => n >= 1024);
if (powers.length) {
for (const power of powers) {
const targetLength = power * STR_CHARS_PER_TOKEN; // power is in tokens, but we want string chars instead
// find the first block where we are past the target power
const blockEntry = blocks.find((b) => b.acc > targetLength);
if (blockEntry) {
blockEntry.block.cache_control = { type: "ephemeral" };
}
}
}
}

const STR_CHARS_PER_TOKEN = 4;

export function highestPowersOfTwo(n: number, len: number): number[] {
const result: number[] = [];
let currentPower = Math.floor(Math.log2(n));

while (result.length < len && currentPower >= 0) {
const value = Math.pow(2, currentPower);
if (value <= n) {
result.push(value);
}
currentPower--;
}
return result;
}
Loading
Loading