Skip to content

Commit 793fcea

Browse files
authored
Merge pull request #30 from dlants/prompt-caching
Prompt caching
2 parents 6514498 + 35be07a commit 793fcea

File tree

9 files changed

+295
-495
lines changed

9 files changed

+295
-495
lines changed

node/chat/__snapshots__/chat.spec.ts.snap

+27-476
Large diffs are not rendered by default.

node/chat/chat.spec.ts

+6-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ describe("tea/chat.spec.ts", () => {
3535
expect(
3636
await buffer.getLines({ start: 0, end: -1 }),
3737
"initial render of chat works",
38-
).toEqual(["Stopped (end_turn)"] as Line[]);
38+
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
3939

4040
app.dispatch({
4141
type: "add-message",
@@ -74,7 +74,7 @@ describe("tea/chat.spec.ts", () => {
7474
"Sure, let me use the list_buffers tool.",
7575
"⚙️ Grabbing buffers...",
7676
"",
77-
"Stopped (end_turn)",
77+
"Stopped (end_turn) [input: 0, output: 0]",
7878
] as Line[]);
7979

8080
expect(
@@ -111,7 +111,7 @@ describe("tea/chat.spec.ts", () => {
111111
"Sure, let me use the list_buffers tool.",
112112
"✅ Finished getting buffers.",
113113
"",
114-
"Stopped (end_turn)",
114+
"Stopped (end_turn) [input: 0, output: 0]",
115115
] as Line[]);
116116
});
117117
});
@@ -144,7 +144,7 @@ describe("tea/chat.spec.ts", () => {
144144
expect(
145145
await buffer.getLines({ start: 0, end: -1 }),
146146
"initial render of chat works",
147-
).toEqual(["Stopped (end_turn)"] as Line[]);
147+
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
148148

149149
app.dispatch({
150150
type: "add-message",
@@ -169,7 +169,7 @@ describe("tea/chat.spec.ts", () => {
169169
"# assistant:",
170170
"Sure, let me use the list_buffers tool.",
171171
"",
172-
"Stopped (end_turn)",
172+
"Stopped (end_turn) [input: 0, output: 0]",
173173
] as Line[]);
174174

175175
// expect(
@@ -184,7 +184,7 @@ describe("tea/chat.spec.ts", () => {
184184
expect(
185185
await buffer.getLines({ start: 0, end: -1 }),
186186
"finished render is as expected",
187-
).toEqual(["Stopped (end_turn)"] as Line[]);
187+
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
188188

189189
// expect(
190190
// await extractMountTree(mountedApp.getMountedNode()),

node/chat/chat.ts

+14-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
type ProviderMessageContent,
2121
type ProviderName,
2222
type StopReason,
23+
type Usage,
2324
} from "../providers/provider.ts";
2425
import { assertUnreachable } from "../utils/assertUnreachable.ts";
2526
import { DEFAULT_OPTIONS, type MagentaOptions } from "../options.ts";
@@ -34,6 +35,7 @@ export type ConversationState =
3435
| {
3536
state: "stopped";
3637
stopReason: StopReason;
38+
usage: Usage;
3739
};
3840

3941
export type Model = {
@@ -115,7 +117,11 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
115117
lastUserMessageId: counter.last() as Message.MessageId,
116118
options: DEFAULT_OPTIONS,
117119
activeProvider: "anthropic",
118-
conversation: { state: "stopped", stopReason: "end_turn" },
120+
conversation: {
121+
state: "stopped",
122+
stopReason: "end_turn",
123+
usage: { inputTokens: 0, outputTokens: 0 },
124+
},
119125
messages: [],
120126
toolManager: toolManagerModel.initModel(),
121127
contextManager: contextManagerModel.initModel(),
@@ -509,6 +515,7 @@ ${msg.error.stack}`,
509515
conversation: {
510516
state: "stopped",
511517
stopReason: res?.stopReason || "end_turn",
518+
usage: res?.usage || { inputTokens: 0, outputTokens: 0 },
512519
},
513520
});
514521
}
@@ -549,7 +556,12 @@ ${msg.error.stack}`,
549556
) % MESSAGE_ANIMATION.length
550557
]
551558
}`
552-
: d`Stopped (${model.conversation.stopReason || ""})`
559+
: d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
560+
model.conversation.usage.cacheHits !== undefined &&
561+
model.conversation.usage.cacheMisses !== undefined
562+
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
563+
: ""
564+
}]`
553565
}${
554566
model.conversation.state == "stopped" &&
555567
!contextManagerModel.isContextEmpty(model.contextManager)

node/magenta.spec.ts

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ hello
2222
# assistant:
2323
sup?
2424
25-
Stopped (end_turn)`);
25+
Stopped (end_turn) [input: 0, output: 0]`);
2626

2727
await driver.clear();
28-
await driver.assertDisplayBufferContent(`Stopped (end_turn)`);
28+
await driver.assertDisplayBufferContent(
29+
`Stopped (end_turn) [input: 0, output: 0]`,
30+
);
2931
await driver.inputMagentaText(`hello again`);
3032
await driver.send();
3133
await driver.mockAnthropic.respond({
@@ -41,7 +43,7 @@ hello again
4143
# assistant:
4244
huh?
4345
44-
Stopped (end_turn)`);
46+
Stopped (end_turn) [input: 0, output: 0]`);
4547
});
4648
});
4749

