more efficient snapshots in parallel toolcalls

This commit is contained in:
Dax Raad
2025-08-03 12:09:59 -04:00
parent 4b204fee58
commit 06830327e7

View File

@@ -735,7 +735,6 @@ export namespace Session {
args, args,
}, },
) )
await processor.track(options.toolCallId)
const result = await item.execute(args, { const result = await item.execute(args, {
sessionID: input.sessionID, sessionID: input.sessionID,
abort: abort.signal, abort: abort.signal,
@@ -784,7 +783,6 @@ export namespace Session {
const execute = item.execute const execute = item.execute
if (!execute) continue if (!execute) continue
item.execute = async (args, opts) => { item.execute = async (args, opts) => {
await processor.track(opts.toolCallId)
const result = await execute(args, opts) const result = await execute(args, opts)
const output = result.content const output = result.content
.filter((x: any) => x.type === "text") .filter((x: any) => x.type === "text")
@@ -920,15 +918,11 @@ export namespace Session {
} }
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) { function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
const toolCalls: Record<string, MessageV2.ToolPart> = {} const toolcalls: Record<string, MessageV2.ToolPart> = {}
const snapshots: Record<string, string> = {} let snapshot: string | undefined
return { return {
async track(toolCallID: string) {
const hash = await Snapshot.track()
if (hash) snapshots[toolCallID] = hash
},
partFromToolCall(toolCallID: string) { partFromToolCall(toolCallID: string) {
return toolCalls[toolCallID] return toolcalls[toolCallID]
}, },
async process(stream: StreamTextResult<Record<string, AITool>, never>) { async process(stream: StreamTextResult<Record<string, AITool>, never>) {
try { try {
@@ -944,7 +938,7 @@ export namespace Session {
case "tool-input-start": case "tool-input-start":
const part = await updatePart({ const part = await updatePart({
id: toolCalls[value.id]?.id ?? Identifier.ascending("part"), id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
messageID: assistantMsg.id, messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID, sessionID: assistantMsg.sessionID,
type: "tool", type: "tool",
@@ -954,7 +948,7 @@ export namespace Session {
status: "pending", status: "pending",
}, },
}) })
toolCalls[value.id] = part as MessageV2.ToolPart toolcalls[value.id] = part as MessageV2.ToolPart
break break
case "tool-input-delta": case "tool-input-delta":
@@ -964,7 +958,7 @@ export namespace Session {
break break
case "tool-call": { case "tool-call": {
const match = toolCalls[value.toolCallId] const match = toolcalls[value.toolCallId]
if (match) { if (match) {
const part = await updatePart({ const part = await updatePart({
...match, ...match,
@@ -976,12 +970,12 @@ export namespace Session {
}, },
}, },
}) })
toolCalls[value.toolCallId] = part as MessageV2.ToolPart toolcalls[value.toolCallId] = part as MessageV2.ToolPart
} }
break break
} }
case "tool-result": { case "tool-result": {
const match = toolCalls[value.toolCallId] const match = toolcalls[value.toolCallId]
if (match && match.state.status === "running") { if (match && match.state.status === "running") {
await updatePart({ await updatePart({
...match, ...match,
@@ -997,27 +991,13 @@ export namespace Session {
}, },
}, },
}) })
delete toolCalls[value.toolCallId] delete toolcalls[value.toolCallId]
const snapshot = snapshots[value.toolCallId]
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
if (patch.files.length) {
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
}
} }
break break
} }
case "tool-error": { case "tool-error": {
const match = toolCalls[value.toolCallId] const match = toolcalls[value.toolCallId]
if (match && match.state.status === "running") { if (match && match.state.status === "running") {
await updatePart({ await updatePart({
...match, ...match,
@@ -1031,19 +1011,7 @@ export namespace Session {
}, },
}, },
}) })
delete toolCalls[value.toolCallId] delete toolcalls[value.toolCallId]
const snapshot = snapshots[value.toolCallId]
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
} }
break break
} }
@@ -1058,6 +1026,7 @@ export namespace Session {
sessionID: assistantMsg.sessionID, sessionID: assistantMsg.sessionID,
type: "step-start", type: "step-start",
}) })
snapshot = await Snapshot.track()
break break
case "finish-step": case "finish-step":
@@ -1073,6 +1042,20 @@ export namespace Session {
cost: usage.cost, cost: usage.cost,
}) })
await updateMessage(assistantMsg) await updateMessage(assistantMsg)
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
if (patch.files.length) {
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
snapshot = undefined
}
break break
case "text-start": case "text-start":