bring back task tool
This commit is contained in:
@@ -21,7 +21,7 @@ import { AuthCopilot } from "../auth/copilot"
|
||||
import { ModelsDev } from "./models"
|
||||
import { NamedError } from "../util/error"
|
||||
import { Auth } from "../auth"
|
||||
// import { TaskTool } from "../tool/task"
|
||||
import { TaskTool } from "../tool/task"
|
||||
|
||||
export namespace Provider {
|
||||
const log = Log.create({ service: "provider" })
|
||||
@@ -456,7 +456,7 @@ export namespace Provider {
|
||||
WriteTool,
|
||||
TodoWriteTool,
|
||||
TodoReadTool,
|
||||
// TaskTool,
|
||||
TaskTool,
|
||||
]
|
||||
|
||||
const TOOL_MAPPING: Record<string, Tool.Info[]> = {
|
||||
@@ -531,12 +531,4 @@ export namespace Provider {
|
||||
providerID: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const AuthError = NamedError.create(
|
||||
"ProviderAuthError",
|
||||
z.object({
|
||||
providerID: z.string(),
|
||||
message: z.string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -443,7 +443,7 @@ export namespace Session {
|
||||
const result = await ReadTool.execute(args, {
|
||||
sessionID: input.sessionID,
|
||||
abort: abort.signal,
|
||||
messageID: "", // read tool doesn't use message ID
|
||||
messageID: userMsg.id,
|
||||
metadata: async () => {},
|
||||
})
|
||||
return [
|
||||
@@ -577,20 +577,22 @@ export namespace Session {
|
||||
await updateMessage(assistantMsg)
|
||||
const tools: Record<string, AITool> = {}
|
||||
|
||||
const processor = createProcessor(assistantMsg, model.info)
|
||||
|
||||
for (const item of await Provider.tools(input.providerID)) {
|
||||
if (mode.tools[item.id] === false) continue
|
||||
if (session.parentID && item.id === "task") continue
|
||||
tools[item.id] = tool({
|
||||
id: item.id as any,
|
||||
description: item.description,
|
||||
inputSchema: item.parameters as ZodSchema,
|
||||
async execute(args) {
|
||||
async execute(args, options) {
|
||||
const result = await item.execute(args, {
|
||||
sessionID: input.sessionID,
|
||||
abort: abort.signal,
|
||||
messageID: assistantMsg.id,
|
||||
metadata: async () => {
|
||||
/*
|
||||
const match = toolCalls[opts.toolCallId]
|
||||
metadata: async (val) => {
|
||||
const match = processor.partFromToolCall(options.toolCallId)
|
||||
if (match && match.state.status === "running") {
|
||||
await updatePart({
|
||||
...match,
|
||||
@@ -598,14 +600,13 @@ export namespace Session {
|
||||
title: val.title,
|
||||
metadata: val.metadata,
|
||||
status: "running",
|
||||
input: args.input,
|
||||
input: args,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
*/
|
||||
},
|
||||
})
|
||||
return result
|
||||
@@ -676,18 +677,19 @@ export namespace Session {
|
||||
],
|
||||
}),
|
||||
})
|
||||
const result = await processStream(assistantMsg, model.info, stream)
|
||||
const result = await processor.process(stream)
|
||||
return result
|
||||
}
|
||||
|
||||
async function processStream(
|
||||
assistantMsg: MessageV2.Assistant,
|
||||
model: ModelsDev.Model,
|
||||
stream: StreamTextResult<Record<string, AITool>, never>,
|
||||
) {
|
||||
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
|
||||
const toolCalls: Record<string, MessageV2.ToolPart> = {}
|
||||
return {
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolCalls[toolCallID]
|
||||
},
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
const toolCalls: Record<string, MessageV2.ToolPart> = {}
|
||||
|
||||
for await (const value of stream.fullStream) {
|
||||
log.info("part", {
|
||||
@@ -888,7 +890,7 @@ export namespace Session {
|
||||
assistantMsg.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
assistantMsg.error = new Provider.AuthError(
|
||||
assistantMsg.error = new MessageV2.AuthError(
|
||||
{
|
||||
providerID: model.id,
|
||||
message: e.message,
|
||||
@@ -927,6 +929,8 @@ export namespace Session {
|
||||
assistantMsg.time.completed = Date.now()
|
||||
await updateMessage(assistantMsg)
|
||||
return { info: assistantMsg, parts: p }
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
|
||||
@@ -1006,6 +1010,7 @@ export namespace Session {
|
||||
}
|
||||
await updateMessage(next)
|
||||
|
||||
const processor = createProcessor(next, model.info)
|
||||
const stream = streamText({
|
||||
abortSignal: abort.signal,
|
||||
model: model.language,
|
||||
@@ -1029,7 +1034,7 @@ export namespace Session {
|
||||
],
|
||||
})
|
||||
|
||||
const result = await processStream(next, model.info, stream)
|
||||
const result = await processor.process(stream)
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import z from "zod"
|
||||
import { Bus } from "../bus"
|
||||
import { Provider } from "../provider/provider"
|
||||
import { NamedError } from "../util/error"
|
||||
import { Message } from "./message"
|
||||
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
|
||||
@@ -9,6 +8,13 @@ import { Identifier } from "../id/id"
|
||||
export namespace MessageV2 {
|
||||
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
||||
export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
|
||||
export const AuthError = NamedError.create(
|
||||
"ProviderAuthError",
|
||||
z.object({
|
||||
providerID: z.string(),
|
||||
message: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const ToolStatePending = z
|
||||
.object({
|
||||
@@ -173,7 +179,7 @@ export namespace MessageV2 {
|
||||
}),
|
||||
error: z
|
||||
.discriminatedUnion("name", [
|
||||
Provider.AuthError.Schema,
|
||||
AuthError.Schema,
|
||||
NamedError.Unknown.Schema,
|
||||
OutputLengthError.Schema,
|
||||
AbortedError.Schema,
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import z from "zod"
|
||||
import { Provider } from "../provider/provider"
|
||||
import { NamedError } from "../util/error"
|
||||
|
||||
export namespace Message {
|
||||
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
|
||||
export const AuthError = NamedError.create(
|
||||
"ProviderAuthError",
|
||||
z.object({
|
||||
providerID: z.string(),
|
||||
message: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const ToolCall = z
|
||||
.object({
|
||||
@@ -134,11 +140,7 @@ export namespace Message {
|
||||
completed: z.number().optional(),
|
||||
}),
|
||||
error: z
|
||||
.discriminatedUnion("name", [
|
||||
Provider.AuthError.Schema,
|
||||
NamedError.Unknown.Schema,
|
||||
OutputLengthError.Schema,
|
||||
])
|
||||
.discriminatedUnion("name", [AuthError.Schema, NamedError.Unknown.Schema, OutputLengthError.Schema])
|
||||
.optional(),
|
||||
sessionID: z.string(),
|
||||
tool: z.record(
|
||||
|
||||
@@ -129,7 +129,7 @@ export namespace Storage {
|
||||
cwd: path.join(dir, prefix),
|
||||
onlyFiles: true,
|
||||
}),
|
||||
)
|
||||
).then((items) => items.map((item) => path.join(prefix, item.slice(0, -5))))
|
||||
result.sort()
|
||||
return result
|
||||
} catch {
|
||||
|
||||
@@ -15,21 +15,15 @@ export const TaskTool = Tool.define({
|
||||
}),
|
||||
async execute(params, ctx) {
|
||||
const session = await Session.create(ctx.sessionID)
|
||||
const msg = (await Session.getMessage(ctx.sessionID, ctx.messageID)) as MessageV2.Assistant
|
||||
|
||||
const parts: Record<string, MessageV2.Part> = {}
|
||||
function summary(input: MessageV2.Part[]) {
|
||||
const result = []
|
||||
for (const part of input) {
|
||||
if (part.type === "tool" && part.state.status === "completed") {
|
||||
result.push(part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
|
||||
if (msg.role !== "assistant") throw new Error("Not an assistant message")
|
||||
|
||||
const messageID = Identifier.ascending("message")
|
||||
const parts: Record<string, MessageV2.ToolPart> = {}
|
||||
const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
|
||||
if (evt.properties.part.sessionID !== session.id) return
|
||||
if (evt.properties.part.messageID === messageID) return
|
||||
if (evt.properties.part.type !== "tool") return
|
||||
parts[evt.properties.part.id] = evt.properties.part
|
||||
ctx.metadata({
|
||||
title: params.description,
|
||||
@@ -42,7 +36,6 @@ export const TaskTool = Tool.define({
|
||||
ctx.abort.addEventListener("abort", () => {
|
||||
Session.abort(session.id)
|
||||
})
|
||||
const messageID = Identifier.ascending("message")
|
||||
const result = await Session.chat({
|
||||
messageID,
|
||||
sessionID: session.id,
|
||||
@@ -62,7 +55,7 @@ export const TaskTool = Tool.define({
|
||||
return {
|
||||
title: params.description,
|
||||
metadata: {
|
||||
summary: summary(result.parts),
|
||||
summary: result.parts.filter((x) => x.type === "tool"),
|
||||
},
|
||||
output: result.parts.findLast((x) => x.type === "text")!.text,
|
||||
}
|
||||
|
||||
@@ -305,10 +305,8 @@ func renderToolDetails(
|
||||
return ""
|
||||
}
|
||||
|
||||
if toolCall.State.Status == opencode.ToolPartStateStatusPending ||
|
||||
toolCall.State.Status == opencode.ToolPartStateStatusRunning {
|
||||
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
|
||||
title := renderToolTitle(toolCall, width)
|
||||
title = styles.NewStyle().Width(width - 6).Render(title)
|
||||
return renderContentBlock(app, title, highlight, width)
|
||||
}
|
||||
|
||||
@@ -339,7 +337,6 @@ func renderToolDetails(
|
||||
borderColor = t.BorderActive()
|
||||
}
|
||||
|
||||
if toolCall.State.Status == opencode.ToolPartStateStatusCompleted {
|
||||
metadata := toolCall.State.Metadata.(map[string]any)
|
||||
switch toolCall.Tool {
|
||||
case "read":
|
||||
@@ -439,19 +436,17 @@ func renderToolDetails(
|
||||
if summary != nil {
|
||||
toolcalls := summary.([]any)
|
||||
steps := []string{}
|
||||
for _, toolcall := range toolcalls {
|
||||
call := toolcall.(map[string]any)
|
||||
if toolInvocation, ok := call["toolInvocation"].(map[string]any); ok {
|
||||
data, _ := json.Marshal(toolInvocation)
|
||||
for _, item := range toolcalls {
|
||||
data, _ := json.Marshal(item)
|
||||
var toolCall opencode.ToolPart
|
||||
_ = json.Unmarshal(data, &toolCall)
|
||||
step := renderToolTitle(toolCall, width)
|
||||
step = "∟ " + step
|
||||
steps = append(steps, step)
|
||||
}
|
||||
}
|
||||
body = strings.Join(steps, "\n")
|
||||
}
|
||||
body = styles.NewStyle().Width(width - 6).Render(body)
|
||||
default:
|
||||
if result == nil {
|
||||
empty := ""
|
||||
@@ -461,7 +456,6 @@ func renderToolDetails(
|
||||
body = util.TruncateHeight(body, 10)
|
||||
body = styles.NewStyle().Width(width - 6).Render(body)
|
||||
}
|
||||
}
|
||||
|
||||
error := ""
|
||||
if toolCall.State.Status == opencode.ToolPartStateStatusError {
|
||||
@@ -539,10 +533,9 @@ func renderToolTitle(
|
||||
toolCall opencode.ToolPart,
|
||||
width int,
|
||||
) string {
|
||||
// TODO: handle truncate to width
|
||||
|
||||
if toolCall.State.Status == opencode.ToolPartStateStatusPending {
|
||||
return renderToolAction(toolCall.Tool)
|
||||
title := renderToolAction(toolCall.Tool)
|
||||
return styles.NewStyle().Width(width - 6).Render(title)
|
||||
}
|
||||
|
||||
toolArgs := ""
|
||||
@@ -596,7 +589,7 @@ func renderToolTitle(
|
||||
func renderToolAction(name string) string {
|
||||
switch name {
|
||||
case "task":
|
||||
return "Searching..."
|
||||
return "Planning..."
|
||||
case "bash":
|
||||
return "Writing command..."
|
||||
case "edit":
|
||||
|
||||
Reference in New Issue
Block a user