more efficient snapshots in parallel toolcalls
This commit is contained in:
@@ -735,7 +735,6 @@ export namespace Session {
|
|||||||
args,
|
args,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await processor.track(options.toolCallId)
|
|
||||||
const result = await item.execute(args, {
|
const result = await item.execute(args, {
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
abort: abort.signal,
|
abort: abort.signal,
|
||||||
@@ -784,7 +783,6 @@ export namespace Session {
|
|||||||
const execute = item.execute
|
const execute = item.execute
|
||||||
if (!execute) continue
|
if (!execute) continue
|
||||||
item.execute = async (args, opts) => {
|
item.execute = async (args, opts) => {
|
||||||
await processor.track(opts.toolCallId)
|
|
||||||
const result = await execute(args, opts)
|
const result = await execute(args, opts)
|
||||||
const output = result.content
|
const output = result.content
|
||||||
.filter((x: any) => x.type === "text")
|
.filter((x: any) => x.type === "text")
|
||||||
@@ -920,15 +918,11 @@ export namespace Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
|
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
|
||||||
const toolCalls: Record<string, MessageV2.ToolPart> = {}
|
const toolcalls: Record<string, MessageV2.ToolPart> = {}
|
||||||
const snapshots: Record<string, string> = {}
|
let snapshot: string | undefined
|
||||||
return {
|
return {
|
||||||
async track(toolCallID: string) {
|
|
||||||
const hash = await Snapshot.track()
|
|
||||||
if (hash) snapshots[toolCallID] = hash
|
|
||||||
},
|
|
||||||
partFromToolCall(toolCallID: string) {
|
partFromToolCall(toolCallID: string) {
|
||||||
return toolCalls[toolCallID]
|
return toolcalls[toolCallID]
|
||||||
},
|
},
|
||||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||||
try {
|
try {
|
||||||
@@ -944,7 +938,7 @@ export namespace Session {
|
|||||||
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
const part = await updatePart({
|
const part = await updatePart({
|
||||||
id: toolCalls[value.id]?.id ?? Identifier.ascending("part"),
|
id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
|
||||||
messageID: assistantMsg.id,
|
messageID: assistantMsg.id,
|
||||||
sessionID: assistantMsg.sessionID,
|
sessionID: assistantMsg.sessionID,
|
||||||
type: "tool",
|
type: "tool",
|
||||||
@@ -954,7 +948,7 @@ export namespace Session {
|
|||||||
status: "pending",
|
status: "pending",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
toolCalls[value.id] = part as MessageV2.ToolPart
|
toolcalls[value.id] = part as MessageV2.ToolPart
|
||||||
break
|
break
|
||||||
|
|
||||||
case "tool-input-delta":
|
case "tool-input-delta":
|
||||||
@@ -964,7 +958,7 @@ export namespace Session {
|
|||||||
break
|
break
|
||||||
|
|
||||||
case "tool-call": {
|
case "tool-call": {
|
||||||
const match = toolCalls[value.toolCallId]
|
const match = toolcalls[value.toolCallId]
|
||||||
if (match) {
|
if (match) {
|
||||||
const part = await updatePart({
|
const part = await updatePart({
|
||||||
...match,
|
...match,
|
||||||
@@ -976,12 +970,12 @@ export namespace Session {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
toolCalls[value.toolCallId] = part as MessageV2.ToolPart
|
toolcalls[value.toolCallId] = part as MessageV2.ToolPart
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
case "tool-result": {
|
case "tool-result": {
|
||||||
const match = toolCalls[value.toolCallId]
|
const match = toolcalls[value.toolCallId]
|
||||||
if (match && match.state.status === "running") {
|
if (match && match.state.status === "running") {
|
||||||
await updatePart({
|
await updatePart({
|
||||||
...match,
|
...match,
|
||||||
@@ -997,27 +991,13 @@ export namespace Session {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
delete toolCalls[value.toolCallId]
|
delete toolcalls[value.toolCallId]
|
||||||
const snapshot = snapshots[value.toolCallId]
|
|
||||||
if (snapshot) {
|
|
||||||
const patch = await Snapshot.patch(snapshot)
|
|
||||||
if (patch.files.length) {
|
|
||||||
await updatePart({
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "patch",
|
|
||||||
hash: patch.hash,
|
|
||||||
files: patch.files,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
case "tool-error": {
|
case "tool-error": {
|
||||||
const match = toolCalls[value.toolCallId]
|
const match = toolcalls[value.toolCallId]
|
||||||
if (match && match.state.status === "running") {
|
if (match && match.state.status === "running") {
|
||||||
await updatePart({
|
await updatePart({
|
||||||
...match,
|
...match,
|
||||||
@@ -1031,19 +1011,7 @@ export namespace Session {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
delete toolCalls[value.toolCallId]
|
delete toolcalls[value.toolCallId]
|
||||||
const snapshot = snapshots[value.toolCallId]
|
|
||||||
if (snapshot) {
|
|
||||||
const patch = await Snapshot.patch(snapshot)
|
|
||||||
await updatePart({
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "patch",
|
|
||||||
hash: patch.hash,
|
|
||||||
files: patch.files,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -1058,6 +1026,7 @@ export namespace Session {
|
|||||||
sessionID: assistantMsg.sessionID,
|
sessionID: assistantMsg.sessionID,
|
||||||
type: "step-start",
|
type: "step-start",
|
||||||
})
|
})
|
||||||
|
snapshot = await Snapshot.track()
|
||||||
break
|
break
|
||||||
|
|
||||||
case "finish-step":
|
case "finish-step":
|
||||||
@@ -1073,6 +1042,20 @@ export namespace Session {
|
|||||||
cost: usage.cost,
|
cost: usage.cost,
|
||||||
})
|
})
|
||||||
await updateMessage(assistantMsg)
|
await updateMessage(assistantMsg)
|
||||||
|
if (snapshot) {
|
||||||
|
const patch = await Snapshot.patch(snapshot)
|
||||||
|
if (patch.files.length) {
|
||||||
|
await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "patch",
|
||||||
|
hash: patch.hash,
|
||||||
|
files: patch.files,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
snapshot = undefined
|
||||||
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
case "text-start":
|
case "text-start":
|
||||||
|
|||||||
Reference in New Issue
Block a user