Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edit user message #55

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions node/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import * as Part from "./part.ts";
import * as Message from "./message.ts";
import * as ContextManager from "../context/context-manager.ts";
import {
chainThunks,
type Dispatch,
parallelThunks,
type Thunk,
Expand Down Expand Up @@ -70,6 +71,19 @@ export type Msg =
role: Role;
content?: string;
}
| {
type: "edit-message";
// delete will remove user message and all consecutive ones
// regenerate will regenerate assistant message based on closest previous user message
// edit will move message to the input buffer and remove all consecutive messages
action: "delete" | "regenerate" | "edit";
role: Role;
id: number;
}
| {
type: "prepend-input-buffer";
text: string;
}
| {
type: "stream-response";
text: string;
Expand Down Expand Up @@ -100,6 +114,17 @@ export type Msg =
type: "show-message-debug-info";
};

// Returns the message ID of the last user message or falls back to the counter's last value
function getLastUserMessageId(
lastUserMessage: Message.Model | undefined,
counter: Counter,
): Message.MessageId {
if (lastUserMessage !== undefined) {
return lastUserMessage.id;
}
return counter.last() as Message.MessageId;
}

export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
const counter = new Counter();
const partModel = Part.init({ nvim, lsp });
Expand All @@ -109,7 +134,7 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {

function initModel(): Model {
return {
lastUserMessageId: counter.last() as Message.MessageId,
lastUserMessageId: getLastUserMessageId(undefined, counter),
providerSetting: {
provider: "anthropic",
model: "claude-3-5-sonnet-latest",
Expand Down Expand Up @@ -171,6 +196,81 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
];
}

case "edit-message": {
if (model.conversation.state !== "stopped") return [model];
const messageIndex = model.messages.findIndex((m) => m.id === msg.id);
if (messageIndex === -1) {
return [model];
}

const { role, action } = msg;

switch (action) {
case "delete": {
model.messages = model.messages.slice(0, messageIndex);
const lastUserMessage = model.messages.findLast(
(m) => m.role === "user",
);
model.lastUserMessageId = getLastUserMessageId(
lastUserMessage,
counter,
);
return [model];
}
case "regenerate": {
const searchEndIndex = messageIndex + 1;
// if it's assistant message, then find last previous user message
const lastUserMessageIndex =
role === "user"
? messageIndex
: model.messages.findLastIndex(
(m, i) => m.role === "user" && i < searchEndIndex,
);

if (lastUserMessageIndex === -1) {
nvim.logger?.error(
"Cannot regenerate message, no previous user message found",
);
return [model];
}

model.messages = model.messages.slice(0, lastUserMessageIndex + 1);
model.lastUserMessageId = model.messages[lastUserMessageIndex].id;
return [model, sendMessage(model)];
}

case "edit": {
const { role, id } = msg;
if (role !== "user") return [model];

const messageIndex = model.messages.findIndex((m) => m.id === id);
if (messageIndex === -1) return [model];
const message = model.messages[messageIndex];

const text = message.parts
.filter((p) => p.type == "text")
.map((p) => p.text)
.join("");

const dispatchPrepend: Thunk<Msg> = async (dispatch) =>
Promise.resolve(dispatch({ type: "prepend-input-buffer", text }));
const dispatchDeleteMsg: Thunk<Msg> = async (dispatch) =>
Promise.resolve(
dispatch({ type: "edit-message", role, id, action: "delete" }),
);

return [
model,
chainThunks<Msg>(dispatchPrepend, dispatchDeleteMsg),
];
}

default:
assertUnreachable(action);
}
break;
}

case "conversation-state": {
model.conversation = msg.conversation;
if (msg.conversation.state == "stopped") {
Expand Down Expand Up @@ -259,7 +359,17 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
message: msg.msg.message,
};
}
return [model];
}

if (msg.msg.type === "edit-message") {
const message = msg.msg;
return [
model,
parallelThunks<Msg>(
wrapMessageThunk(msg.idx, messageThunk),
async (dispatch) => Promise.resolve(dispatch(message)),
),
];
}

return [model, wrapMessageThunk(msg.idx, messageThunk)];
Expand Down Expand Up @@ -423,6 +533,11 @@ ${msg.error.stack}`,
return [model, () => showDebugInfo(model)];
}

case "prepend-input-buffer": {
//NOTE: this is handled by the parent component
return [model];
}

default:
assertUnreachable(msg);
}
Expand Down
37 changes: 36 additions & 1 deletion node/chat/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ export type Msg =
| {
type: "init-edit";
filePath: string;
}
| {
type: "edit-message";
action: "delete" | "regenerate" | "edit";
role: Role;
id: number;
};

export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
Expand Down Expand Up @@ -151,6 +157,11 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
return [model];
}

case "edit-message": {
// NOTE: nothing to do, should be handled by parent (chat)
return [model];
}

case "init-edit": {
const edits = model.edits[msg.filePath];
if (!edits) {
Expand Down Expand Up @@ -233,8 +244,10 @@ export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
);
}

const role = createRoleHeader(model, dispatch);

return d`\
# ${model.role}:
${role}
${model.parts.map(
(part, partIdx) =>
d`${partModel.view({
Expand All @@ -253,3 +266,25 @@ ${fileEdits}`

return { update, view };
}

function createRoleHeader(model: Model, dispatch: Dispatch<Msg>) {
const { role, id } = model;
const dispatchEdit =
(action: Extract<Msg, { type: "edit-message" }>["action"]) => () =>
dispatch({ type: "edit-message", action, role, id });
switch (role) {
case "user":
return withBindings(d`# ${role}:`, {
d: dispatchEdit("delete"),
r: dispatchEdit("regenerate"),
e: dispatchEdit("edit"),
});
case "assistant":
return withBindings(d`# ${role}:`, {
r: dispatchEdit("regenerate"),
d: dispatchEdit("delete"),
});
default:
assertUnreachable(role);
}
}
48 changes: 46 additions & 2 deletions node/magenta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
import { getCurrentBuffer, getcwd, getpos, notifyErr } from "./nvim/nvim.ts";
import path from "node:path";
import type { BufNr, Line } from "./nvim/buffer.ts";
import { pos1col1to0 } from "./nvim/window.ts";
import { pos1col1to0, type Position1Indexed } from "./nvim/window.ts";
import { getMarkdownExt } from "./utils/markdown.ts";
import {
DEFAULT_OPTIONS,
Expand Down Expand Up @@ -64,7 +64,15 @@ export class Magenta {
this.chatApp = TEA.createApp({
nvim: this.nvim,
initialModel: this.chatModel.initModel(),
update: (msg, model) => this.chatModel.update(msg, model, { nvim }),
update: (msg, model) => {
if (msg.type === "prepend-input-buffer") {
const text = msg.text;
prependInputBufferHandler(text, this.sidebar, nvim).catch((err) => {
nvim.logger?.debug("error appending input buffer: ", err);
});
}
return this.chatModel.update(msg, model, { nvim });
},
View: this.chatModel.view,
});

Expand Down Expand Up @@ -358,3 +366,39 @@ ${lines.join("\n")}
return magenta;
}
}

async function prependInputBufferHandler(
text: string,
sidebar: Sidebar,
nvim: Nvim,
) {
const inputBuffer = sidebar.state.inputBuffer;
if (!inputBuffer) {
nvim.logger?.debug(`unable to init inputBuffer`);
return;
}
const lines = text.split("\n");
const lastLine = await inputBuffer.getLineCount();
const firstLineText = await inputBuffer.getLines({
start: lastLine - 1,
end: lastLine,
});

//Prepend the text to input bufer. If there is a content separate with a new line.
if (lastLine === 1 && firstLineText[0].trim() === "") {
await inputBuffer.setLines({ start: 0, end: -1, lines: lines as Line[] });
} else {
lines.push("");
await inputBuffer.setLines({ start: 0, end: 0, lines: lines as Line[] });
}

if (sidebar.state.state !== "visible") {
nvim.logger?.debug(`sidebar state is not in 'visible' state`);
return;
}
await sidebar.state.inputWindow.setWindowAsCurrent();
await sidebar.state.inputWindow.setCursor({
row: lastLine,
col: 0,
} as Position1Indexed);
}
4 changes: 4 additions & 0 deletions node/nvim/buffer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ export class NvimBuffer {
return this.nvim.call("nvim_buf_set_name", [this.id, name]);
}

getLineCount(): Promise<number> {
return this.nvim.call("nvim_buf_line_count", [this.id]);
}

static async create(listed: boolean, scratch: boolean, nvim: Nvim) {
const bufNr = (await nvim.call("nvim_create_buf", [
listed,
Expand Down
4 changes: 4 additions & 0 deletions node/nvim/window.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ export class NvimWindow {
return this.nvim.call("nvim_win_set_cursor", [this.id, [pos.row, pos.col]]);
}

async setWindowAsCurrent() {
return this.nvim.call("nvim_set_current_win", [this.id]);
}

zt() {
return this.nvim.call("nvim_exec2", [
`call win_execute(${this.id}, 'normal! zt')`,
Expand Down
2 changes: 1 addition & 1 deletion node/tea/bindings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { type MountedVDOM } from "./view.ts";
import { type Position0Indexed } from "../nvim/window.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";

export const BINDING_KEYS = ["<CR>"] as const;
export const BINDING_KEYS = ["<CR>", "d", "e", "r"] as const;

export type BindingKey = (typeof BINDING_KEYS)[number];
export type Bindings = Partial<{
Expand Down