@@ -6,12 +6,17 @@ import {
6
6
type StopReason ,
7
7
type Provider ,
8
8
type ProviderMessage ,
9
+ type Usage ,
9
10
} from "./provider.ts" ;
10
11
import type { ToolRequestId } from "../tools/toolManager.ts" ;
11
12
import { assertUnreachable } from "../utils/assertUnreachable.ts" ;
12
13
import type { MessageStream } from "@anthropic-ai/sdk/lib/MessageStream.mjs" ;
13
14
import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts" ;
14
15
16
+ export type MessageParam = Omit < Anthropic . MessageParam , "content" > & {
17
+ content : Array < Anthropic . Messages . ContentBlockParam > ;
18
+ } ;
19
+
15
20
export type AnthropicOptions = {
16
21
model : "claude-3-5-sonnet-20241022" ;
17
22
} ;
@@ -49,6 +54,7 @@ export class AnthropicProvider implements Provider {
49
54
) : Promise < {
50
55
toolRequests : Result < ToolManager . ToolRequest , { rawRequest : unknown } > [ ] ;
51
56
stopReason : StopReason ;
57
+ usage : Usage ;
52
58
} > {
53
59
const buf : string [ ] = [ ] ;
54
60
let flushInProgress : boolean = false ;
@@ -69,10 +75,15 @@ export class AnthropicProvider implements Provider {
69
75
}
70
76
} ;
71
77
72
- const anthropicMessages = messages . map ( ( m ) : Anthropic . MessageParam => {
73
- let content : Anthropic . MessageParam [ "content" ] ;
78
+ const anthropicMessages = messages . map ( ( m ) : MessageParam => {
79
+ let content : Anthropic . Messages . ContentBlockParam [ ] ;
74
80
if ( typeof m . content == "string" ) {
75
- content = m . content ;
81
+ content = [
82
+ {
83
+ type : "text" ,
84
+ text : m . content ,
85
+ } ,
86
+ ] ;
76
87
} else {
77
88
content = m . content . map ( ( c ) : Anthropic . ContentBlockParam => {
78
89
switch ( c . type ) {
@@ -105,6 +116,17 @@ export class AnthropicProvider implements Provider {
105
116
} ;
106
117
} ) ;
107
118
119
+ placeCacheBreakpoints ( anthropicMessages ) ;
120
+
121
+ const tools : Anthropic . Tool [ ] = ToolManager . TOOL_SPECS . map (
122
+ ( t ) : Anthropic . Tool => {
123
+ return {
124
+ ...t ,
125
+ input_schema : t . input_schema as Anthropic . Messages . Tool . InputSchema ,
126
+ } ;
127
+ } ,
128
+ ) ;
129
+
108
130
try {
109
131
this . request = this . client . messages
110
132
. stream ( {
@@ -116,7 +138,7 @@ export class AnthropicProvider implements Provider {
116
138
type : "auto" ,
117
139
disable_parallel_tool_use : false ,
118
140
} ,
119
- tools : ToolManager . TOOL_SPECS as Anthropic . Tool [ ] ,
141
+ tools,
120
142
} )
121
143
. on ( "text" , ( text : string ) => {
122
144
buf . push ( text ) ;
@@ -203,11 +225,110 @@ export class AnthropicProvider implements Provider {
203
225
return extendError ( result , { rawRequest : req } ) ;
204
226
} ) ;
205
227
206
- this . nvim . logger ?. debug ( "toolRequests: " + JSON . stringify ( toolRequests ) ) ;
207
- this . nvim . logger ?. debug ( "stopReason: " + response . stop_reason ) ;
208
- return { toolRequests, stopReason : response . stop_reason || "end_turn" } ;
228
+ const usage : Usage = {
229
+ inputTokens : response . usage . input_tokens ,
230
+ outputTokens : response . usage . output_tokens ,
231
+ } ;
232
+ if ( response . usage . cache_read_input_tokens ) {
233
+ usage . cacheHits = response . usage . cache_read_input_tokens ;
234
+ }
235
+ if ( response . usage . cache_creation_input_tokens ) {
236
+ usage . cacheMisses = response . usage . cache_creation_input_tokens ;
237
+ }
238
+
239
+ return {
240
+ toolRequests,
241
+ stopReason : response . stop_reason || "end_turn" ,
242
+ usage,
243
+ } ;
209
244
} finally {
210
245
this . request = undefined ;
211
246
}
212
247
}
213
248
}
249
+
250
+ export function placeCacheBreakpoints ( messages : MessageParam [ ] ) {
251
+ // when we scan the messages, keep track of where each part ends.
252
+ const blocks : { block : Anthropic . Messages . ContentBlockParam ; acc : number } [ ] =
253
+ [ ] ;
254
+
255
+ let lengthAcc = 0 ;
256
+ for ( const message of messages ) {
257
+ for ( const block of message . content ) {
258
+ switch ( block . type ) {
259
+ case "text" :
260
+ lengthAcc += block . text . length ;
261
+ break ;
262
+ case "image" :
263
+ lengthAcc += block . source . data . length ;
264
+ break ;
265
+ case "tool_use" :
266
+ lengthAcc += JSON . stringify ( block . input ) . length ;
267
+ break ;
268
+ case "tool_result" :
269
+ if ( block . content ) {
270
+ if ( typeof block . content == "string" ) {
271
+ lengthAcc += block . content . length ;
272
+ } else {
273
+ let blockLength = 0 ;
274
+ for ( const blockContent of block . content ) {
275
+ switch ( blockContent . type ) {
276
+ case "text" :
277
+ blockLength += blockContent . text . length ;
278
+ break ;
279
+ case "image" :
280
+ blockLength += blockContent . source . data . length ;
281
+ break ;
282
+ }
283
+ }
284
+
285
+ lengthAcc += blockLength ;
286
+ }
287
+ }
288
+ break ;
289
+ case "document" :
290
+ lengthAcc += block . source . data . length ;
291
+ }
292
+
293
+ blocks . push ( { block, acc : lengthAcc } ) ;
294
+ }
295
+ }
296
+
297
+ // estimating 4 characters per token.
298
+ const tokens = Math . floor ( lengthAcc / STR_CHARS_PER_TOKEN ) ;
299
+
300
+ // Anthropic allows for placing up to 4 cache control markers.
301
+ // It will not cache anythign less than 1024 tokens for sonnet 3.5
302
+ // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
303
+ // this is pretty rough estimate, due to the conversion between string length and tokens.
304
+ // however, since we are not accounting for tools or the system prompt, and generally code and technical writing
305
+ // tend to have a lower coefficient of string length to tokens (about 3.5 average sting length per token), this means
306
+ // that the first cache control should be past the 1024 mark and should be cached.
307
+ const powers = highestPowersOfTwo ( tokens , 4 ) . filter ( ( n ) => n >= 1024 ) ;
308
+ if ( powers . length ) {
309
+ for ( const power of powers ) {
310
+ const targetLength = power * STR_CHARS_PER_TOKEN ; // power is in tokens, but we want string chars instead
311
+ // find the first block where we are past the target power
312
+ const blockEntry = blocks . find ( ( b ) => b . acc > targetLength ) ;
313
+ if ( blockEntry ) {
314
+ blockEntry . block . cache_control = { type : "ephemeral" } ;
315
+ }
316
+ }
317
+ }
318
+ }
319
+
320
+ const STR_CHARS_PER_TOKEN = 4 ;
321
+
322
+ export function highestPowersOfTwo ( n : number , len : number ) : number [ ] {
323
+ const result : number [ ] = [ ] ;
324
+ let currentPower = Math . floor ( Math . log2 ( n ) ) ;
325
+
326
+ while ( result . length < len && currentPower >= 0 ) {
327
+ const value = Math . pow ( 2 , currentPower ) ;
328
+ if ( value <= n ) {
329
+ result . push ( value ) ;
330
+ }
331
+ currentPower -- ;
332
+ }
333
+ return result ;
334
+ }
0 commit comments