bring back task tool

This commit is contained in:
Dax Raad
2025-07-15 00:05:54 -04:00
parent 294a11752e
commit 4b2ce14ff3
7 changed files with 387 additions and 396 deletions

View File

@@ -21,7 +21,7 @@ import { AuthCopilot } from "../auth/copilot"
import { ModelsDev } from "./models"
import { NamedError } from "../util/error"
import { Auth } from "../auth"
// import { TaskTool } from "../tool/task"
import { TaskTool } from "../tool/task"
export namespace Provider {
const log = Log.create({ service: "provider" })
@@ -456,7 +456,7 @@ export namespace Provider {
WriteTool,
TodoWriteTool,
TodoReadTool,
// TaskTool,
TaskTool,
]
const TOOL_MAPPING: Record<string, Tool.Info[]> = {
@@ -531,12 +531,4 @@ export namespace Provider {
providerID: z.string(),
}),
)
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
}

View File

@@ -443,7 +443,7 @@ export namespace Session {
const result = await ReadTool.execute(args, {
sessionID: input.sessionID,
abort: abort.signal,
messageID: "", // read tool doesn't use message ID
messageID: userMsg.id,
metadata: async () => {},
})
return [
@@ -577,20 +577,22 @@ export namespace Session {
await updateMessage(assistantMsg)
const tools: Record<string, AITool> = {}
const processor = createProcessor(assistantMsg, model.info)
for (const item of await Provider.tools(input.providerID)) {
if (mode.tools[item.id] === false) continue
if (session.parentID && item.id === "task") continue
tools[item.id] = tool({
id: item.id as any,
description: item.description,
inputSchema: item.parameters as ZodSchema,
async execute(args) {
async execute(args, options) {
const result = await item.execute(args, {
sessionID: input.sessionID,
abort: abort.signal,
messageID: assistantMsg.id,
metadata: async () => {
/*
const match = toolCalls[opts.toolCallId]
metadata: async (val) => {
const match = processor.partFromToolCall(options.toolCallId)
if (match && match.state.status === "running") {
await updatePart({
...match,
@@ -598,14 +600,13 @@ export namespace Session {
title: val.title,
metadata: val.metadata,
status: "running",
input: args.input,
input: args,
time: {
start: Date.now(),
},
},
})
}
*/
},
})
return result
@@ -676,18 +677,19 @@ export namespace Session {
],
}),
})
const result = await processStream(assistantMsg, model.info, stream)
const result = await processor.process(stream)
return result
}
async function processStream(
assistantMsg: MessageV2.Assistant,
model: ModelsDev.Model,
stream: StreamTextResult<Record<string, AITool>, never>,
) {
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
const toolCalls: Record<string, MessageV2.ToolPart> = {}
return {
partFromToolCall(toolCallID: string) {
return toolCalls[toolCallID]
},
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
try {
let currentText: MessageV2.TextPart | undefined
const toolCalls: Record<string, MessageV2.ToolPart> = {}
for await (const value of stream.fullStream) {
log.info("part", {
@@ -888,7 +890,7 @@ export namespace Session {
assistantMsg.error = e
break
case LoadAPIKeyError.isInstance(e):
assistantMsg.error = new Provider.AuthError(
assistantMsg.error = new MessageV2.AuthError(
{
providerID: model.id,
message: e.message,
@@ -927,6 +929,8 @@ export namespace Session {
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
return { info: assistantMsg, parts: p }
},
}
}
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
@@ -1006,6 +1010,7 @@ export namespace Session {
}
await updateMessage(next)
const processor = createProcessor(next, model.info)
const stream = streamText({
abortSignal: abort.signal,
model: model.language,
@@ -1029,7 +1034,7 @@ export namespace Session {
],
})
const result = await processStream(next, model.info, stream)
const result = await processor.process(stream)
return result
}

View File

@@ -1,6 +1,5 @@
import z from "zod"
import { Bus } from "../bus"
import { Provider } from "../provider/provider"
import { NamedError } from "../util/error"
import { Message } from "./message"
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
@@ -9,6 +8,13 @@ import { Identifier } from "../id/id"
export namespace MessageV2 {
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
export const ToolStatePending = z
.object({
@@ -173,7 +179,7 @@ export namespace MessageV2 {
}),
error: z
.discriminatedUnion("name", [
Provider.AuthError.Schema,
AuthError.Schema,
NamedError.Unknown.Schema,
OutputLengthError.Schema,
AbortedError.Schema,

View File

@@ -1,9 +1,15 @@
import z from "zod"
import { Provider } from "../provider/provider"
import { NamedError } from "../util/error"
export namespace Message {
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
export const ToolCall = z
.object({
@@ -134,11 +140,7 @@ export namespace Message {
completed: z.number().optional(),
}),
error: z
.discriminatedUnion("name", [
Provider.AuthError.Schema,
NamedError.Unknown.Schema,
OutputLengthError.Schema,
])
.discriminatedUnion("name", [AuthError.Schema, NamedError.Unknown.Schema, OutputLengthError.Schema])
.optional(),
sessionID: z.string(),
tool: z.record(

View File

@@ -129,7 +129,7 @@ export namespace Storage {
cwd: path.join(dir, prefix),
onlyFiles: true,
}),
)
).then((items) => items.map((item) => path.join(prefix, item.slice(0, -5))))
result.sort()
return result
} catch {

View File

@@ -15,21 +15,15 @@ export const TaskTool = Tool.define({
}),
async execute(params, ctx) {
const session = await Session.create(ctx.sessionID)
const msg = (await Session.getMessage(ctx.sessionID, ctx.messageID)) as MessageV2.Assistant
const parts: Record<string, MessageV2.Part> = {}
function summary(input: MessageV2.Part[]) {
const result = []
for (const part of input) {
if (part.type === "tool" && part.state.status === "completed") {
result.push(part)
}
}
return result
}
const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
if (msg.role !== "assistant") throw new Error("Not an assistant message")
const messageID = Identifier.ascending("message")
const parts: Record<string, MessageV2.ToolPart> = {}
const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
if (evt.properties.part.sessionID !== session.id) return
if (evt.properties.part.messageID === messageID) return
if (evt.properties.part.type !== "tool") return
parts[evt.properties.part.id] = evt.properties.part
ctx.metadata({
title: params.description,
@@ -42,7 +36,6 @@ export const TaskTool = Tool.define({
ctx.abort.addEventListener("abort", () => {
Session.abort(session.id)
})
const messageID = Identifier.ascending("message")
const result = await Session.chat({
messageID,
sessionID: session.id,
@@ -62,7 +55,7 @@ export const TaskTool = Tool.define({
return {
title: params.description,
metadata: {
summary: summary(result.parts),
summary: result.parts.filter((x) => x.type === "tool"),
},
output: result.parts.findLast((x) => x.type === "text")!.text,
}

View File

@@ -305,10 +305,8 @@ func renderToolDetails(
return ""
}
if toolCall.State.Status == opencode.ToolPartStateStatusPending ||
toolCall.State.Status == opencode.ToolPartStateStatusRunning {
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
title := renderToolTitle(toolCall, width)
title = styles.NewStyle().Width(width - 6).Render(title)
return renderContentBlock(app, title, highlight, width)
}
@@ -339,7 +337,6 @@ func renderToolDetails(
borderColor = t.BorderActive()
}
if toolCall.State.Status == opencode.ToolPartStateStatusCompleted {
metadata := toolCall.State.Metadata.(map[string]any)
switch toolCall.Tool {
case "read":
@@ -439,19 +436,17 @@ func renderToolDetails(
if summary != nil {
toolcalls := summary.([]any)
steps := []string{}
for _, toolcall := range toolcalls {
call := toolcall.(map[string]any)
if toolInvocation, ok := call["toolInvocation"].(map[string]any); ok {
data, _ := json.Marshal(toolInvocation)
for _, item := range toolcalls {
data, _ := json.Marshal(item)
var toolCall opencode.ToolPart
_ = json.Unmarshal(data, &toolCall)
step := renderToolTitle(toolCall, width)
step = "∟ " + step
steps = append(steps, step)
}
}
body = strings.Join(steps, "\n")
}
body = styles.NewStyle().Width(width - 6).Render(body)
default:
if result == nil {
empty := ""
@@ -461,7 +456,6 @@ func renderToolDetails(
body = util.TruncateHeight(body, 10)
body = styles.NewStyle().Width(width - 6).Render(body)
}
}
error := ""
if toolCall.State.Status == opencode.ToolPartStateStatusError {
@@ -539,10 +533,9 @@ func renderToolTitle(
toolCall opencode.ToolPart,
width int,
) string {
// TODO: handle truncate to width
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
return renderToolAction(toolCall.Tool)
title := renderToolAction(toolCall.Tool)
return styles.NewStyle().Width(width - 6).Render(title)
}
toolArgs := ""
@@ -596,7 +589,7 @@ func renderToolTitle(
func renderToolAction(name string) string {
switch name {
case "task":
return "Searching..."
return "Planning..."
case "bash":
return "Writing command..."
case "edit":