Skip to content

Commit e4f8ed8

Browse files
authored
Merge pull request #24 from dlants/context-message-order
Context message order
2 parents ccf6b55 + 60e9927 commit e4f8ed8

13 files changed

+561
-97
lines changed

bun/chat/chat.ts

+63-22
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import {
1111
import { d, type View } from "../tea/view.ts";
1212
import * as ToolManager from "../tools/toolManager.ts";
1313
import { type Result } from "../utils/result.ts";
14-
import { IdCounter } from "../utils/uniqueId.ts";
14+
import { Counter } from "../utils/uniqueId.ts";
1515
import type { Nvim } from "bunvim";
1616
import type { Lsp } from "../lsp.ts";
1717
import {
@@ -37,6 +37,7 @@ export type ConversationState =
3737
};
3838

3939
export type Model = {
40+
lastUserMessageId: Message.MessageId;
4041
activeProvider: ProviderName;
4142
options: MagentaOptions;
4243
conversation: ConversationState;
@@ -58,6 +59,11 @@ export type Msg =
5859
type: "context-manager-msg";
5960
msg: ContextManager.Msg;
6061
}
62+
| {
63+
type: "add-file-context";
64+
absFilePath: string;
65+
relFilePath: string;
66+
}
6167
| {
6268
type: "add-message";
6369
role: Role;
@@ -98,14 +104,15 @@ export type Msg =
98104
};
99105

100106
export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
101-
const idCounter = new IdCounter("message_");
107+
const counter = new Counter();
102108
const partModel = Part.init({ nvim, lsp });
103109
const toolManagerModel = ToolManager.init({ nvim, lsp });
104110
const contextManagerModel = ContextManager.init({ nvim });
105111
const messageModel = Message.init({ nvim, lsp });
106112

107113
function initModel(): Model {
108114
return {
115+
lastUserMessageId: counter.last() as Message.MessageId,
109116
options: DEFAULT_OPTIONS,
110117
activeProvider: "anthropic",
111118
conversation: { state: "stopped", stopReason: "end_turn" },
@@ -134,12 +141,16 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
134141
return [{ ...model, activeProvider: msg.provider }];
135142
case "add-message": {
136143
let message: Message.Model = {
137-
id: idCounter.get() as Message.MessageId,
144+
id: counter.get() as Message.MessageId,
138145
role: msg.role,
139146
parts: [],
140147
edits: {},
141148
};
142149

150+
if (message.role == "user") {
151+
model.lastUserMessageId = message.id;
152+
}
153+
143154
let messageThunk;
144155
if (msg.content) {
145156
const [next, thunk] = messageModel.update(
@@ -231,7 +242,7 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
231242
const lastMessage = model.messages[model.messages.length - 1];
232243
if (lastMessage?.role !== "assistant") {
233244
model.messages.push({
234-
id: idCounter.get() as Message.MessageId,
245+
id: counter.get() as Message.MessageId,
235246
role: "assistant",
236247
parts: [],
237248
edits: {},
@@ -255,7 +266,7 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
255266
const lastMessage = model.messages[model.messages.length - 1];
256267
if (lastMessage?.role !== "assistant") {
257268
model.messages.push({
258-
id: idCounter.get() as Message.MessageId,
269+
id: counter.get() as Message.MessageId,
259270
role: "assistant",
260271
parts: [],
261272
edits: {},
@@ -283,7 +294,7 @@ ${msg.error.stack}`,
283294
const lastMessage = model.messages[model.messages.length - 1];
284295
if (lastMessage?.role !== "assistant") {
285296
model.messages.push({
286-
id: idCounter.get() as Message.MessageId,
297+
id: counter.get() as Message.MessageId,
287298
role: "assistant",
288299
parts: [],
289300
edits: {},
@@ -363,6 +374,24 @@ ${msg.error.stack}`,
363374
];
364375
}
365376

377+
case "add-file-context": {
378+
const [nextContextManager, contextManagerThunk] =
379+
contextManagerModel.update(
380+
{
381+
type: "add-file-context",
382+
absFilePath: msg.absFilePath,
383+
relFilePath: msg.relFilePath,
384+
messageId: model.lastUserMessageId,
385+
},
386+
model.contextManager,
387+
);
388+
model.contextManager = nextContextManager;
389+
return [
390+
model,
391+
parallelThunks(wrapThunk("context-manager-msg", contextManagerThunk)),
392+
];
393+
}
394+
366395
case "clear": {
367396
return [initModel()];
368397
}
@@ -519,19 +548,7 @@ ${msg.error.stack}`,
519548
};
520549

521550
async function getMessages(model: Model): Promise<ProviderMessage[]> {
522-
const messages = [];
523-
const contextMessage = await contextManagerModel.getContextMessage(
524-
model.contextManager,
525-
);
526-
527-
if (contextMessage) {
528-
nvim.logger?.debug(
529-
`Got context message: ${JSON.stringify(contextMessage)}`,
530-
);
531-
messages.push(contextMessage);
532-
}
533-
534-
const rest = model.messages.flatMap((msg) => {
551+
const messages = model.messages.flatMap((msg) => {
535552
const messageContent: ProviderMessageContent[] = [];
536553
const toolResponseContent: ProviderMessageContent[] = [];
537554

@@ -580,11 +597,35 @@ ${msg.error.stack}`,
580597
});
581598
}
582599

583-
return out;
600+
return out.map((m) => ({
601+
message: m,
602+
messageId: msg.id,
603+
}));
584604
});
585605

586-
messages.push(...rest);
587-
return messages;
606+
const contextMessages = await contextManagerModel.getContextMessages(
607+
counter.last() as Message.MessageId,
608+
model.contextManager,
609+
);
610+
611+
if (contextMessages) {
612+
nvim.logger?.debug(
613+
`Got context messages: ${JSON.stringify(contextMessages)}`,
614+
);
615+
616+
for (const contextMessage of contextMessages) {
617+
// we want to insert the contextMessage before the corresponding user message
618+
let idx = messages.findIndex(
619+
(m) => m.messageId >= contextMessage.messageId,
620+
);
621+
if (idx == -1) {
622+
idx = messages.length;
623+
}
624+
messages.splice(idx, 0, contextMessage);
625+
}
626+
}
627+
628+
return messages.map((m) => m.message);
588629
}
589630

590631
return {

bun/chat/message.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { displayDiffs } from "../tools/diff.ts";
88
import type { Lsp } from "../lsp.ts";
99
import type { Nvim } from "bunvim";
1010

11-
export type MessageId = string & { __messageId: true };
11+
export type MessageId = number & { __messageId: true };
1212
export type Model = {
1313
id: MessageId;
1414
role: Role;
@@ -160,7 +160,7 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
160160
await displayDiffs({
161161
context: { nvim },
162162
filePath: msg.filePath,
163-
diffId: model.id,
163+
diffId: `message_${model.id}`,
164164
edits: edits.requestIds.map((requestId) => {
165165
const toolWrapper = toolManager.toolWrappers[requestId];
166166
if (!toolWrapper) {

bun/context/context-manager.ts

+50-48
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
import fs from "node:fs";
2-
import path from "node:path";
31
import { d, withBindings, type View } from "../tea/view";
42
import type { Dispatch, Update } from "../tea/tea";
53
import { assertUnreachable } from "../utils/assertUnreachable";
64
import type { ProviderMessage } from "../providers/provider";
7-
import { getcwd } from "../nvim/nvim";
85
import type { Nvim } from "bunvim";
9-
import { getBufferIfOpen } from "../utils/buffers";
6+
import type { MessageId } from "../chat/message";
7+
import { BufferAndFileManager } from "./file-and-buffer-manager";
108

119
export type Model = {
1210
files: {
13-
[absFilePath: string]: { relFilePath: string };
11+
[absFilePath: string]: {
12+
relFilePath: string;
13+
initialMessageId: MessageId;
14+
};
1415
};
1516
};
1617

@@ -19,13 +20,16 @@ export type Msg =
1920
type: "add-file-context";
2021
relFilePath: string;
2122
absFilePath: string;
23+
messageId: MessageId;
2224
}
2325
| {
2426
type: "remove-file-context";
2527
absFilePath: string;
2628
};
2729

2830
export function init({ nvim }: { nvim: Nvim }) {
31+
const bufferAndFileManager = new BufferAndFileManager(nvim);
32+
2933
function initModel(): Model {
3034
return {
3135
files: {},
@@ -39,7 +43,10 @@ export function init({ nvim }: { nvim: Nvim }) {
3943
...model,
4044
files: {
4145
...model.files,
42-
[msg.absFilePath]: { relFilePath: msg.relFilePath },
46+
[msg.absFilePath]: {
47+
relFilePath: msg.relFilePath,
48+
initialMessageId: msg.messageId,
49+
},
4350
},
4451
},
4552
];
@@ -73,57 +80,56 @@ ${fileContext}`;
7380
return Object.keys(model.files).length == 0;
7481
}
7582

76-
async function getContextMessage(
83+
async function getContextMessages(
84+
currentMessageId: MessageId,
7785
model: Model,
78-
): Promise<ProviderMessage | undefined> {
86+
): Promise<{ messageId: MessageId; message: ProviderMessage }[] | undefined> {
7987
if (isContextEmpty(model)) {
8088
return undefined;
8189
}
8290

83-
const cwd = await getcwd(nvim);
84-
const fileContents = await Promise.all(
91+
return await Promise.all(
8592
Object.keys(model.files).map((absFilePath) =>
86-
getFileContents({ absFilePath, cwd }),
93+
getFileMessage({ absFilePath, currentMessageId }),
8794
),
8895
);
89-
90-
return {
91-
role: "user",
92-
content: `${FILE_PROMPT}
93-
94-
${fileContents.join("\n\n")}`,
95-
};
9696
}
9797

98-
async function getFileContents({
98+
async function getFileMessage({
9999
absFilePath,
100-
cwd,
100+
currentMessageId,
101101
}: {
102102
absFilePath: string;
103-
cwd: string;
104-
}): Promise<string> {
105-
const relativePath = path.relative(cwd, absFilePath);
106-
const bufferContents = await getBufferIfOpen({
107-
relativePath,
108-
context: { nvim },
109-
});
110-
111-
if (bufferContents.status == "ok") {
112-
return renderFile({
113-
relFilePath: relativePath,
114-
content: bufferContents.result,
115-
});
116-
} else if (bufferContents.status == "error") {
117-
return `\
118-
Error trying to read file \`${relativePath}\`: ${bufferContents.error}`;
119-
}
103+
currentMessageId: MessageId;
104+
}): Promise<{ messageId: MessageId; message: ProviderMessage }> {
105+
const res = await bufferAndFileManager.getFileContents(
106+
absFilePath,
107+
currentMessageId,
108+
);
120109

121-
try {
122-
const fileContent = await fs.promises.readFile(absFilePath, "utf-8");
123-
return renderFile({ relFilePath: relativePath, content: fileContent });
124-
} catch (error) {
125-
return `\
126-
Error trying to read file \`${relativePath}\`: ${(error as Error).message}`;
110+
switch (res.status) {
111+
case "ok":
112+
return {
113+
messageId: res.value.messageId,
114+
message: {
115+
role: "user",
116+
content: renderFile({
117+
relFilePath: res.value.relFilePath,
118+
content: res.value.content,
119+
}),
120+
},
121+
};
122+
123+
case "error":
124+
return {
125+
messageId: currentMessageId,
126+
message: {
127+
role: "user",
128+
content: `Error reading file \`${absFilePath}\`: ${res.error}`,
129+
},
130+
};
131+
default:
132+
assertUnreachable(res);
127133
}
128134
}
129135

@@ -146,10 +152,6 @@ ${content}
146152
initModel,
147153
update,
148154
view,
149-
getContextMessage,
155+
getContextMessages,
150156
};
151157
}
152-
153-
export const FILE_PROMPT = `Files.
154-
This is the most up-to-date content of these files.
155-
Any other mentions of code or snippets from these files may be out of date.`;

0 commit comments

Comments
 (0)