better handling of aborting sessions
This commit is contained in:
@@ -552,167 +552,196 @@ export namespace Session {
|
||||
],
|
||||
}),
|
||||
})
|
||||
for await (const value of result.fullStream) {
|
||||
l.info("part", {
|
||||
type: value.type,
|
||||
})
|
||||
switch (value.type) {
|
||||
case "start":
|
||||
break
|
||||
try {
|
||||
for await (const value of result.fullStream) {
|
||||
l.info("part", {
|
||||
type: value.type,
|
||||
})
|
||||
switch (value.type) {
|
||||
case "start":
|
||||
break
|
||||
|
||||
case "tool-input-start":
|
||||
next.parts.push({
|
||||
type: "tool",
|
||||
tool: value.toolName,
|
||||
id: value.id,
|
||||
state: {
|
||||
status: "pending",
|
||||
},
|
||||
})
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: next.parts[next.parts.length - 1],
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
break
|
||||
|
||||
case "tool-input-delta":
|
||||
break
|
||||
|
||||
case "tool-call": {
|
||||
const match = next.parts.find((p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId)
|
||||
if (match) {
|
||||
match.state = {
|
||||
status: "running",
|
||||
input: value.input,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
case "tool-input-start":
|
||||
next.parts.push({
|
||||
type: "tool",
|
||||
tool: value.toolName,
|
||||
id: value.id,
|
||||
state: {
|
||||
status: "pending",
|
||||
},
|
||||
}
|
||||
})
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: match,
|
||||
part: next.parts[next.parts.length - 1],
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case "tool-result": {
|
||||
const match = next.parts.find((p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId)
|
||||
if (match && match.state.status === "running") {
|
||||
match.state = {
|
||||
status: "completed",
|
||||
input: value.input,
|
||||
output: value.output.output,
|
||||
metadata: value.output.metadata,
|
||||
title: value.output.title,
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
}
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: match,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
break
|
||||
|
||||
case "tool-error": {
|
||||
const match = next.parts.find((p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId)
|
||||
if (match && match.state.status === "running") {
|
||||
match.state = {
|
||||
status: "error",
|
||||
input: value.input,
|
||||
error: (value.error as any).toString(),
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
}
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: match,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case "tool-input-delta":
|
||||
break
|
||||
|
||||
case "error":
|
||||
const e = value.error
|
||||
log.error("", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
next.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
next.error = new Provider.AuthError(
|
||||
{
|
||||
providerID: input.providerID,
|
||||
message: e.message,
|
||||
case "tool-call": {
|
||||
const match = next.parts.find(
|
||||
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
|
||||
)
|
||||
if (match) {
|
||||
match.state = {
|
||||
status: "running",
|
||||
input: value.input,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
}
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: match,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
Bus.publish(Event.Error, {
|
||||
error: next.error,
|
||||
})
|
||||
break
|
||||
|
||||
case "start-step":
|
||||
next.parts.push({
|
||||
type: "step-start",
|
||||
})
|
||||
break
|
||||
|
||||
case "finish-step":
|
||||
const usage = getUsage(model.info, value.usage, value.providerMetadata)
|
||||
next.cost += usage.cost
|
||||
next.tokens = usage.tokens
|
||||
break
|
||||
|
||||
case "text-start":
|
||||
text = {
|
||||
type: "text",
|
||||
text: "",
|
||||
case "tool-result": {
|
||||
const match = next.parts.find(
|
||||
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
|
||||
)
|
||||
if (match && match.state.status === "running") {
|
||||
match.state = {
|
||||
status: "completed",
|
||||
input: value.input,
|
||||
output: value.output.output,
|
||||
metadata: value.output.metadata,
|
||||
title: value.output.title,
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
}
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: match,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
break
|
||||
|
||||
case "text":
|
||||
if (text.text === "") next.parts.push(text)
|
||||
text.text += value.text
|
||||
break
|
||||
case "tool-error": {
|
||||
const match = next.parts.find(
|
||||
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
|
||||
)
|
||||
if (match && match.state.status === "running") {
|
||||
match.state = {
|
||||
status: "error",
|
||||
input: value.input,
|
||||
error: (value.error as any).toString(),
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
}
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: match,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case "text-end":
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: text,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
break
|
||||
case "error":
|
||||
throw value.error
|
||||
|
||||
case "finish":
|
||||
next.time.completed = Date.now()
|
||||
break
|
||||
case "start-step":
|
||||
next.parts.push({
|
||||
type: "step-start",
|
||||
})
|
||||
break
|
||||
|
||||
default:
|
||||
l.info("unhandled", {
|
||||
...value,
|
||||
})
|
||||
continue
|
||||
case "finish-step":
|
||||
const usage = getUsage(model.info, value.usage, value.providerMetadata)
|
||||
next.cost += usage.cost
|
||||
next.tokens = usage.tokens
|
||||
break
|
||||
|
||||
case "text-start":
|
||||
text = {
|
||||
type: "text",
|
||||
text: "",
|
||||
}
|
||||
break
|
||||
|
||||
case "text":
|
||||
if (text.text === "") next.parts.push(text)
|
||||
text.text += value.text
|
||||
break
|
||||
|
||||
case "text-end":
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part: text,
|
||||
sessionID: next.sessionID,
|
||||
messageID: next.id,
|
||||
})
|
||||
break
|
||||
|
||||
case "finish":
|
||||
next.time.completed = Date.now()
|
||||
break
|
||||
|
||||
default:
|
||||
l.info("unhandled", {
|
||||
...value,
|
||||
})
|
||||
continue
|
||||
}
|
||||
await updateMessage(next)
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
next.error = new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
},
|
||||
).toObject()
|
||||
break
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
next.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
next.error = new Provider.AuthError(
|
||||
{
|
||||
providerID: input.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
}
|
||||
Bus.publish(Event.Error, {
|
||||
error: next.error,
|
||||
})
|
||||
}
|
||||
for (const part of next.parts) {
|
||||
if (part.type === "tool" && part.state.status !== "completed") {
|
||||
part.state = {
|
||||
status: "error",
|
||||
error: "Tool execution aborted",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
input: {},
|
||||
}
|
||||
}
|
||||
await updateMessage(next)
|
||||
}
|
||||
next.time.completed = Date.now()
|
||||
await updateMessage(next)
|
||||
|
||||
@@ -7,6 +7,7 @@ import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
|
||||
|
||||
export namespace MessageV2 {
|
||||
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
||||
export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
|
||||
|
||||
export const ToolStatePending = z
|
||||
.object({
|
||||
@@ -148,7 +149,12 @@ export namespace MessageV2 {
|
||||
completed: z.number().optional(),
|
||||
}),
|
||||
error: z
|
||||
.discriminatedUnion("name", [Provider.AuthError.Schema, NamedError.Unknown.Schema, OutputLengthError.Schema])
|
||||
.discriminatedUnion("name", [
|
||||
Provider.AuthError.Schema,
|
||||
NamedError.Unknown.Schema,
|
||||
OutputLengthError.Schema,
|
||||
AbortedError.Schema,
|
||||
])
|
||||
.optional(),
|
||||
system: z.string().array(),
|
||||
modelID: z.string(),
|
||||
|
||||
Reference in New Issue
Block a user