wip: session revert/unrevert
This commit is contained in:
@@ -34,6 +34,7 @@ import type { ModelsDev } from "../provider/models"
|
|||||||
import { Installation } from "../installation"
|
import { Installation } from "../installation"
|
||||||
import { Config } from "../config/config"
|
import { Config } from "../config/config"
|
||||||
import { ProviderTransform } from "../provider/transform"
|
import { ProviderTransform } from "../provider/transform"
|
||||||
|
import { Snapshot } from "../snapshot"
|
||||||
|
|
||||||
export namespace Session {
|
export namespace Session {
|
||||||
const log = Log.create({ service: "session" })
|
const log = Log.create({ service: "session" })
|
||||||
@@ -53,6 +54,13 @@ export namespace Session {
|
|||||||
created: z.number(),
|
created: z.number(),
|
||||||
updated: z.number(),
|
updated: z.number(),
|
||||||
}),
|
}),
|
||||||
|
revert: z
|
||||||
|
.object({
|
||||||
|
messageID: z.string(),
|
||||||
|
part: z.number(),
|
||||||
|
snapshot: z.string().optional(),
|
||||||
|
})
|
||||||
|
.optional(),
|
||||||
})
|
})
|
||||||
.openapi({
|
.openapi({
|
||||||
ref: "Session",
|
ref: "Session",
|
||||||
@@ -285,6 +293,37 @@ export namespace Session {
|
|||||||
l.info("chatting")
|
l.info("chatting")
|
||||||
const model = await Provider.getModel(input.providerID, input.modelID)
|
const model = await Provider.getModel(input.providerID, input.modelID)
|
||||||
let msgs = await messages(input.sessionID)
|
let msgs = await messages(input.sessionID)
|
||||||
|
const session = await get(input.sessionID)
|
||||||
|
|
||||||
|
if (session.revert) {
|
||||||
|
const trimmed = []
|
||||||
|
for (const msg of msgs) {
|
||||||
|
if (
|
||||||
|
msg.id > session.revert.messageID ||
|
||||||
|
(msg.id === session.revert.messageID && session.revert.part === 0)
|
||||||
|
) {
|
||||||
|
await Storage.remove(
|
||||||
|
"session/message/" + input.sessionID + "/" + msg.id,
|
||||||
|
)
|
||||||
|
await Bus.publish(Message.Event.Removed, {
|
||||||
|
sessionID: input.sessionID,
|
||||||
|
messageID: msg.id,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (msg.id === session.revert.messageID) {
|
||||||
|
if (session.revert.part === 0) break
|
||||||
|
msg.parts = msg.parts.slice(0, session.revert.part)
|
||||||
|
}
|
||||||
|
trimmed.push(msg)
|
||||||
|
}
|
||||||
|
msgs = trimmed
|
||||||
|
await update(input.sessionID, (draft) => {
|
||||||
|
draft.revert = undefined
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const previous = msgs.at(-1)
|
const previous = msgs.at(-1)
|
||||||
|
|
||||||
// auto summarize if too long
|
// auto summarize if too long
|
||||||
@@ -319,7 +358,6 @@ export namespace Session {
|
|||||||
if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
|
if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
|
||||||
|
|
||||||
const app = App.info()
|
const app = App.info()
|
||||||
const session = await get(input.sessionID)
|
|
||||||
if (msgs.length === 0 && !session.parentID) {
|
if (msgs.length === 0 && !session.parentID) {
|
||||||
generateText({
|
generateText({
|
||||||
maxTokens: input.providerID === "google" ? 1024 : 20,
|
maxTokens: input.providerID === "google" ? 1024 : 20,
|
||||||
@@ -349,6 +387,7 @@ export namespace Session {
|
|||||||
})
|
})
|
||||||
.catch(() => {})
|
.catch(() => {})
|
||||||
}
|
}
|
||||||
|
const snapshot = await Snapshot.create(input.sessionID)
|
||||||
const msg: Message.Info = {
|
const msg: Message.Info = {
|
||||||
role: "user",
|
role: "user",
|
||||||
id: Identifier.ascending("message"),
|
id: Identifier.ascending("message"),
|
||||||
@@ -359,6 +398,7 @@ export namespace Session {
|
|||||||
},
|
},
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
tool: {},
|
tool: {},
|
||||||
|
snapshot,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
await updateMessage(msg)
|
await updateMessage(msg)
|
||||||
@@ -373,6 +413,7 @@ export namespace Session {
|
|||||||
role: "assistant",
|
role: "assistant",
|
||||||
parts: [],
|
parts: [],
|
||||||
metadata: {
|
metadata: {
|
||||||
|
snapshot,
|
||||||
assistant: {
|
assistant: {
|
||||||
system,
|
system,
|
||||||
path: {
|
path: {
|
||||||
@@ -424,6 +465,7 @@ export namespace Session {
|
|||||||
})
|
})
|
||||||
next.metadata!.tool![opts.toolCallId] = {
|
next.metadata!.tool![opts.toolCallId] = {
|
||||||
...result.metadata,
|
...result.metadata,
|
||||||
|
snapshot: await Snapshot.create(input.sessionID),
|
||||||
time: {
|
time: {
|
||||||
start,
|
start,
|
||||||
end: Date.now(),
|
end: Date.now(),
|
||||||
@@ -436,6 +478,7 @@ export namespace Session {
|
|||||||
error: true,
|
error: true,
|
||||||
message: e.toString(),
|
message: e.toString(),
|
||||||
title: e.toString(),
|
title: e.toString(),
|
||||||
|
snapshot: await Snapshot.create(input.sessionID),
|
||||||
time: {
|
time: {
|
||||||
start,
|
start,
|
||||||
end: Date.now(),
|
end: Date.now(),
|
||||||
@@ -457,6 +500,7 @@ export namespace Session {
|
|||||||
const result = await execute(args, opts)
|
const result = await execute(args, opts)
|
||||||
next.metadata!.tool![opts.toolCallId] = {
|
next.metadata!.tool![opts.toolCallId] = {
|
||||||
...result.metadata,
|
...result.metadata,
|
||||||
|
snapshot: await Snapshot.create(input.sessionID),
|
||||||
time: {
|
time: {
|
||||||
start,
|
start,
|
||||||
end: Date.now(),
|
end: Date.now(),
|
||||||
@@ -471,6 +515,7 @@ export namespace Session {
|
|||||||
next.metadata!.tool![opts.toolCallId] = {
|
next.metadata!.tool![opts.toolCallId] = {
|
||||||
error: true,
|
error: true,
|
||||||
message: e.toString(),
|
message: e.toString(),
|
||||||
|
snapshot: await Snapshot.create(input.sessionID),
|
||||||
title: "mcp",
|
title: "mcp",
|
||||||
time: {
|
time: {
|
||||||
start,
|
start,
|
||||||
@@ -735,6 +780,51 @@ export namespace Session {
|
|||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function revert(input: {
|
||||||
|
sessionID: string
|
||||||
|
messageID: string
|
||||||
|
part: number
|
||||||
|
}) {
|
||||||
|
const message = await getMessage(input.sessionID, input.messageID)
|
||||||
|
if (!message) return
|
||||||
|
const part = message.parts[input.part]
|
||||||
|
if (!part) return
|
||||||
|
const session = await get(input.sessionID)
|
||||||
|
const snapshot =
|
||||||
|
session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
|
||||||
|
const old = (() => {
|
||||||
|
if (message.role === "assistant") {
|
||||||
|
const lastTool = message.parts.findLast(
|
||||||
|
(part, index) =>
|
||||||
|
part.type === "tool-invocation" && index < input.part,
|
||||||
|
)
|
||||||
|
if (lastTool && lastTool.type === "tool-invocation")
|
||||||
|
return message.metadata.tool[lastTool.toolInvocation.toolCallId]
|
||||||
|
.snapshot
|
||||||
|
}
|
||||||
|
return message.metadata.snapshot
|
||||||
|
})()
|
||||||
|
if (old) await Snapshot.restore(input.sessionID, old)
|
||||||
|
await update(input.sessionID, (draft) => {
|
||||||
|
draft.revert = {
|
||||||
|
messageID: input.messageID,
|
||||||
|
part: input.part,
|
||||||
|
snapshot,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function unrevert(sessionID: string) {
|
||||||
|
const session = await get(sessionID)
|
||||||
|
if (!session) return
|
||||||
|
if (!session.revert) return
|
||||||
|
if (session.revert.snapshot)
|
||||||
|
await Snapshot.restore(sessionID, session.revert.snapshot)
|
||||||
|
update(sessionID, (draft) => {
|
||||||
|
draft.revert = undefined
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
export async function summarize(input: {
|
export async function summarize(input: {
|
||||||
sessionID: string
|
sessionID: string
|
||||||
providerID: string
|
providerID: string
|
||||||
|
|||||||
@@ -159,6 +159,7 @@ export namespace Message {
|
|||||||
z
|
z
|
||||||
.object({
|
.object({
|
||||||
title: z.string(),
|
title: z.string(),
|
||||||
|
snapshot: z.string().optional(),
|
||||||
time: z.object({
|
time: z.object({
|
||||||
start: z.number(),
|
start: z.number(),
|
||||||
end: z.number(),
|
end: z.number(),
|
||||||
@@ -188,12 +189,8 @@ export namespace Message {
|
|||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.optional(),
|
.optional(),
|
||||||
user: z
|
|
||||||
.object({
|
|
||||||
snapshot: z.string().optional(),
|
snapshot: z.string().optional(),
|
||||||
})
|
})
|
||||||
.optional(),
|
|
||||||
})
|
|
||||||
.openapi({ ref: "MessageMetadata" }),
|
.openapi({ ref: "MessageMetadata" }),
|
||||||
})
|
})
|
||||||
.openapi({
|
.openapi({
|
||||||
@@ -208,6 +205,13 @@ export namespace Message {
|
|||||||
info: Info,
|
info: Info,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
Removed: Bus.event(
|
||||||
|
"message.removed",
|
||||||
|
z.object({
|
||||||
|
sessionID: z.string(),
|
||||||
|
messageID: z.string(),
|
||||||
|
}),
|
||||||
|
),
|
||||||
PartUpdated: Bus.event(
|
PartUpdated: Bus.event(
|
||||||
"message.part.updated",
|
"message.part.updated",
|
||||||
z.object({
|
z.object({
|
||||||
|
|||||||
Reference in New Issue
Block a user