bring back task tool
This commit is contained in:
@@ -21,7 +21,7 @@ import { AuthCopilot } from "../auth/copilot"
|
|||||||
import { ModelsDev } from "./models"
|
import { ModelsDev } from "./models"
|
||||||
import { NamedError } from "../util/error"
|
import { NamedError } from "../util/error"
|
||||||
import { Auth } from "../auth"
|
import { Auth } from "../auth"
|
||||||
// import { TaskTool } from "../tool/task"
|
import { TaskTool } from "../tool/task"
|
||||||
|
|
||||||
export namespace Provider {
|
export namespace Provider {
|
||||||
const log = Log.create({ service: "provider" })
|
const log = Log.create({ service: "provider" })
|
||||||
@@ -456,7 +456,7 @@ export namespace Provider {
|
|||||||
WriteTool,
|
WriteTool,
|
||||||
TodoWriteTool,
|
TodoWriteTool,
|
||||||
TodoReadTool,
|
TodoReadTool,
|
||||||
// TaskTool,
|
TaskTool,
|
||||||
]
|
]
|
||||||
|
|
||||||
const TOOL_MAPPING: Record<string, Tool.Info[]> = {
|
const TOOL_MAPPING: Record<string, Tool.Info[]> = {
|
||||||
@@ -531,12 +531,4 @@ export namespace Provider {
|
|||||||
providerID: z.string(),
|
providerID: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
export const AuthError = NamedError.create(
|
|
||||||
"ProviderAuthError",
|
|
||||||
z.object({
|
|
||||||
providerID: z.string(),
|
|
||||||
message: z.string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -443,7 +443,7 @@ export namespace Session {
|
|||||||
const result = await ReadTool.execute(args, {
|
const result = await ReadTool.execute(args, {
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
abort: abort.signal,
|
abort: abort.signal,
|
||||||
messageID: "", // read tool doesn't use message ID
|
messageID: userMsg.id,
|
||||||
metadata: async () => {},
|
metadata: async () => {},
|
||||||
})
|
})
|
||||||
return [
|
return [
|
||||||
@@ -577,20 +577,22 @@ export namespace Session {
|
|||||||
await updateMessage(assistantMsg)
|
await updateMessage(assistantMsg)
|
||||||
const tools: Record<string, AITool> = {}
|
const tools: Record<string, AITool> = {}
|
||||||
|
|
||||||
|
const processor = createProcessor(assistantMsg, model.info)
|
||||||
|
|
||||||
for (const item of await Provider.tools(input.providerID)) {
|
for (const item of await Provider.tools(input.providerID)) {
|
||||||
if (mode.tools[item.id] === false) continue
|
if (mode.tools[item.id] === false) continue
|
||||||
|
if (session.parentID && item.id === "task") continue
|
||||||
tools[item.id] = tool({
|
tools[item.id] = tool({
|
||||||
id: item.id as any,
|
id: item.id as any,
|
||||||
description: item.description,
|
description: item.description,
|
||||||
inputSchema: item.parameters as ZodSchema,
|
inputSchema: item.parameters as ZodSchema,
|
||||||
async execute(args) {
|
async execute(args, options) {
|
||||||
const result = await item.execute(args, {
|
const result = await item.execute(args, {
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
abort: abort.signal,
|
abort: abort.signal,
|
||||||
messageID: assistantMsg.id,
|
messageID: assistantMsg.id,
|
||||||
metadata: async () => {
|
metadata: async (val) => {
|
||||||
/*
|
const match = processor.partFromToolCall(options.toolCallId)
|
||||||
const match = toolCalls[opts.toolCallId]
|
|
||||||
if (match && match.state.status === "running") {
|
if (match && match.state.status === "running") {
|
||||||
await updatePart({
|
await updatePart({
|
||||||
...match,
|
...match,
|
||||||
@@ -598,14 +600,13 @@ export namespace Session {
|
|||||||
title: val.title,
|
title: val.title,
|
||||||
metadata: val.metadata,
|
metadata: val.metadata,
|
||||||
status: "running",
|
status: "running",
|
||||||
input: args.input,
|
input: args,
|
||||||
time: {
|
time: {
|
||||||
start: Date.now(),
|
start: Date.now(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
@@ -676,257 +677,260 @@ export namespace Session {
|
|||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
const result = await processStream(assistantMsg, model.info, stream)
|
const result = await processor.process(stream)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
async function processStream(
|
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
|
||||||
assistantMsg: MessageV2.Assistant,
|
const toolCalls: Record<string, MessageV2.ToolPart> = {}
|
||||||
model: ModelsDev.Model,
|
return {
|
||||||
stream: StreamTextResult<Record<string, AITool>, never>,
|
partFromToolCall(toolCallID: string) {
|
||||||
) {
|
return toolCalls[toolCallID]
|
||||||
try {
|
},
|
||||||
let currentText: MessageV2.TextPart | undefined
|
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||||
const toolCalls: Record<string, MessageV2.ToolPart> = {}
|
try {
|
||||||
|
let currentText: MessageV2.TextPart | undefined
|
||||||
|
|
||||||
for await (const value of stream.fullStream) {
|
for await (const value of stream.fullStream) {
|
||||||
log.info("part", {
|
log.info("part", {
|
||||||
type: value.type,
|
type: value.type,
|
||||||
})
|
|
||||||
switch (value.type) {
|
|
||||||
case "start":
|
|
||||||
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
|
||||||
if (snapshot)
|
|
||||||
await updatePart({
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "snapshot",
|
|
||||||
snapshot,
|
|
||||||
})
|
|
||||||
break
|
|
||||||
|
|
||||||
case "tool-input-start":
|
|
||||||
const part = await updatePart({
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "tool",
|
|
||||||
tool: value.toolName,
|
|
||||||
callID: value.id,
|
|
||||||
state: {
|
|
||||||
status: "pending",
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
toolCalls[value.id] = part as MessageV2.ToolPart
|
switch (value.type) {
|
||||||
break
|
case "start":
|
||||||
|
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
||||||
|
if (snapshot)
|
||||||
|
await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "snapshot",
|
||||||
|
snapshot,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
case "tool-input-delta":
|
case "tool-input-start":
|
||||||
break
|
const part = await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "tool",
|
||||||
|
tool: value.toolName,
|
||||||
|
callID: value.id,
|
||||||
|
state: {
|
||||||
|
status: "pending",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
toolCalls[value.id] = part as MessageV2.ToolPart
|
||||||
|
break
|
||||||
|
|
||||||
case "tool-call": {
|
case "tool-input-delta":
|
||||||
const match = toolCalls[value.toolCallId]
|
break
|
||||||
if (match) {
|
|
||||||
const part = await updatePart({
|
case "tool-call": {
|
||||||
...match,
|
const match = toolCalls[value.toolCallId]
|
||||||
state: {
|
if (match) {
|
||||||
status: "running",
|
const part = await updatePart({
|
||||||
input: value.input,
|
...match,
|
||||||
|
state: {
|
||||||
|
status: "running",
|
||||||
|
input: value.input,
|
||||||
|
time: {
|
||||||
|
start: Date.now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
toolCalls[value.toolCallId] = part as MessageV2.ToolPart
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case "tool-result": {
|
||||||
|
const match = toolCalls[value.toolCallId]
|
||||||
|
if (match && match.state.status === "running") {
|
||||||
|
await updatePart({
|
||||||
|
...match,
|
||||||
|
state: {
|
||||||
|
status: "completed",
|
||||||
|
input: value.input,
|
||||||
|
output: value.output.output,
|
||||||
|
metadata: value.output.metadata,
|
||||||
|
title: value.output.title,
|
||||||
|
time: {
|
||||||
|
start: match.state.time.start,
|
||||||
|
end: Date.now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
delete toolCalls[value.toolCallId]
|
||||||
|
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
||||||
|
if (snapshot)
|
||||||
|
await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "snapshot",
|
||||||
|
snapshot,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tool-error": {
|
||||||
|
const match = toolCalls[value.toolCallId]
|
||||||
|
if (match && match.state.status === "running") {
|
||||||
|
await updatePart({
|
||||||
|
...match,
|
||||||
|
state: {
|
||||||
|
status: "error",
|
||||||
|
input: value.input,
|
||||||
|
error: (value.error as any).toString(),
|
||||||
|
time: {
|
||||||
|
start: match.state.time.start,
|
||||||
|
end: Date.now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
delete toolCalls[value.toolCallId]
|
||||||
|
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
||||||
|
if (snapshot)
|
||||||
|
await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "snapshot",
|
||||||
|
snapshot,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
throw value.error
|
||||||
|
|
||||||
|
case "start-step":
|
||||||
|
await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "step-start",
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
case "finish-step":
|
||||||
|
const usage = getUsage(model, value.usage, value.providerMetadata)
|
||||||
|
assistantMsg.cost += usage.cost
|
||||||
|
assistantMsg.tokens = usage.tokens
|
||||||
|
await updatePart({
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "step-finish",
|
||||||
|
tokens: usage.tokens,
|
||||||
|
cost: usage.cost,
|
||||||
|
})
|
||||||
|
await updateMessage(assistantMsg)
|
||||||
|
break
|
||||||
|
|
||||||
|
case "text-start":
|
||||||
|
currentText = {
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
messageID: assistantMsg.id,
|
||||||
|
sessionID: assistantMsg.sessionID,
|
||||||
|
type: "text",
|
||||||
|
text: "",
|
||||||
time: {
|
time: {
|
||||||
start: Date.now(),
|
start: Date.now(),
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
})
|
break
|
||||||
toolCalls[value.toolCallId] = part as MessageV2.ToolPart
|
|
||||||
}
|
case "text":
|
||||||
break
|
if (currentText) {
|
||||||
}
|
currentText.text += value.text
|
||||||
case "tool-result": {
|
await updatePart(currentText)
|
||||||
const match = toolCalls[value.toolCallId]
|
}
|
||||||
if (match && match.state.status === "running") {
|
break
|
||||||
await updatePart({
|
|
||||||
...match,
|
case "text-end":
|
||||||
state: {
|
if (currentText && currentText.text) {
|
||||||
status: "completed",
|
currentText.time = {
|
||||||
input: value.input,
|
start: Date.now(),
|
||||||
output: value.output.output,
|
|
||||||
metadata: value.output.metadata,
|
|
||||||
title: value.output.title,
|
|
||||||
time: {
|
|
||||||
start: match.state.time.start,
|
|
||||||
end: Date.now(),
|
end: Date.now(),
|
||||||
},
|
}
|
||||||
},
|
await updatePart(currentText)
|
||||||
})
|
}
|
||||||
delete toolCalls[value.toolCallId]
|
currentText = undefined
|
||||||
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
break
|
||||||
if (snapshot)
|
|
||||||
await updatePart({
|
case "finish":
|
||||||
id: Identifier.ascending("part"),
|
assistantMsg.time.completed = Date.now()
|
||||||
messageID: assistantMsg.id,
|
await updateMessage(assistantMsg)
|
||||||
sessionID: assistantMsg.sessionID,
|
break
|
||||||
type: "snapshot",
|
|
||||||
snapshot,
|
default:
|
||||||
|
log.info("unhandled", {
|
||||||
|
...value,
|
||||||
})
|
})
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
} catch (e) {
|
||||||
case "tool-error": {
|
log.error("", {
|
||||||
const match = toolCalls[value.toolCallId]
|
error: e,
|
||||||
if (match && match.state.status === "running") {
|
})
|
||||||
await updatePart({
|
switch (true) {
|
||||||
...match,
|
case e instanceof DOMException && e.name === "AbortError":
|
||||||
state: {
|
assistantMsg.error = new MessageV2.AbortedError(
|
||||||
status: "error",
|
{ message: e.message },
|
||||||
input: value.input,
|
{
|
||||||
error: (value.error as any).toString(),
|
cause: e,
|
||||||
time: {
|
|
||||||
start: match.state.time.start,
|
|
||||||
end: Date.now(),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
})
|
).toObject()
|
||||||
delete toolCalls[value.toolCallId]
|
break
|
||||||
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
case MessageV2.OutputLengthError.isInstance(e):
|
||||||
if (snapshot)
|
assistantMsg.error = e
|
||||||
await updatePart({
|
break
|
||||||
id: Identifier.ascending("part"),
|
case LoadAPIKeyError.isInstance(e):
|
||||||
messageID: assistantMsg.id,
|
assistantMsg.error = new MessageV2.AuthError(
|
||||||
sessionID: assistantMsg.sessionID,
|
{
|
||||||
type: "snapshot",
|
providerID: model.id,
|
||||||
snapshot,
|
message: e.message,
|
||||||
})
|
},
|
||||||
}
|
{ cause: e },
|
||||||
break
|
).toObject()
|
||||||
|
break
|
||||||
|
case e instanceof Error:
|
||||||
|
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||||
}
|
}
|
||||||
|
Bus.publish(Event.Error, {
|
||||||
case "error":
|
sessionID: assistantMsg.sessionID,
|
||||||
throw value.error
|
error: assistantMsg.error,
|
||||||
|
})
|
||||||
case "start-step":
|
|
||||||
await updatePart({
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "step-start",
|
|
||||||
})
|
|
||||||
break
|
|
||||||
|
|
||||||
case "finish-step":
|
|
||||||
const usage = getUsage(model, value.usage, value.providerMetadata)
|
|
||||||
assistantMsg.cost += usage.cost
|
|
||||||
assistantMsg.tokens = usage.tokens
|
|
||||||
await updatePart({
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "step-finish",
|
|
||||||
tokens: usage.tokens,
|
|
||||||
cost: usage.cost,
|
|
||||||
})
|
|
||||||
await updateMessage(assistantMsg)
|
|
||||||
break
|
|
||||||
|
|
||||||
case "text-start":
|
|
||||||
currentText = {
|
|
||||||
id: Identifier.ascending("part"),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
type: "text",
|
|
||||||
text: "",
|
|
||||||
time: {
|
|
||||||
start: Date.now(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
break
|
|
||||||
|
|
||||||
case "text":
|
|
||||||
if (currentText) {
|
|
||||||
currentText.text += value.text
|
|
||||||
await updatePart(currentText)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
|
|
||||||
case "text-end":
|
|
||||||
if (currentText && currentText.text) {
|
|
||||||
currentText.time = {
|
|
||||||
start: Date.now(),
|
|
||||||
end: Date.now(),
|
|
||||||
}
|
|
||||||
await updatePart(currentText)
|
|
||||||
}
|
|
||||||
currentText = undefined
|
|
||||||
break
|
|
||||||
|
|
||||||
case "finish":
|
|
||||||
assistantMsg.time.completed = Date.now()
|
|
||||||
await updateMessage(assistantMsg)
|
|
||||||
break
|
|
||||||
|
|
||||||
default:
|
|
||||||
log.info("unhandled", {
|
|
||||||
...value,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
const p = await parts(assistantMsg.sessionID, assistantMsg.id)
|
||||||
} catch (e) {
|
for (const part of p) {
|
||||||
log.error("", {
|
if (part.type === "tool" && part.state.status !== "completed") {
|
||||||
error: e,
|
updatePart({
|
||||||
})
|
...part,
|
||||||
switch (true) {
|
state: {
|
||||||
case e instanceof DOMException && e.name === "AbortError":
|
status: "error",
|
||||||
assistantMsg.error = new MessageV2.AbortedError(
|
error: "Tool execution aborted",
|
||||||
{ message: e.message },
|
time: {
|
||||||
{
|
start: Date.now(),
|
||||||
cause: e,
|
end: Date.now(),
|
||||||
},
|
},
|
||||||
).toObject()
|
input: {},
|
||||||
break
|
},
|
||||||
case MessageV2.OutputLengthError.isInstance(e):
|
})
|
||||||
assistantMsg.error = e
|
}
|
||||||
break
|
}
|
||||||
case LoadAPIKeyError.isInstance(e):
|
assistantMsg.time.completed = Date.now()
|
||||||
assistantMsg.error = new Provider.AuthError(
|
await updateMessage(assistantMsg)
|
||||||
{
|
return { info: assistantMsg, parts: p }
|
||||||
providerID: model.id,
|
},
|
||||||
message: e.message,
|
|
||||||
},
|
|
||||||
{ cause: e },
|
|
||||||
).toObject()
|
|
||||||
break
|
|
||||||
case e instanceof Error:
|
|
||||||
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
|
||||||
}
|
|
||||||
Bus.publish(Event.Error, {
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
error: assistantMsg.error,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
const p = await parts(assistantMsg.sessionID, assistantMsg.id)
|
|
||||||
for (const part of p) {
|
|
||||||
if (part.type === "tool" && part.state.status !== "completed") {
|
|
||||||
updatePart({
|
|
||||||
...part,
|
|
||||||
state: {
|
|
||||||
status: "error",
|
|
||||||
error: "Tool execution aborted",
|
|
||||||
time: {
|
|
||||||
start: Date.now(),
|
|
||||||
end: Date.now(),
|
|
||||||
},
|
|
||||||
input: {},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assistantMsg.time.completed = Date.now()
|
|
||||||
await updateMessage(assistantMsg)
|
|
||||||
return { info: assistantMsg, parts: p }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
|
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
|
||||||
@@ -1006,6 +1010,7 @@ export namespace Session {
|
|||||||
}
|
}
|
||||||
await updateMessage(next)
|
await updateMessage(next)
|
||||||
|
|
||||||
|
const processor = createProcessor(next, model.info)
|
||||||
const stream = streamText({
|
const stream = streamText({
|
||||||
abortSignal: abort.signal,
|
abortSignal: abort.signal,
|
||||||
model: model.language,
|
model: model.language,
|
||||||
@@ -1029,7 +1034,7 @@ export namespace Session {
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await processStream(next, model.info, stream)
|
const result = await processor.process(stream)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import z from "zod"
|
import z from "zod"
|
||||||
import { Bus } from "../bus"
|
import { Bus } from "../bus"
|
||||||
import { Provider } from "../provider/provider"
|
|
||||||
import { NamedError } from "../util/error"
|
import { NamedError } from "../util/error"
|
||||||
import { Message } from "./message"
|
import { Message } from "./message"
|
||||||
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
|
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
|
||||||
@@ -9,6 +8,13 @@ import { Identifier } from "../id/id"
|
|||||||
export namespace MessageV2 {
|
export namespace MessageV2 {
|
||||||
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
||||||
export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
|
export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
|
||||||
|
export const AuthError = NamedError.create(
|
||||||
|
"ProviderAuthError",
|
||||||
|
z.object({
|
||||||
|
providerID: z.string(),
|
||||||
|
message: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
export const ToolStatePending = z
|
export const ToolStatePending = z
|
||||||
.object({
|
.object({
|
||||||
@@ -173,7 +179,7 @@ export namespace MessageV2 {
|
|||||||
}),
|
}),
|
||||||
error: z
|
error: z
|
||||||
.discriminatedUnion("name", [
|
.discriminatedUnion("name", [
|
||||||
Provider.AuthError.Schema,
|
AuthError.Schema,
|
||||||
NamedError.Unknown.Schema,
|
NamedError.Unknown.Schema,
|
||||||
OutputLengthError.Schema,
|
OutputLengthError.Schema,
|
||||||
AbortedError.Schema,
|
AbortedError.Schema,
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
import z from "zod"
|
import z from "zod"
|
||||||
import { Provider } from "../provider/provider"
|
|
||||||
import { NamedError } from "../util/error"
|
import { NamedError } from "../util/error"
|
||||||
|
|
||||||
export namespace Message {
|
export namespace Message {
|
||||||
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
||||||
|
export const AuthError = NamedError.create(
|
||||||
|
"ProviderAuthError",
|
||||||
|
z.object({
|
||||||
|
providerID: z.string(),
|
||||||
|
message: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
export const ToolCall = z
|
export const ToolCall = z
|
||||||
.object({
|
.object({
|
||||||
@@ -134,11 +140,7 @@ export namespace Message {
|
|||||||
completed: z.number().optional(),
|
completed: z.number().optional(),
|
||||||
}),
|
}),
|
||||||
error: z
|
error: z
|
||||||
.discriminatedUnion("name", [
|
.discriminatedUnion("name", [AuthError.Schema, NamedError.Unknown.Schema, OutputLengthError.Schema])
|
||||||
Provider.AuthError.Schema,
|
|
||||||
NamedError.Unknown.Schema,
|
|
||||||
OutputLengthError.Schema,
|
|
||||||
])
|
|
||||||
.optional(),
|
.optional(),
|
||||||
sessionID: z.string(),
|
sessionID: z.string(),
|
||||||
tool: z.record(
|
tool: z.record(
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ export namespace Storage {
|
|||||||
cwd: path.join(dir, prefix),
|
cwd: path.join(dir, prefix),
|
||||||
onlyFiles: true,
|
onlyFiles: true,
|
||||||
}),
|
}),
|
||||||
)
|
).then((items) => items.map((item) => path.join(prefix, item.slice(0, -5))))
|
||||||
result.sort()
|
result.sort()
|
||||||
return result
|
return result
|
||||||
} catch {
|
} catch {
|
||||||
|
|||||||
@@ -15,21 +15,15 @@ export const TaskTool = Tool.define({
|
|||||||
}),
|
}),
|
||||||
async execute(params, ctx) {
|
async execute(params, ctx) {
|
||||||
const session = await Session.create(ctx.sessionID)
|
const session = await Session.create(ctx.sessionID)
|
||||||
const msg = (await Session.getMessage(ctx.sessionID, ctx.messageID)) as MessageV2.Assistant
|
const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
|
||||||
|
if (msg.role !== "assistant") throw new Error("Not an assistant message")
|
||||||
const parts: Record<string, MessageV2.Part> = {}
|
|
||||||
function summary(input: MessageV2.Part[]) {
|
|
||||||
const result = []
|
|
||||||
for (const part of input) {
|
|
||||||
if (part.type === "tool" && part.state.status === "completed") {
|
|
||||||
result.push(part)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const messageID = Identifier.ascending("message")
|
||||||
|
const parts: Record<string, MessageV2.ToolPart> = {}
|
||||||
const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
|
const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
|
||||||
if (evt.properties.part.sessionID !== session.id) return
|
if (evt.properties.part.sessionID !== session.id) return
|
||||||
|
if (evt.properties.part.messageID === messageID) return
|
||||||
|
if (evt.properties.part.type !== "tool") return
|
||||||
parts[evt.properties.part.id] = evt.properties.part
|
parts[evt.properties.part.id] = evt.properties.part
|
||||||
ctx.metadata({
|
ctx.metadata({
|
||||||
title: params.description,
|
title: params.description,
|
||||||
@@ -42,7 +36,6 @@ export const TaskTool = Tool.define({
|
|||||||
ctx.abort.addEventListener("abort", () => {
|
ctx.abort.addEventListener("abort", () => {
|
||||||
Session.abort(session.id)
|
Session.abort(session.id)
|
||||||
})
|
})
|
||||||
const messageID = Identifier.ascending("message")
|
|
||||||
const result = await Session.chat({
|
const result = await Session.chat({
|
||||||
messageID,
|
messageID,
|
||||||
sessionID: session.id,
|
sessionID: session.id,
|
||||||
@@ -62,7 +55,7 @@ export const TaskTool = Tool.define({
|
|||||||
return {
|
return {
|
||||||
title: params.description,
|
title: params.description,
|
||||||
metadata: {
|
metadata: {
|
||||||
summary: summary(result.parts),
|
summary: result.parts.filter((x) => x.type === "tool"),
|
||||||
},
|
},
|
||||||
output: result.parts.findLast((x) => x.type === "text")!.text,
|
output: result.parts.findLast((x) => x.type === "text")!.text,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,10 +305,8 @@ func renderToolDetails(
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if toolCall.State.Status == opencode.ToolPartStateStatusPending ||
|
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
|
||||||
toolCall.State.Status == opencode.ToolPartStateStatusRunning {
|
|
||||||
title := renderToolTitle(toolCall, width)
|
title := renderToolTitle(toolCall, width)
|
||||||
title = styles.NewStyle().Width(width - 6).Render(title)
|
|
||||||
return renderContentBlock(app, title, highlight, width)
|
return renderContentBlock(app, title, highlight, width)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,128 +337,124 @@ func renderToolDetails(
|
|||||||
borderColor = t.BorderActive()
|
borderColor = t.BorderActive()
|
||||||
}
|
}
|
||||||
|
|
||||||
if toolCall.State.Status == opencode.ToolPartStateStatusCompleted {
|
metadata := toolCall.State.Metadata.(map[string]any)
|
||||||
metadata := toolCall.State.Metadata.(map[string]any)
|
switch toolCall.Tool {
|
||||||
switch toolCall.Tool {
|
case "read":
|
||||||
case "read":
|
preview := metadata["preview"]
|
||||||
preview := metadata["preview"]
|
if preview != nil && toolInputMap["filePath"] != nil {
|
||||||
if preview != nil && toolInputMap["filePath"] != nil {
|
filename := toolInputMap["filePath"].(string)
|
||||||
filename := toolInputMap["filePath"].(string)
|
body = preview.(string)
|
||||||
body = preview.(string)
|
body = util.RenderFile(filename, body, width, util.WithTruncate(6))
|
||||||
body = util.RenderFile(filename, body, width, util.WithTruncate(6))
|
}
|
||||||
}
|
case "edit":
|
||||||
case "edit":
|
if filename, ok := toolInputMap["filePath"].(string); ok {
|
||||||
if filename, ok := toolInputMap["filePath"].(string); ok {
|
diffField := metadata["diff"]
|
||||||
diffField := metadata["diff"]
|
if diffField != nil {
|
||||||
if diffField != nil {
|
patch := diffField.(string)
|
||||||
patch := diffField.(string)
|
var formattedDiff string
|
||||||
var formattedDiff string
|
formattedDiff, _ = diff.FormatUnifiedDiff(
|
||||||
formattedDiff, _ = diff.FormatUnifiedDiff(
|
filename,
|
||||||
filename,
|
patch,
|
||||||
patch,
|
diff.WithWidth(width-2),
|
||||||
diff.WithWidth(width-2),
|
)
|
||||||
)
|
body = strings.TrimSpace(formattedDiff)
|
||||||
body = strings.TrimSpace(formattedDiff)
|
style := styles.NewStyle().
|
||||||
style := styles.NewStyle().
|
Background(backgroundColor).
|
||||||
Background(backgroundColor).
|
Foreground(t.TextMuted()).
|
||||||
Foreground(t.TextMuted()).
|
Padding(1, 2).
|
||||||
Padding(1, 2).
|
Width(width - 4)
|
||||||
Width(width - 4)
|
if highlight {
|
||||||
if highlight {
|
style = style.Foreground(t.Text()).Bold(true)
|
||||||
style = style.Foreground(t.Text()).Bold(true)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if diagnostics := renderDiagnostics(metadata, filename); diagnostics != "" {
|
if diagnostics := renderDiagnostics(metadata, filename); diagnostics != "" {
|
||||||
diagnostics = style.Render(diagnostics)
|
diagnostics = style.Render(diagnostics)
|
||||||
body += "\n" + diagnostics
|
body += "\n" + diagnostics
|
||||||
}
|
}
|
||||||
|
|
||||||
title := renderToolTitle(toolCall, width)
|
title := renderToolTitle(toolCall, width)
|
||||||
title = style.Render(title)
|
title = style.Render(title)
|
||||||
content := title + "\n" + body
|
content := title + "\n" + body
|
||||||
content = renderContentBlock(
|
content = renderContentBlock(
|
||||||
app,
|
app,
|
||||||
content,
|
content,
|
||||||
highlight,
|
highlight,
|
||||||
width,
|
width,
|
||||||
WithPadding(0),
|
WithPadding(0),
|
||||||
WithBorderColor(borderColor),
|
WithBorderColor(borderColor),
|
||||||
)
|
)
|
||||||
return content
|
return content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "write":
|
||||||
|
if filename, ok := toolInputMap["filePath"].(string); ok {
|
||||||
|
if content, ok := toolInputMap["content"].(string); ok {
|
||||||
|
body = util.RenderFile(filename, content, width)
|
||||||
|
if diagnostics := renderDiagnostics(metadata, filename); diagnostics != "" {
|
||||||
|
body += "\n\n" + diagnostics
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "write":
|
}
|
||||||
if filename, ok := toolInputMap["filePath"].(string); ok {
|
case "bash":
|
||||||
if content, ok := toolInputMap["content"].(string); ok {
|
stdout := metadata["stdout"]
|
||||||
body = util.RenderFile(filename, content, width)
|
if stdout != nil {
|
||||||
if diagnostics := renderDiagnostics(metadata, filename); diagnostics != "" {
|
command := toolInputMap["command"].(string)
|
||||||
body += "\n\n" + diagnostics
|
body = fmt.Sprintf("```console\n> %s\n%s```", command, stdout)
|
||||||
}
|
body = util.ToMarkdown(body, width, backgroundColor)
|
||||||
}
|
}
|
||||||
}
|
case "webfetch":
|
||||||
case "bash":
|
if format, ok := toolInputMap["format"].(string); ok && result != nil {
|
||||||
stdout := metadata["stdout"]
|
|
||||||
if stdout != nil {
|
|
||||||
command := toolInputMap["command"].(string)
|
|
||||||
body = fmt.Sprintf("```console\n> %s\n%s```", command, stdout)
|
|
||||||
body = util.ToMarkdown(body, width, backgroundColor)
|
|
||||||
}
|
|
||||||
case "webfetch":
|
|
||||||
if format, ok := toolInputMap["format"].(string); ok && result != nil {
|
|
||||||
body = *result
|
|
||||||
body = util.TruncateHeight(body, 10)
|
|
||||||
if format == "html" || format == "markdown" {
|
|
||||||
body = util.ToMarkdown(body, width, backgroundColor)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "todowrite":
|
|
||||||
todos := metadata["todos"]
|
|
||||||
if todos != nil {
|
|
||||||
for _, item := range todos.([]any) {
|
|
||||||
todo := item.(map[string]any)
|
|
||||||
content := todo["content"].(string)
|
|
||||||
switch todo["status"] {
|
|
||||||
case "completed":
|
|
||||||
body += fmt.Sprintf("- [x] %s\n", content)
|
|
||||||
case "cancelled":
|
|
||||||
// strike through cancelled todo
|
|
||||||
body += fmt.Sprintf("- [~] ~~%s~~\n", content)
|
|
||||||
case "in_progress":
|
|
||||||
// highlight in progress todo
|
|
||||||
body += fmt.Sprintf("- [ ] `%s`\n", content)
|
|
||||||
default:
|
|
||||||
body += fmt.Sprintf("- [ ] %s\n", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body = util.ToMarkdown(body, width, backgroundColor)
|
|
||||||
}
|
|
||||||
case "task":
|
|
||||||
summary := metadata["summary"]
|
|
||||||
if summary != nil {
|
|
||||||
toolcalls := summary.([]any)
|
|
||||||
steps := []string{}
|
|
||||||
for _, toolcall := range toolcalls {
|
|
||||||
call := toolcall.(map[string]any)
|
|
||||||
if toolInvocation, ok := call["toolInvocation"].(map[string]any); ok {
|
|
||||||
data, _ := json.Marshal(toolInvocation)
|
|
||||||
var toolCall opencode.ToolPart
|
|
||||||
_ = json.Unmarshal(data, &toolCall)
|
|
||||||
step := renderToolTitle(toolCall, width)
|
|
||||||
step = "∟ " + step
|
|
||||||
steps = append(steps, step)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body = strings.Join(steps, "\n")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if result == nil {
|
|
||||||
empty := ""
|
|
||||||
result = &empty
|
|
||||||
}
|
|
||||||
body = *result
|
body = *result
|
||||||
body = util.TruncateHeight(body, 10)
|
body = util.TruncateHeight(body, 10)
|
||||||
body = styles.NewStyle().Width(width - 6).Render(body)
|
if format == "html" || format == "markdown" {
|
||||||
|
body = util.ToMarkdown(body, width, backgroundColor)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
case "todowrite":
|
||||||
|
todos := metadata["todos"]
|
||||||
|
if todos != nil {
|
||||||
|
for _, item := range todos.([]any) {
|
||||||
|
todo := item.(map[string]any)
|
||||||
|
content := todo["content"].(string)
|
||||||
|
switch todo["status"] {
|
||||||
|
case "completed":
|
||||||
|
body += fmt.Sprintf("- [x] %s\n", content)
|
||||||
|
case "cancelled":
|
||||||
|
// strike through cancelled todo
|
||||||
|
body += fmt.Sprintf("- [~] ~~%s~~\n", content)
|
||||||
|
case "in_progress":
|
||||||
|
// highlight in progress todo
|
||||||
|
body += fmt.Sprintf("- [ ] `%s`\n", content)
|
||||||
|
default:
|
||||||
|
body += fmt.Sprintf("- [ ] %s\n", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body = util.ToMarkdown(body, width, backgroundColor)
|
||||||
|
}
|
||||||
|
case "task":
|
||||||
|
summary := metadata["summary"]
|
||||||
|
if summary != nil {
|
||||||
|
toolcalls := summary.([]any)
|
||||||
|
steps := []string{}
|
||||||
|
for _, item := range toolcalls {
|
||||||
|
data, _ := json.Marshal(item)
|
||||||
|
var toolCall opencode.ToolPart
|
||||||
|
_ = json.Unmarshal(data, &toolCall)
|
||||||
|
step := renderToolTitle(toolCall, width)
|
||||||
|
step = "∟ " + step
|
||||||
|
steps = append(steps, step)
|
||||||
|
}
|
||||||
|
body = strings.Join(steps, "\n")
|
||||||
|
}
|
||||||
|
body = styles.NewStyle().Width(width - 6).Render(body)
|
||||||
|
default:
|
||||||
|
if result == nil {
|
||||||
|
empty := ""
|
||||||
|
result = &empty
|
||||||
|
}
|
||||||
|
body = *result
|
||||||
|
body = util.TruncateHeight(body, 10)
|
||||||
|
body = styles.NewStyle().Width(width - 6).Render(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
error := ""
|
error := ""
|
||||||
@@ -539,10 +533,9 @@ func renderToolTitle(
|
|||||||
toolCall opencode.ToolPart,
|
toolCall opencode.ToolPart,
|
||||||
width int,
|
width int,
|
||||||
) string {
|
) string {
|
||||||
// TODO: handle truncate to width
|
|
||||||
|
|
||||||
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
|
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
|
||||||
return renderToolAction(toolCall.Tool)
|
title := renderToolAction(toolCall.Tool)
|
||||||
|
return styles.NewStyle().Width(width - 6).Render(title)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolArgs := ""
|
toolArgs := ""
|
||||||
@@ -596,7 +589,7 @@ func renderToolTitle(
|
|||||||
func renderToolAction(name string) string {
|
func renderToolAction(name string) string {
|
||||||
switch name {
|
switch name {
|
||||||
case "task":
|
case "task":
|
||||||
return "Searching..."
|
return "Planning..."
|
||||||
case "bash":
|
case "bash":
|
||||||
return "Writing command..."
|
return "Writing command..."
|
||||||
case "edit":
|
case "edit":
|
||||||
|
|||||||
Reference in New Issue
Block a user