node/providers/anthropic.spec.ts

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import { describe, it, expect } from "vitest";
2+
import { placeCacheBreakpoints } from "./anthropic.ts";
3+
import type { MessageParam } from "./anthropic.ts";
4+
5+
describe("anthropic.ts", () => {
6+
it("placeCacheBreakpoints should add cache markers at appropriate positions", () => {
7+
const messages: MessageParam[] = [
8+
{
9+
role: "user",
10+
content: [
11+
{
12+
type: "text",
13+
text: "a".repeat(4096), // ~1024 tokens
14+
},
15+
{
16+
type: "text",
17+
text: "b".repeat(4096), // Another ~1024 tokens
18+
},
19+
{
20+
type: "text",
21+
text: "c".repeat(8192), // Another ~2048 tokens
22+
},
23+
],
24+
},
25+
];
26+
27+
placeCacheBreakpoints(messages);
28+
29+
expect(messages[0].content[0].cache_control).toBeUndefined();
30+
31+
expect(messages[0].content[1].cache_control).toEqual({ type: "ephemeral" });
32+
33+
expect(messages[0].content[2].cache_control).toEqual({ type: "ephemeral" });
34+
});
35+
36+
it("placeCacheBreakpoints should handle mixed content types", () => {
37+
const messages: MessageParam[] = [
38+
{
39+
role: "user",
40+
content: [
41+
{
42+
type: "text",
43+
text: "a".repeat(4096),
44+
},
45+
{
46+
type: "tool_use",
47+
name: "test_tool",
48+
id: "123",
49+
input: { param: "a".repeat(4096) },
50+
},
51+
],
52+
},
53+
];
54+
55+
placeCacheBreakpoints(messages);
56+
57+
expect(messages[0].content[0].cache_control).toBeUndefined();
58+
expect(messages[0].content[1].cache_control).toEqual({ type: "ephemeral" });
59+
});
60+
61+
it("placeCacheBreakpoints should not add cache markers for small content", () => {
62+
const messages: MessageParam[] = [
63+
{
64+
role: "user",
65+
content: [
66+
{
67+
type: "text",
68+
text: "short message",
69+
},
70+
],
71+
},
72+
];
73+
74+
placeCacheBreakpoints(messages);
75+
76+
expect(messages[0].content[0].cache_control).toBeUndefined();
77+
});
78+
});

node/providers/anthropic.ts

+128-7
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@ import {
66
type StopReason,
77
type Provider,
88
type ProviderMessage,
9+
type Usage,
910
} from "./provider.ts";
1011
import type { ToolRequestId } from "../tools/toolManager.ts";
1112
import { assertUnreachable } from "../utils/assertUnreachable.ts";
1213
import type { MessageStream } from "@anthropic-ai/sdk/lib/MessageStream.mjs";
1314
import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts";
1415

16+
export type MessageParam = Omit<Anthropic.MessageParam, "content"> & {
17+
content: Array<Anthropic.Messages.ContentBlockParam>;
18+
};
19+
1520
export type AnthropicOptions = {
1621
model: "claude-3-5-sonnet-20241022";
1722
};
@@ -49,6 +54,7 @@ export class AnthropicProvider implements Provider {
4954
): Promise<{
5055
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
5156
stopReason: StopReason;
57+
usage: Usage;
5258
}> {
5359
const buf: string[] = [];
5460
let flushInProgress: boolean = false;
@@ -69,10 +75,15 @@ export class AnthropicProvider implements Provider {
6975
}
7076
};
7177

