wip: refactoring tui

This commit is contained in:
adamdottv
2025-05-30 15:34:22 -05:00
parent f5e2c596d4
commit c69c9327da
13 changed files with 244 additions and 263 deletions

View File

@@ -3,7 +3,6 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"log/slog" "log/slog"
@@ -20,20 +19,14 @@ import (
type App struct { type App struct {
Client *client.ClientWithResponses Client *client.ClientWithResponses
Events *client.Client Events *client.Client
Provider *client.ProviderInfo
Model *client.ProviderModel
Session *client.SessionInfo Session *client.SessionInfo
Messages []client.MessageInfo Messages []client.MessageInfo
Status status.Service
LogsOLD any // TODO: Define LogService interface when needed
HistoryOLD any // TODO: Define HistoryService interface when needed
PermissionsOLD any // TODO: Define PermissionService interface when needed
Status status.Service
PrimaryAgentOLD AgentService PrimaryAgentOLD AgentService
watcherCancelFuncs []context.CancelFunc
cancelFuncsMutex sync.Mutex
watcherWG sync.WaitGroup
// UI state // UI state
filepickerOpen bool filepickerOpen bool
completionDialogOpen bool completionDialogOpen bool
@@ -70,13 +63,9 @@ func New(ctx context.Context) (*App, error) {
Client: httpClient, Client: httpClient,
Events: eventClient, Events: eventClient,
Session: &client.SessionInfo{}, Session: &client.SessionInfo{},
Messages: []client.MessageInfo{},
PrimaryAgentOLD: agentBridge, PrimaryAgentOLD: agentBridge,
Status: status.GetService(), Status: status.GetService(),
// TODO: These services need API endpoints:
LogsOLD: nil, // logging.GetService(),
HistoryOLD: nil, // history.GetService(),
PermissionsOLD: nil, // permission.GetService(),
} }
// Initialize theme based on configuration // Initialize theme based on configuration
@@ -128,13 +117,12 @@ func (a *App) SendChatMessage(ctx context.Context, text string, attachments []At
go a.Client.PostSessionChatWithResponse(ctx, client.PostSessionChatJSONRequestBody{ go a.Client.PostSessionChatWithResponse(ctx, client.PostSessionChatJSONRequestBody{
SessionID: a.Session.Id, SessionID: a.Session.Id,
Parts: parts, Parts: parts,
ProviderID: "anthropic", ProviderID: a.Provider.Id,
ModelID: "claude-sonnet-4-20250514", ModelID: a.Model.Id,
}) })
// The actual response will come through SSE // The actual response will come through SSE
// For now, just return success // For now, just return success
return tea.Batch(cmds...) return tea.Batch(cmds...)
} }
@@ -169,6 +157,22 @@ func (a *App) ListMessages(ctx context.Context, sessionId string) ([]client.Mess
return messages, nil return messages, nil
} }
func (a *App) ListProviders(ctx context.Context) ([]client.ProviderInfo, error) {
resp, err := a.Client.PostProviderListWithResponse(ctx)
if err != nil {
return nil, err
}
if resp.StatusCode() != 200 {
return nil, fmt.Errorf("failed to list sessions: %d", resp.StatusCode())
}
if resp.JSON200 == nil {
return []client.ProviderInfo{}, nil
}
providers := *resp.JSON200
return providers, nil
}
// initTheme sets the application theme based on the configuration // initTheme sets the application theme based on the configuration
func (app *App) initTheme() { func (app *App) initTheme() {
cfg := config.Get() cfg := config.Get()
@@ -207,11 +211,5 @@ func (app *App) SetCompletionDialogOpen(open bool) {
// Shutdown performs a clean shutdown of the application // Shutdown performs a clean shutdown of the application
func (app *App) Shutdown() { func (app *App) Shutdown() {
// Cancel all watcher goroutines // TODO: cleanup?
app.cancelFuncsMutex.Lock()
for _, cancel := range app.watcherCancelFuncs {
cancel()
}
app.cancelFuncsMutex.Unlock()
app.watcherWG.Wait()
} }

View File

@@ -17,33 +17,6 @@ func NewAgentServiceBridge(client *client.ClientWithResponses) *AgentServiceBrid
return &AgentServiceBridge{client: client} return &AgentServiceBridge{client: client}
} }
// Run sends a message to the chat API
func (a *AgentServiceBridge) Run(ctx context.Context, sessionID string, text string, attachments ...Attachment) (string, error) {
// TODO: Handle attachments when API supports them
if len(attachments) > 0 {
// For now, ignore attachments
// return "", fmt.Errorf("attachments not supported yet")
}
part := client.MessagePart{}
part.FromMessagePartText(client.MessagePartText{
Type: "text",
Text: text,
})
parts := []client.MessagePart{part}
go a.client.PostSessionChatWithResponse(ctx, client.PostSessionChatJSONRequestBody{
SessionID: sessionID,
Parts: parts,
ProviderID: "anthropic",
ModelID: "claude-sonnet-4-20250514",
})
// The actual response will come through SSE
// For now, just return success
return "", nil
}
// Cancel cancels the current generation - NOT IMPLEMENTED IN API YET // Cancel cancels the current generation - NOT IMPLEMENTED IN API YET
func (a *AgentServiceBridge) Cancel(sessionID string) error { func (a *AgentServiceBridge) Cancel(sessionID string) error {
// TODO: Not implemented in TypeScript API yet // TODO: Not implemented in TypeScript API yet

View File

@@ -6,7 +6,6 @@ import (
// AgentService defines the interface for agent operations // AgentService defines the interface for agent operations
type AgentService interface { type AgentService interface {
Run(ctx context.Context, sessionID string, text string, attachments ...Attachment) (string, error)
Cancel(sessionID string) error Cancel(sessionID string) error
IsBusy() bool IsBusy() bool
IsSessionBusy(sessionID string) bool IsSessionBusy(sessionID string) bool

View File

@@ -335,7 +335,10 @@ func (m *statusCmp) projectDiagnostics() string {
func (m statusCmp) model() string { func (m statusCmp) model() string {
t := theme.CurrentTheme() t := theme.CurrentTheme()
model := "Claude Sonnet 4" // models.SupportedModels[coder.Model] model := "None"
if m.app.Model != nil {
model = *m.app.Model.Name
}
return styles.Padded(). return styles.Padded().
Background(t.Secondary()). Background(t.Secondary()).

View File

@@ -1,14 +1,18 @@
package dialog package dialog
import ( import (
"context"
"fmt"
"github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/sst/opencode/internal/config" "github.com/sst/opencode/internal/tui/app"
"github.com/sst/opencode/internal/tui/layout" "github.com/sst/opencode/internal/tui/layout"
"github.com/sst/opencode/internal/tui/styles" "github.com/sst/opencode/internal/tui/styles"
"github.com/sst/opencode/internal/tui/theme" "github.com/sst/opencode/internal/tui/theme"
"github.com/sst/opencode/internal/tui/util" "github.com/sst/opencode/internal/tui/util"
"github.com/sst/opencode/pkg/client"
) )
const ( const (
@@ -16,24 +20,25 @@ const (
maxDialogWidth = 40 maxDialogWidth = 40
) )
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
// Model models.Model
}
// CloseModelDialogMsg is sent when a model is selected // CloseModelDialogMsg is sent when a model is selected
type CloseModelDialogMsg struct{} type CloseModelDialogMsg struct {
Provider *client.ProviderInfo
Model *client.ProviderModel
}
// ModelDialog interface for the model selection dialog // ModelDialog interface for the model selection dialog
type ModelDialog interface { type ModelDialog interface {
tea.Model tea.Model
layout.Bindings layout.Bindings
SetProviders(providers []client.ProviderInfo)
} }
type modelDialogCmp struct { type modelDialogCmp struct {
// models []models.Model app *app.App
// provider models.ModelProvider availableProviders []client.ProviderInfo
// availableProviders []models.ModelProvider provider client.ProviderInfo
model *client.ProviderModel
selectedIdx int selectedIdx int
width int width int
@@ -100,10 +105,28 @@ var modelKeys = modelKeyMap{
} }
func (m *modelDialogCmp) Init() tea.Cmd { func (m *modelDialogCmp) Init() tea.Cmd {
m.setupModels() // cfg := config.Get()
// modelInfo := GetSelectedModel(cfg)
// m.availableProviders = getEnabledProviders(cfg)
// m.hScrollPossible = len(m.availableProviders) > 1
// m.provider = modelInfo.Provider
// m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
// m.setupModelsForProvider(m.provider)
m.availableProviders, _ = m.app.ListProviders(context.Background())
m.hScrollOffset = 0
m.hScrollPossible = len(m.availableProviders) > 1
m.provider = m.availableProviders[m.hScrollOffset]
return nil return nil
} }
func (m *modelDialogCmp) SetProviders(providers []client.ProviderInfo) {
m.availableProviders = providers
}
func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.KeyMsg: case tea.KeyMsg:
@@ -121,7 +144,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.switchProvider(1) m.switchProvider(1)
} }
case key.Matches(msg, modelKeys.Enter): case key.Matches(msg, modelKeys.Enter):
// return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]}) return m, util.CmdHandler(CloseModelDialogMsg{Provider: &m.provider, Model: &m.provider.Models[m.selectedIdx]})
case key.Matches(msg, modelKeys.Escape): case key.Matches(msg, modelKeys.Escape):
return m, util.CmdHandler(CloseModelDialogMsg{}) return m, util.CmdHandler(CloseModelDialogMsg{})
} }
@@ -138,8 +161,8 @@ func (m *modelDialogCmp) moveSelectionUp() {
if m.selectedIdx > 0 { if m.selectedIdx > 0 {
m.selectedIdx-- m.selectedIdx--
} else { } else {
// m.selectedIdx = len(m.models) - 1 m.selectedIdx = len(m.provider.Models) - 1
// m.scrollOffset = max(0, len(m.models)-numVisibleModels) m.scrollOffset = max(0, len(m.provider.Models)-numVisibleModels)
} }
// Keep selection visible // Keep selection visible
@@ -150,12 +173,12 @@ func (m *modelDialogCmp) moveSelectionUp() {
// moveSelectionDown moves the selection down or wraps to top // moveSelectionDown moves the selection down or wraps to top
func (m *modelDialogCmp) moveSelectionDown() { func (m *modelDialogCmp) moveSelectionDown() {
// if m.selectedIdx < len(m.models)-1 { if m.selectedIdx < len(m.provider.Models)-1 {
// m.selectedIdx++ m.selectedIdx++
// } else { } else {
// m.selectedIdx = 0 m.selectedIdx = 0
// m.scrollOffset = 0 m.scrollOffset = 0
// } }
// Keep selection visible // Keep selection visible
if m.selectedIdx >= m.scrollOffset+numVisibleModels { if m.selectedIdx >= m.scrollOffset+numVisibleModels {
@@ -167,16 +190,16 @@ func (m *modelDialogCmp) switchProvider(offset int) {
newOffset := m.hScrollOffset + offset newOffset := m.hScrollOffset + offset
// Ensure we stay within bounds // Ensure we stay within bounds
// if newOffset < 0 { if newOffset < 0 {
// newOffset = len(m.availableProviders) - 1 newOffset = len(m.availableProviders) - 1
// } }
// if newOffset >= len(m.availableProviders) { if newOffset >= len(m.availableProviders) {
// newOffset = 0 newOffset = 0
// } }
m.hScrollOffset = newOffset m.hScrollOffset = newOffset
// m.provider = m.availableProviders[m.hScrollOffset] m.provider = m.availableProviders[m.hScrollOffset]
// m.setupModelsForProvider(m.provider) m.setupModelsForProvider(m.provider.Id)
} }
func (m *modelDialogCmp) View() string { func (m *modelDialogCmp) View() string {
@@ -184,33 +207,32 @@ func (m *modelDialogCmp) View() string {
baseStyle := styles.BaseStyle() baseStyle := styles.BaseStyle()
// Capitalize first letter of provider name // Capitalize first letter of provider name
// providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:]) title := baseStyle.
// title := baseStyle. Foreground(t.Primary()).
// Foreground(t.Primary()). Bold(true).
// Bold(true). Width(maxDialogWidth).
// Width(maxDialogWidth). Padding(0, 0, 1).
// Padding(0, 0, 1). Render(fmt.Sprintf("Select %s Model", m.provider.Name))
// Render(fmt.Sprintf("Select %s Model", providerName))
// Render visible models // Render visible models
// endIdx := min(m.scrollOffset+numVisibleModels, len(m.models)) endIdx := min(m.scrollOffset+numVisibleModels, len(m.provider.Models))
// modelItems := make([]string, 0, endIdx-m.scrollOffset) modelItems := make([]string, 0, endIdx-m.scrollOffset)
//
// for i := m.scrollOffset; i < endIdx; i++ { for i := m.scrollOffset; i < endIdx; i++ {
// itemStyle := baseStyle.Width(maxDialogWidth) itemStyle := baseStyle.Width(maxDialogWidth)
// if i == m.selectedIdx { if i == m.selectedIdx {
// itemStyle = itemStyle.Background(t.Primary()). itemStyle = itemStyle.Background(t.Primary()).
// Foreground(t.Background()).Bold(true) Foreground(t.Background()).Bold(true)
// } }
// modelItems = append(modelItems, itemStyle.Render(m.models[i].Name)) modelItems = append(modelItems, itemStyle.Render(*m.provider.Models[i].Name))
// } }
scrollIndicator := m.getScrollIndicators(maxDialogWidth) scrollIndicator := m.getScrollIndicators(maxDialogWidth)
content := lipgloss.JoinVertical( content := lipgloss.JoinVertical(
lipgloss.Left, lipgloss.Left,
// title, title,
// baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)), baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
scrollIndicator, scrollIndicator,
) )
@@ -225,22 +247,22 @@ func (m *modelDialogCmp) View() string {
func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string { func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
var indicator string var indicator string
// if len(m.models) > numVisibleModels { if len(m.provider.Models) > numVisibleModels {
// if m.scrollOffset > 0 { if m.scrollOffset > 0 {
// indicator += "↑ " indicator += "↑ "
// } }
// if m.scrollOffset+numVisibleModels < len(m.models) { if m.scrollOffset+numVisibleModels < len(m.provider.Models) {
// indicator += "↓ " indicator += "↓ "
// } }
// } }
if m.hScrollPossible { if m.hScrollPossible {
if m.hScrollOffset > 0 { if m.hScrollOffset > 0 {
indicator = "← " + indicator indicator = "← " + indicator
} }
// if m.hScrollOffset < len(m.availableProviders)-1 { if m.hScrollOffset < len(m.availableProviders)-1 {
// indicator += "→" indicator += "→"
// } }
} }
if indicator == "" { if indicator == "" {
@@ -262,70 +284,26 @@ func (m *modelDialogCmp) BindingKeys() []key.Binding {
return layout.KeyMapToSlice(modelKeys) return layout.KeyMapToSlice(modelKeys)
} }
func (m *modelDialogCmp) setupModels() {
// cfg := config.Get()
// modelInfo := GetSelectedModel(cfg)
// m.availableProviders = getEnabledProviders(cfg)
// m.hScrollPossible = len(m.availableProviders) > 1
//
// m.provider = modelInfo.Provider
// m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
//
// m.setupModelsForProvider(m.provider)
}
func GetSelectedModel(cfg *config.Config) string {
return "Claude Sonnet 4"
// agentCfg := cfg.Agents[config.AgentPrimary]
// selectedModelId := agentCfg.Model
// return models.SupportedModels[selectedModelId]
}
func getEnabledProviders(cfg *config.Config) []string {
return []string{"anthropic", "openai", "google"}
// var providers []models.ModelProvider
// for providerId, provider := range cfg.Providers {
// if !provider.Disabled {
// providers = append(providers, providerId)
// }
// }
//
// // Sort by provider popularity
// slices.SortFunc(providers, func(a, b models.ModelProvider) int {
// rA := models.ProviderPopularity[a]
// rB := models.ProviderPopularity[b]
//
// // models not included in popularity ranking default to last
// if rA == 0 {
// rA = 999
// }
// if rB == 0 {
// rB = 999
// }
// return rA - rB
// })
// return providers
}
// findProviderIndex returns the index of the provider in the list, or -1 if not found // findProviderIndex returns the index of the provider in the list, or -1 if not found
func findProviderIndex(providers []string, provider string) int { // func findProviderIndex(providers []string, provider string) int {
for i, p := range providers { // for i, p := range providers {
if p == provider { // if p == provider {
return i // return i
} // }
} // }
return -1 // return -1
} // }
func (m *modelDialogCmp) setupModelsForProvider(_ string) {
m.selectedIdx = 0
m.scrollOffset = 0
func (m *modelDialogCmp) setupModelsForProvider(provider string) {
// cfg := config.Get() // cfg := config.Get()
// agentCfg := cfg.Agents[config.AgentPrimary] // agentCfg := cfg.Agents[config.AgentPrimary]
// selectedModelId := agentCfg.Model // selectedModelId := agentCfg.Model
// m.provider = provider // m.provider = provider
// m.models = getModelsForProvider(provider) // m.models = getModelsForProvider(provider)
m.selectedIdx = 0
m.scrollOffset = 0
// Try to select the current model if it belongs to this provider // Try to select the current model if it belongs to this provider
// if provider == models.SupportedModels[selectedModelId].Provider { // if provider == models.SupportedModels[selectedModelId].Provider {
@@ -342,28 +320,8 @@ func (m *modelDialogCmp) setupModelsForProvider(provider string) {
// } // }
} }
func getModelsForProvider(provider string) []string { func NewModelDialogCmp(app *app.App) ModelDialog {
return []string{"Claude Sonnet 4"} return &modelDialogCmp{
// var providerModels []models.Model app: app,
// for _, model := range models.SupportedModels { }
// if model.Provider == provider {
// providerModels = append(providerModels, model)
// }
// }
// reverse alphabetical order (if llm naming was consistent latest would appear first)
// slices.SortFunc(providerModels, func(a, b models.Model) int {
// if a.Name > b.Name {
// return -1
// } else if a.Name < b.Name {
// return 1
// }
// return 0
// })
// return providerModels
}
func NewModelDialogCmp() ModelDialog {
return &modelDialogCmp{}
} }

View File

@@ -5,8 +5,15 @@ import (
) )
type SessionSelectedMsg = *client.SessionInfo type SessionSelectedMsg = *client.SessionInfo
type ModelSelectedMsg struct {
Provider client.ProviderInfo
Model client.ProviderModel
}
type SessionClearedMsg struct{} type SessionClearedMsg struct{}
type CompactSessionMsg struct{} type CompactSessionMsg struct{}
// TODO: remove
type StateUpdatedMsg struct { type StateUpdatedMsg struct {
State map[string]any State map[string]any
} }

View File

@@ -168,16 +168,27 @@ func (a appModel) Init() tea.Cmd {
return dialog.ShowInitDialogMsg{Show: shouldShow} return dialog.ShowInitDialogMsg{Show: shouldShow}
}) })
cmds = append(cmds, func() tea.Msg {
providers, _ := a.app.ListProviders(context.Background())
return state.ModelSelectedMsg{Provider: providers[0], Model: providers[0].Models[0]}
})
return tea.Batch(cmds...) return tea.Batch(cmds...)
} }
func (a appModel) updateAllPages(msg tea.Msg) (tea.Model, tea.Cmd) { func (a appModel) updateAllPages(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
var cmd tea.Cmd var cmd tea.Cmd
for id := range a.pages { for id := range a.pages {
a.pages[id], cmd = a.pages[id].Update(msg) a.pages[id], cmd = a.pages[id].Update(msg)
cmds = append(cmds, cmd) cmds = append(cmds, cmd)
} }
s, cmd := a.status.Update(msg)
cmds = append(cmds, cmd)
a.status = s.(core.StatusCmp)
return a, tea.Batch(cmds...) return a, tea.Batch(cmds...)
} }
@@ -201,12 +212,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
for i, m := range a.app.Messages { for i, m := range a.app.Messages {
if m.Id == msg.Properties.Info.Id { if m.Id == msg.Properties.Info.Id {
a.app.Messages[i] = msg.Properties.Info a.app.Messages[i] = msg.Properties.Info
slog.Debug("Updated message", "message", msg.Properties.Info)
return a.updateAllPages(state.StateUpdatedMsg{State: nil}) return a.updateAllPages(state.StateUpdatedMsg{State: nil})
} }
} }
a.app.Messages = append(a.app.Messages, msg.Properties.Info) a.app.Messages = append(a.app.Messages, msg.Properties.Info)
slog.Debug("Appended message", "message", msg.Properties.Info)
return a.updateAllPages(state.StateUpdatedMsg{State: nil}) return a.updateAllPages(state.StateUpdatedMsg{State: nil})
} }
@@ -287,6 +296,19 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.app.Messages, _ = a.app.ListMessages(context.Background(), msg.Id) a.app.Messages, _ = a.app.ListMessages(context.Background(), msg.Id)
return a.updateAllPages(msg) return a.updateAllPages(msg)
case dialog.CloseModelDialogMsg:
a.showModelDialog = false
slog.Debug("closing model dialog", "msg", msg)
if msg.Provider != nil && msg.Model != nil {
return a, util.CmdHandler(state.ModelSelectedMsg{Provider: *msg.Provider, Model: *msg.Model})
}
return a, nil
case state.ModelSelectedMsg:
a.app.Provider = &msg.Provider
a.app.Model = &msg.Model
return a.updateAllPages(msg)
case dialog.CloseCommandDialogMsg: case dialog.CloseCommandDialogMsg:
a.showCommandDialog = false a.showCommandDialog = false
return a, nil return a, nil
@@ -309,24 +331,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
status.Info("Theme changed to: " + msg.ThemeName) status.Info("Theme changed to: " + msg.ThemeName)
return a, cmd return a, cmd
case dialog.CloseModelDialogMsg:
a.showModelDialog = false
return a, nil
case dialog.ModelSelectedMsg:
a.showModelDialog = false
// TODO: Agent model update not implemented in API yet
// model, err := a.app.PrimaryAgent.Update(config.AgentPrimary, msg.Model.ID)
// if err != nil {
// status.Error(err.Error())
// return a, nil
// }
// status.Info(fmt.Sprintf("Model changed to %s", model.Name))
status.Info("Model selection not implemented in API yet")
return a, nil
case dialog.ShowInitDialogMsg: case dialog.ShowInitDialogMsg:
a.showInitDialog = msg.Show a.showInitDialog = msg.Show
return a, nil return a, nil
@@ -476,6 +480,18 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.showThemeDialog = false a.showThemeDialog = false
a.showFilepicker = false a.showFilepicker = false
// Load providers and show the dialog
providers, err := a.app.ListProviders(context.Background())
if err != nil {
status.Error(err.Error())
return a, nil
}
if len(providers) == 0 {
status.Warn("No providers available")
return a, nil
}
a.modelDialog.SetProviders(providers)
a.showModelDialog = true a.showModelDialog = true
return a, nil return a, nil
} }
@@ -907,7 +923,7 @@ func New(app *app.App) tea.Model {
quit: dialog.NewQuitCmp(), quit: dialog.NewQuitCmp(),
sessionDialog: dialog.NewSessionDialogCmp(), sessionDialog: dialog.NewSessionDialogCmp(),
commandDialog: dialog.NewCommandDialogCmp(), commandDialog: dialog.NewCommandDialogCmp(),
modelDialog: dialog.NewModelDialogCmp(), modelDialog: dialog.NewModelDialogCmp(app),
permissions: dialog.NewPermissionDialogCmp(), permissions: dialog.NewPermissionDialogCmp(),
initDialog: dialog.NewInitDialogCmp(), initDialog: dialog.NewInitDialogCmp(),
themeDialog: dialog.NewThemeDialogCmp(), themeDialog: dialog.NewThemeDialogCmp(),

View File

@@ -14,7 +14,7 @@ export namespace Config {
export const Info = z export const Info = z
.object({ .object({
providers: z.record(z.string(), Provider.Info).optional(), providers: Provider.Info.array().optional(),
}) })
.strict(); .strict();

View File

@@ -1,6 +1,6 @@
import { App } from "../app/app"; import { App } from "../app/app";
import { Log } from "../util/log"; import { Log } from "../util/log";
import { mergeDeep } from "remeda"; import { concat } from "remeda";
import path from "path"; import path from "path";
import { Provider } from "../provider/provider"; import { Provider } from "../provider/provider";
@@ -19,26 +19,32 @@ export namespace LLM {
} }
} }
const NATIVE_PROVIDERS: Record<string, Provider.Info> = { const NATIVE_PROVIDERS: Provider.Info[] = [
anthropic: { {
models: { id: "anthropic",
"claude-sonnet-4-20250514": { name: "Anthropic",
name: "Claude 4 Sonnet", models: [
{
id: "claude-sonnet-4-20250514",
name: "Claude Sonnet 4",
cost: { cost: {
input: 3.0 / 1_000_000, input: 3.0 / 1_000_000,
output: 15.0 / 1_000_000, output: 15.0 / 1_000_000,
inputCached: 3.75 / 1_000_000, inputCached: 3.75 / 1_000_000,
outputCached: 0.3 / 1_000_000, outputCached: 0.3 / 1_000_000,
}, },
contextWindow: 200000, contextWindow: 200_000,
maxTokens: 50000, maxOutputTokens: 50_000,
attachment: true, attachment: true,
}, },
}, ],
}, },
openai: { {
models: { id: "openai",
"codex-mini-latest": { name: "OpenAI",
models: [
{
id: "codex-mini-latest",
name: "Codex Mini", name: "Codex Mini",
cost: { cost: {
input: 1.5 / 1_000_000, input: 1.5 / 1_000_000,
@@ -46,16 +52,19 @@ export namespace LLM {
output: 6.0 / 1_000_000, output: 6.0 / 1_000_000,
outputCached: 0.0 / 1_000_000, outputCached: 0.0 / 1_000_000,
}, },
contextWindow: 200000, contextWindow: 200_000,
maxTokens: 100000, maxOutputTokens: 100_000,
attachment: true, attachment: true,
reasoning: true, reasoning: true,
}, },
}, ],
}, },
google: { {
models: { id: "google",
"gemini-2.5-pro-preview-03-25": { name: "Google",
models: [
{
id: "gemini-2.5-pro-preview-03-25",
name: "Gemini 2.5 Pro", name: "Gemini 2.5 Pro",
cost: { cost: {
input: 1.25 / 1_000_000, input: 1.25 / 1_000_000,
@@ -63,18 +72,18 @@ export namespace LLM {
output: 10 / 1_000_000, output: 10 / 1_000_000,
outputCached: 0 / 1_000_000, outputCached: 0 / 1_000_000,
}, },
contextWindow: 1000000, contextWindow: 1_000_000,
maxTokens: 50000, maxOutputTokens: 50_000,
attachment: true, attachment: true,
}, },
}, ],
}, },
}; ];
const AUTODETECT: Record<string, string[]> = { const AUTODETECT: Record<string, string[]> = {
anthropic: ["ANTHROPIC_API_KEY"], anthropic: ["ANTHROPIC_API_KEY"],
openai: ["OPENAI_API_KEY"], openai: ["OPENAI_API_KEY"],
google: ["GOOGLE_GENERATIVE_AI_API_KEY"], google: ["GOOGLE_GENERATIVE_AI_API_KEY", "GEMINI_API_KEY"],
}; };
const state = App.state("llm", async () => { const state = App.state("llm", async () => {
@@ -91,33 +100,33 @@ export namespace LLM {
{ info: Provider.Model; instance: LanguageModel } { info: Provider.Model; instance: LanguageModel }
>(); >();
const list = mergeDeep(NATIVE_PROVIDERS, config.providers ?? {}); const list = concat(NATIVE_PROVIDERS, config.providers ?? []);
for (const [providerID, providerInfo] of Object.entries(list)) { for (const provider of list) {
if ( if (
!config.providers?.[providerID] && !config.providers?.find((p) => p.id === provider.id) &&
!AUTODETECT[providerID]?.some((env) => process.env[env]) !AUTODETECT[provider.id]?.some((env) => process.env[env])
) )
continue; continue;
const dir = path.join( const dir = path.join(
Global.cache(), Global.cache(),
`node_modules`, `node_modules`,
`@ai-sdk`, `@ai-sdk`,
providerID, provider.id,
); );
if (!(await Bun.file(path.join(dir, "package.json")).exists())) { if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
BunProc.run(["add", "--exact", `@ai-sdk/${providerID}@alpha`], { BunProc.run(["add", "--exact", `@ai-sdk/${provider.id}@alpha`], {
cwd: Global.cache(), cwd: Global.cache(),
}); });
} }
const mod = await import( const mod = await import(
path.join(Global.cache(), `node_modules`, `@ai-sdk`, providerID) path.join(Global.cache(), `node_modules`, `@ai-sdk`, provider.id)
); );
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]; const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!];
const loaded = fn(providerInfo.options); const loaded = fn(provider.options);
log.info("loaded", { provider: providerID }); log.info("loaded", { provider: provider.id });
providers[providerID] = { providers[provider.id] = {
info: providerInfo, info: provider,
instance: loaded, instance: loaded,
}; };
} }
@@ -142,7 +151,7 @@ export namespace LLM {
providerID, providerID,
modelID, modelID,
}); });
const info = provider.info.models[modelID]; const info = provider.info.models.find((m) => m.id === modelID);
if (!info) throw new ModelNotFoundError(modelID); if (!info) throw new ModelNotFoundError(modelID);
try { try {
const match = provider.instance.languageModel(modelID); const match = provider.instance.languageModel(modelID);

View File

@@ -3,6 +3,7 @@ import z from "zod";
export namespace Provider { export namespace Provider {
export const Model = z export const Model = z
.object({ .object({
id: z.string(),
name: z.string().optional(), name: z.string().optional(),
cost: z.object({ cost: z.object({
input: z.number(), input: z.number(),
@@ -22,8 +23,10 @@ export namespace Provider {
export const Info = z export const Info = z
.object({ .object({
id: z.string(),
name: z.string(),
options: z.record(z.string(), z.any()).optional(), options: z.record(z.string(), z.any()).optional(),
models: z.record(z.string(), Model), models: Model.array(),
}) })
.openapi({ .openapi({
ref: "Provider.Info", ref: "Provider.Info",

View File

@@ -263,7 +263,7 @@ export namespace Server {
description: "List of providers", description: "List of providers",
content: { content: {
"application/json": { "application/json": {
schema: resolver(z.record(z.string(), Provider.Info)), schema: resolver(Provider.Info.array()),
}, },
}, },
}, },
@@ -271,9 +271,9 @@ export namespace Server {
}), }),
async (c) => { async (c) => {
const providers = await LLM.providers(); const providers = await LLM.providers();
const result: Record<string, Provider.Info> = {}; const result = [] as (Provider.Info & { key: string })[];
for (const [providerID, provider] of Object.entries(providers)) { for (const [key, provider] of Object.entries(providers)) {
result[providerID] = provider.info; result.push({ ...provider.info, key });
} }
return c.json(result); return c.json(result);
}, },

View File

@@ -280,8 +280,8 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"type": "object", "type": "array",
"additionalProperties": { "items": {
"$ref": "#/components/schemas/Provider.Info" "$ref": "#/components/schemas/Provider.Info"
} }
} }
@@ -818,24 +818,35 @@
"Provider.Info": { "Provider.Info": {
"type": "object", "type": "object",
"properties": { "properties": {
"id": {
"type": "string"
},
"name": {
"type": "string"
},
"options": { "options": {
"type": "object", "type": "object",
"additionalProperties": {} "additionalProperties": {}
}, },
"models": { "models": {
"type": "object", "type": "array",
"additionalProperties": { "items": {
"$ref": "#/components/schemas/Provider.Model" "$ref": "#/components/schemas/Provider.Model"
} }
} }
}, },
"required": [ "required": [
"id",
"name",
"models" "models"
] ]
}, },
"Provider.Model": { "Provider.Model": {
"type": "object", "type": "object",
"properties": { "properties": {
"id": {
"type": "string"
},
"name": { "name": {
"type": "string" "type": "string"
}, },
@@ -876,6 +887,7 @@
} }
}, },
"required": [ "required": [
"id",
"cost", "cost",
"contextWindow", "contextWindow",
"attachment" "attachment"

View File

@@ -173,8 +173,10 @@ type MessageToolInvocationToolResult struct {
// ProviderInfo defines model for Provider.Info. // ProviderInfo defines model for Provider.Info.
type ProviderInfo struct { type ProviderInfo struct {
Models map[string]ProviderModel `json:"models"` Id string `json:"id"`
Options *map[string]interface{} `json:"options,omitempty"` Models []ProviderModel `json:"models"`
Name string `json:"name"`
Options *map[string]interface{} `json:"options,omitempty"`
} }
// ProviderModel defines model for Provider.Model. // ProviderModel defines model for Provider.Model.
@@ -187,6 +189,7 @@ type ProviderModel struct {
Output float32 `json:"output"` Output float32 `json:"output"`
OutputCached float32 `json:"outputCached"` OutputCached float32 `json:"outputCached"`
} `json:"cost"` } `json:"cost"`
Id string `json:"id"`
MaxOutputTokens *float32 `json:"maxOutputTokens,omitempty"` MaxOutputTokens *float32 `json:"maxOutputTokens,omitempty"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
Reasoning *bool `json:"reasoning,omitempty"` Reasoning *bool `json:"reasoning,omitempty"`
@@ -1421,7 +1424,7 @@ func (r GetEventResponse) StatusCode() int {
type PostProviderListResponse struct { type PostProviderListResponse struct {
Body []byte Body []byte
HTTPResponse *http.Response HTTPResponse *http.Response
JSON200 *map[string]ProviderInfo JSON200 *[]ProviderInfo
} }
// Status returns HTTPResponse.Status // Status returns HTTPResponse.Status
@@ -1756,7 +1759,7 @@ func ParsePostProviderListResponse(rsp *http.Response) (*PostProviderListRespons
switch { switch {
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200:
var dest map[string]ProviderInfo var dest []ProviderInfo
if err := json.Unmarshal(bodyBytes, &dest); err != nil { if err := json.Unmarshal(bodyBytes, &dest); err != nil {
return nil, err return nil, err
} }