bring back task tool

This commit is contained in:
Dax Raad
2025-07-15 00:05:54 -04:00
parent 294a11752e
commit 4b2ce14ff3
7 changed files with 387 additions and 396 deletions

View File

@@ -21,7 +21,7 @@ import { AuthCopilot } from "../auth/copilot"
import { ModelsDev } from "./models" import { ModelsDev } from "./models"
import { NamedError } from "../util/error" import { NamedError } from "../util/error"
import { Auth } from "../auth" import { Auth } from "../auth"
// import { TaskTool } from "../tool/task" import { TaskTool } from "../tool/task"
export namespace Provider { export namespace Provider {
const log = Log.create({ service: "provider" }) const log = Log.create({ service: "provider" })
@@ -456,7 +456,7 @@ export namespace Provider {
WriteTool, WriteTool,
TodoWriteTool, TodoWriteTool,
TodoReadTool, TodoReadTool,
// TaskTool, TaskTool,
] ]
const TOOL_MAPPING: Record<string, Tool.Info[]> = { const TOOL_MAPPING: Record<string, Tool.Info[]> = {
@@ -531,12 +531,4 @@ export namespace Provider {
providerID: z.string(), providerID: z.string(),
}), }),
) )
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
} }

View File

@@ -443,7 +443,7 @@ export namespace Session {
const result = await ReadTool.execute(args, { const result = await ReadTool.execute(args, {
sessionID: input.sessionID, sessionID: input.sessionID,
abort: abort.signal, abort: abort.signal,
messageID: "", // read tool doesn't use message ID messageID: userMsg.id,
metadata: async () => {}, metadata: async () => {},
}) })
return [ return [
@@ -577,20 +577,22 @@ export namespace Session {
await updateMessage(assistantMsg) await updateMessage(assistantMsg)
const tools: Record<string, AITool> = {} const tools: Record<string, AITool> = {}
const processor = createProcessor(assistantMsg, model.info)
for (const item of await Provider.tools(input.providerID)) { for (const item of await Provider.tools(input.providerID)) {
if (mode.tools[item.id] === false) continue if (mode.tools[item.id] === false) continue
if (session.parentID && item.id === "task") continue
tools[item.id] = tool({ tools[item.id] = tool({
id: item.id as any, id: item.id as any,
description: item.description, description: item.description,
inputSchema: item.parameters as ZodSchema, inputSchema: item.parameters as ZodSchema,
async execute(args) { async execute(args, options) {
const result = await item.execute(args, { const result = await item.execute(args, {
sessionID: input.sessionID, sessionID: input.sessionID,
abort: abort.signal, abort: abort.signal,
messageID: assistantMsg.id, messageID: assistantMsg.id,
metadata: async () => { metadata: async (val) => {
/* const match = processor.partFromToolCall(options.toolCallId)
const match = toolCalls[opts.toolCallId]
if (match && match.state.status === "running") { if (match && match.state.status === "running") {
await updatePart({ await updatePart({
...match, ...match,
@@ -598,14 +600,13 @@ export namespace Session {
title: val.title, title: val.title,
metadata: val.metadata, metadata: val.metadata,
status: "running", status: "running",
input: args.input, input: args,
time: { time: {
start: Date.now(), start: Date.now(),
}, },
}, },
}) })
} }
*/
}, },
}) })
return result return result
@@ -676,18 +677,19 @@ export namespace Session {
], ],
}), }),
}) })
const result = await processStream(assistantMsg, model.info, stream) const result = await processor.process(stream)
return result return result
} }
async function processStream( function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
assistantMsg: MessageV2.Assistant, const toolCalls: Record<string, MessageV2.ToolPart> = {}
model: ModelsDev.Model, return {
stream: StreamTextResult<Record<string, AITool>, never>, partFromToolCall(toolCallID: string) {
) { return toolCalls[toolCallID]
},
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
try { try {
let currentText: MessageV2.TextPart | undefined let currentText: MessageV2.TextPart | undefined
const toolCalls: Record<string, MessageV2.ToolPart> = {}
for await (const value of stream.fullStream) { for await (const value of stream.fullStream) {
log.info("part", { log.info("part", {
@@ -888,7 +890,7 @@ export namespace Session {
assistantMsg.error = e assistantMsg.error = e
break break
case LoadAPIKeyError.isInstance(e): case LoadAPIKeyError.isInstance(e):
assistantMsg.error = new Provider.AuthError( assistantMsg.error = new MessageV2.AuthError(
{ {
providerID: model.id, providerID: model.id,
message: e.message, message: e.message,
@@ -927,6 +929,8 @@ export namespace Session {
assistantMsg.time.completed = Date.now() assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg) await updateMessage(assistantMsg)
return { info: assistantMsg, parts: p } return { info: assistantMsg, parts: p }
},
}
} }
export async function revert(_input: { sessionID: string; messageID: string; part: number }) { export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
@@ -1006,6 +1010,7 @@ export namespace Session {
} }
await updateMessage(next) await updateMessage(next)
const processor = createProcessor(next, model.info)
const stream = streamText({ const stream = streamText({
abortSignal: abort.signal, abortSignal: abort.signal,
model: model.language, model: model.language,
@@ -1029,7 +1034,7 @@ export namespace Session {
], ],
}) })
const result = await processStream(next, model.info, stream) const result = await processor.process(stream)
return result return result
} }

View File

@@ -1,6 +1,5 @@
import z from "zod" import z from "zod"
import { Bus } from "../bus" import { Bus } from "../bus"
import { Provider } from "../provider/provider"
import { NamedError } from "../util/error" import { NamedError } from "../util/error"
import { Message } from "./message" import { Message } from "./message"
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai" import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
@@ -9,6 +8,13 @@ import { Identifier } from "../id/id"
export namespace MessageV2 { export namespace MessageV2 {
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({})) export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
export const AbortedError = NamedError.create("MessageAbortedError", z.object({})) export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
export const ToolStatePending = z export const ToolStatePending = z
.object({ .object({
@@ -173,7 +179,7 @@ export namespace MessageV2 {
}), }),
error: z error: z
.discriminatedUnion("name", [ .discriminatedUnion("name", [
Provider.AuthError.Schema, AuthError.Schema,
NamedError.Unknown.Schema, NamedError.Unknown.Schema,
OutputLengthError.Schema, OutputLengthError.Schema,
AbortedError.Schema, AbortedError.Schema,

View File

@@ -1,9 +1,15 @@
import z from "zod" import z from "zod"
import { Provider } from "../provider/provider"
import { NamedError } from "../util/error" import { NamedError } from "../util/error"
export namespace Message { export namespace Message {
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({})) export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
export const ToolCall = z export const ToolCall = z
.object({ .object({
@@ -134,11 +140,7 @@ export namespace Message {
completed: z.number().optional(), completed: z.number().optional(),
}), }),
error: z error: z
.discriminatedUnion("name", [ .discriminatedUnion("name", [AuthError.Schema, NamedError.Unknown.Schema, OutputLengthError.Schema])
Provider.AuthError.Schema,
NamedError.Unknown.Schema,
OutputLengthError.Schema,
])
.optional(), .optional(),
sessionID: z.string(), sessionID: z.string(),
tool: z.record( tool: z.record(

View File

@@ -129,7 +129,7 @@ export namespace Storage {
cwd: path.join(dir, prefix), cwd: path.join(dir, prefix),
onlyFiles: true, onlyFiles: true,
}), }),
) ).then((items) => items.map((item) => path.join(prefix, item.slice(0, -5))))
result.sort() result.sort()
return result return result
} catch { } catch {

View File

@@ -15,21 +15,15 @@ export const TaskTool = Tool.define({
}), }),
async execute(params, ctx) { async execute(params, ctx) {
const session = await Session.create(ctx.sessionID) const session = await Session.create(ctx.sessionID)
const msg = (await Session.getMessage(ctx.sessionID, ctx.messageID)) as MessageV2.Assistant const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
if (msg.role !== "assistant") throw new Error("Not an assistant message")
const parts: Record<string, MessageV2.Part> = {}
function summary(input: MessageV2.Part[]) {
const result = []
for (const part of input) {
if (part.type === "tool" && part.state.status === "completed") {
result.push(part)
}
}
return result
}
const messageID = Identifier.ascending("message")
const parts: Record<string, MessageV2.ToolPart> = {}
const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => { const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
if (evt.properties.part.sessionID !== session.id) return if (evt.properties.part.sessionID !== session.id) return
if (evt.properties.part.messageID === messageID) return
if (evt.properties.part.type !== "tool") return
parts[evt.properties.part.id] = evt.properties.part parts[evt.properties.part.id] = evt.properties.part
ctx.metadata({ ctx.metadata({
title: params.description, title: params.description,
@@ -42,7 +36,6 @@ export const TaskTool = Tool.define({
ctx.abort.addEventListener("abort", () => { ctx.abort.addEventListener("abort", () => {
Session.abort(session.id) Session.abort(session.id)
}) })
const messageID = Identifier.ascending("message")
const result = await Session.chat({ const result = await Session.chat({
messageID, messageID,
sessionID: session.id, sessionID: session.id,
@@ -62,7 +55,7 @@ export const TaskTool = Tool.define({
return { return {
title: params.description, title: params.description,
metadata: { metadata: {
summary: summary(result.parts), summary: result.parts.filter((x) => x.type === "tool"),
}, },
output: result.parts.findLast((x) => x.type === "text")!.text, output: result.parts.findLast((x) => x.type === "text")!.text,
} }

View File

@@ -305,10 +305,8 @@ func renderToolDetails(
return "" return ""
} }
if toolCall.State.Status == opencode.ToolPartStateStatusPending || if toolCall.State.Status == opencode.ToolPartStateStatusPending {
toolCall.State.Status == opencode.ToolPartStateStatusRunning {
title := renderToolTitle(toolCall, width) title := renderToolTitle(toolCall, width)
title = styles.NewStyle().Width(width - 6).Render(title)
return renderContentBlock(app, title, highlight, width) return renderContentBlock(app, title, highlight, width)
} }
@@ -339,7 +337,6 @@ func renderToolDetails(
borderColor = t.BorderActive() borderColor = t.BorderActive()
} }
if toolCall.State.Status == opencode.ToolPartStateStatusCompleted {
metadata := toolCall.State.Metadata.(map[string]any) metadata := toolCall.State.Metadata.(map[string]any)
switch toolCall.Tool { switch toolCall.Tool {
case "read": case "read":
@@ -439,19 +436,17 @@ func renderToolDetails(
if summary != nil { if summary != nil {
toolcalls := summary.([]any) toolcalls := summary.([]any)
steps := []string{} steps := []string{}
for _, toolcall := range toolcalls { for _, item := range toolcalls {
call := toolcall.(map[string]any) data, _ := json.Marshal(item)
if toolInvocation, ok := call["toolInvocation"].(map[string]any); ok {
data, _ := json.Marshal(toolInvocation)
var toolCall opencode.ToolPart var toolCall opencode.ToolPart
_ = json.Unmarshal(data, &toolCall) _ = json.Unmarshal(data, &toolCall)
step := renderToolTitle(toolCall, width) step := renderToolTitle(toolCall, width)
step = "∟ " + step step = "∟ " + step
steps = append(steps, step) steps = append(steps, step)
} }
}
body = strings.Join(steps, "\n") body = strings.Join(steps, "\n")
} }
body = styles.NewStyle().Width(width - 6).Render(body)
default: default:
if result == nil { if result == nil {
empty := "" empty := ""
@@ -461,7 +456,6 @@ func renderToolDetails(
body = util.TruncateHeight(body, 10) body = util.TruncateHeight(body, 10)
body = styles.NewStyle().Width(width - 6).Render(body) body = styles.NewStyle().Width(width - 6).Render(body)
} }
}
error := "" error := ""
if toolCall.State.Status == opencode.ToolPartStateStatusError { if toolCall.State.Status == opencode.ToolPartStateStatusError {
@@ -539,10 +533,9 @@ func renderToolTitle(
toolCall opencode.ToolPart, toolCall opencode.ToolPart,
width int, width int,
) string { ) string {
// TODO: handle truncate to width
if toolCall.State.Status == opencode.ToolPartStateStatusPending { if toolCall.State.Status == opencode.ToolPartStateStatusPending {
return renderToolAction(toolCall.Tool) title := renderToolAction(toolCall.Tool)
return styles.NewStyle().Width(width - 6).Render(title)
} }
toolArgs := "" toolArgs := ""
@@ -596,7 +589,7 @@ func renderToolTitle(
func renderToolAction(name string) string { func renderToolAction(name string) string {
switch name { switch name {
case "task": case "task":
return "Searching..." return "Planning..."
case "bash": case "bash":
return "Writing command..." return "Writing command..."
case "edit": case "edit":