72-
const anthropicMessages = messages.map((m): Anthropic.MessageParam => {
73-
let content: Anthropic.MessageParam["content"];
78+
const anthropicMessages = messages.map((m): MessageParam => {
79+
let content: Anthropic.Messages.ContentBlockParam[];
7480
if (typeof m.content == "string") {
75-
content = m.content;
81+
content = [
82+
{
83+
type: "text",
84+
text: m.content,
85+
},
86+
];
7687
} else {
7788
content = m.content.map((c): Anthropic.ContentBlockParam => {
7889
switch (c.type) {
@@ -105,6 +116,17 @@ export class AnthropicProvider implements Provider {
105116
};
106117
});
107118

119+
placeCacheBreakpoints(anthropicMessages);
120+
121+
const tools: Anthropic.Tool[] = ToolManager.TOOL_SPECS.map(
122+
(t): Anthropic.Tool => {
123+
return {
124+
...t,
125+
input_schema: t.input_schema as Anthropic.Messages.Tool.InputSchema,
126+
};
127+
},
128+
);
129+
108130
try {
109131
this.request = this.client.messages
110132
.stream({
@@ -116,7 +138,7 @@ export class AnthropicProvider implements Provider {
116138
type: "auto",
117139
disable_parallel_tool_use: false,
118140
},
119-
tools: ToolManager.TOOL_SPECS as Anthropic.Tool[],
141+
tools,
120142
})
121143
.on("text", (text: string) => {
122144
buf.push(text);
@@ -203,11 +225,110 @@ export class AnthropicProvider implements Provider {
203225
return extendError(result, { rawRequest: req });
204226
});
205227

206-
this.nvim.logger?.debug("toolRequests: " + JSON.stringify(toolRequests));
207-
this.nvim.logger?.debug("stopReason: " + response.stop_reason);
208-
return { toolRequests, stopReason: response.stop_reason || "end_turn" };
228+
const usage: Usage = {
229+
inputTokens: response.usage.input_tokens,
230+
outputTokens: response.usage.output_tokens,
231+
};
232+
if (response.usage.cache_read_input_tokens) {
233+
usage.cacheHits = response.usage.cache_read_input_tokens;
234+
}
235+
if (response.usage.cache_creation_input_tokens) {
236+
usage.cacheMisses = response.usage.cache_creation_input_tokens;
237+
}
238+
239+
return {
240+
toolRequests,
241+
stopReason: response.stop_reason || "end_turn",
242+
usage,
243+
};
209244
} finally {
210245
this.request = undefined;
211246
}
212247
}
213248
}
249+
250+
export function placeCacheBreakpoints(messages: MessageParam[]) {
251+
// when we scan the messages, keep track of where each part ends.
252+
const blocks: { block: Anthropic.Messages.ContentBlockParam; acc: number }[] =
253+
[];
254+
255+
let lengthAcc = 0;
256+
for (const message of messages) {
257+
for (const block of message.content) {
258+
switch (block.type) {
259+
case "text":
260+
lengthAcc += block.text.length;
261+
break;
262+
case "image":
263+
lengthAcc += block.source.data.length;
264+
break;
265+
case "tool_use":
266+
lengthAcc += JSON.stringify(block.input).length;
267+
break;
268+
case "tool_result":
269+
if (block.content) {
270+
if (typeof block.content == "string") {
271+
lengthAcc += block.content.length;
272+
} else {
273+
let blockLength = 0;
274+
for (const blockContent of block.content) {
275+
switch (blockContent.type) {
276+
case "text":
277+
blockLength += blockContent.text.length;
278+
break;
279+
case "image":
280+
blockLength += blockContent.source.data.length;
281+
break;
282+
}
283+
}
284+
285+
lengthAcc += blockLength;
286+
}
287+
}
288+
break;
289+
case "document":
290+
lengthAcc += block.source.data.length;
291+
}
292+
293+
blocks.push({ block, acc: lengthAcc });
294+
}
295+
}
296+
297+
// estimating 4 characters per token.
298+
const tokens = Math.floor(lengthAcc / STR_CHARS_PER_TOKEN);
299+
300+
// Anthropic allows for placing up to 4 cache control markers.
301+
// It will not cache anythign less than 1024 tokens for sonnet 3.5
302+
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
303+
// this is pretty rough estimate, due to the conversion between string length and tokens.
304+
// however, since we are not accounting for tools or the system prompt, and generally code and technical writing
305+
// tend to have a lower coefficient of string length to tokens (about 3.5 average sting length per token), this means
306+
// that the first cache control should be past the 1024 mark and should be cached.
307+
const powers = highestPowersOfTwo(tokens, 4).filter((n) => n >= 1024);
308+
if (powers.length) {
309+
for (const power of powers) {
310+
const targetLength = power * STR_CHARS_PER_TOKEN; // power is in tokens, but we want string chars instead
311+
// find the first block where we are past the target power
312+
const blockEntry = blocks.find((b) => b.acc > targetLength);
313+
if (blockEntry) {
314+
blockEntry.block.cache_control = { type: "ephemeral" };
315+
}
316+
}
317+
}
318+
}
319+
320+
const STR_CHARS_PER_TOKEN = 4;
321+
322+
export function highestPowersOfTwo(n: number, len: number): number[] {
323+
const result: number[] = [];
324+
let currentPower = Math.floor(Math.log2(n));
325+
326+
while (result.length < len && currentPower >= 0) {
327+
const value = Math.pow(2, currentPower);
328+
if (value <= n) {
329+
result.push(value);
330+
}
331+
currentPower--;
332+
}
333+
return result;
334+
}

0 commit comments

Comments
 (0)