rework llm

This commit is contained in:
Kujtim Hoxha
2025-03-27 22:35:48 +01:00
parent 904061c243
commit afd9ad0560
61 changed files with 5882 additions and 2074 deletions

View File

@@ -0,0 +1,309 @@
package provider
import (
"context"
"encoding/json"
"errors"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
type anthropicProvider struct {
client anthropic.Client
model models.Model
maxTokens int64
apiKey string
systemMessage string
}
type AnthropicOption func(*anthropicProvider)
func WithAnthropicSystemMessage(message string) AnthropicOption {
return func(a *anthropicProvider) {
a.systemMessage = message
}
}
func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
return func(a *anthropicProvider) {
a.maxTokens = maxTokens
}
}
func WithAnthropicModel(model models.Model) AnthropicOption {
return func(a *anthropicProvider) {
a.model = model
}
}
func WithAnthropicKey(apiKey string) AnthropicOption {
return func(a *anthropicProvider) {
a.apiKey = apiKey
}
}
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
provider := &anthropicProvider{
maxTokens: 1024,
}
for _, opt := range opts {
opt(provider)
}
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
return provider, nil
}
func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
anthropicMessages := a.convertToAnthropicMessages(messages)
anthropicTools := a.convertToAnthropicTools(tools)
response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(a.model.APIModel),
MaxTokens: a.maxTokens,
Temperature: anthropic.Float(0),
Messages: anthropicMessages,
Tools: anthropicTools,
System: []anthropic.TextBlockParam{
{
Text: a.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
},
},
})
if err != nil {
return nil, err
}
content := ""
for _, block := range response.Content {
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
content += text.Text
}
}
toolCalls := a.extractToolCalls(response.Content)
tokenUsage := a.extractTokenUsage(response.Usage)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
anthropicMessages := a.convertToAnthropicMessages(messages)
anthropicTools := a.convertToAnthropicTools(tools)
var thinkingParam anthropic.ThinkingConfigParamUnion
lastMessage := messages[len(messages)-1]
temperature := anthropic.Float(0)
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
thinkingParam = anthropic.ThinkingConfigParamUnion{
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
Type: "enabled",
},
}
temperature = anthropic.Float(1)
}
stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(a.model.APIModel),
MaxTokens: a.maxTokens,
Temperature: temperature,
Messages: anthropicMessages,
Tools: anthropicTools,
Thinking: thinkingParam,
System: []anthropic.TextBlockParam{
{
Text: a.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
},
},
})
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
accumulatedMessage := anthropic.Message{}
for stream.Next() {
event := stream.Current()
err := accumulatedMessage.Accumulate(event)
if err != nil {
eventChan <- ProviderEvent{Type: EventError, Error: err}
return
}
switch event := event.AsAny().(type) {
case anthropic.ContentBlockStartEvent:
eventChan <- ProviderEvent{Type: EventContentStart}
case anthropic.ContentBlockDeltaEvent:
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
eventChan <- ProviderEvent{
Type: EventThinkingDelta,
Thinking: event.Delta.Thinking,
}
} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: event.Delta.Text,
}
}
case anthropic.ContentBlockStopEvent:
eventChan <- ProviderEvent{Type: EventContentStop}
case anthropic.MessageStopEvent:
content := ""
for _, block := range accumulatedMessage.Content {
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
content += text.Text
}
}
toolCalls := a.extractToolCalls(accumulatedMessage.Content)
tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
},
}
}
}
if stream.Err() != nil {
eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
}
}()
return eventChan, nil
}
func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
var toolCalls []message.ToolCall
for _, block := range content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
toolCall := message.ToolCall{
ID: variant.ID,
Name: variant.Name,
Input: string(variant.Input),
Type: string(variant.Type),
}
toolCalls = append(toolCalls, toolCall)
}
}
return toolCalls
}
func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
return TokenUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationTokens: usage.CacheCreationInputTokens,
CacheReadTokens: usage.CacheReadInputTokens,
}
}
func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
for i, tool := range tools {
info := tool.Info()
toolParam := anthropic.ToolParam{
Name: info.Name,
Description: anthropic.String(info.Description),
InputSchema: anthropic.ToolInputSchemaParam{
Properties: info.Parameters,
},
}
if i == len(tools)-1 {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
}
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
}
return anthropicTools
}
func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
anthropicMessages := make([]anthropic.MessageParam, len(messages))
cachedBlocks := 0
for i, msg := range messages {
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content)
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
cachedBlocks++
}
anthropicMessages[i] = anthropic.NewUserMessage(content)
case message.Assistant:
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content != "" {
content := anthropic.NewTextBlock(msg.Content)
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
cachedBlocks++
}
blocks = append(blocks, content)
}
for _, toolCall := range msg.ToolCalls {
var inputMap map[string]any
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
if err != nil {
continue
}
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
}
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
case message.Tool:
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
for i, toolResult := range msg.ToolResults {
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
}
anthropicMessages[i] = anthropic.NewUserMessage(results...)
}
}
return anthropicMessages
}

View File

@@ -0,0 +1,443 @@
package provider
import (
"context"
"encoding/json"
"errors"
"log"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
type geminiProvider struct {
client *genai.Client
model models.Model
maxTokens int32
apiKey string
systemMessage string
}
type GeminiOption func(*geminiProvider)
func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
provider := &geminiProvider{
maxTokens: 5000,
}
for _, opt := range opts {
opt(provider)
}
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
if err != nil {
return nil, err
}
provider.client = client
return provider, nil
}
func WithGeminiSystemMessage(message string) GeminiOption {
return func(p *geminiProvider) {
p.systemMessage = message
}
}
func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
return func(p *geminiProvider) {
p.maxTokens = maxTokens
}
}
func WithGeminiModel(model models.Model) GeminiOption {
return func(p *geminiProvider) {
p.model = model
}
}
func WithGeminiKey(apiKey string) GeminiOption {
return func(p *geminiProvider) {
p.apiKey = apiKey
}
}
func (p *geminiProvider) Close() {
if p.client != nil {
p.client.Close()
}
}
// convertToGeminiHistory converts the message history to Gemini's format
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
switch msg.Role {
case message.User:
history = append(history, &genai.Content{
Parts: []genai.Part{genai.Text(msg.Content)},
Role: "user",
})
case message.Assistant:
content := &genai.Content{
Role: "model",
Parts: []genai.Part{},
}
// Handle regular content
if msg.Content != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content))
}
// Handle tool calls if any
if len(msg.ToolCalls) > 0 {
for _, call := range msg.ToolCalls {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
Args: args,
})
}
}
history = append(history, content)
case message.Tool:
for _, result := range msg.ToolResults {
// Parse response content to map if possible
response := map[string]interface{}{"result": result.Content}
parsed, err := parseJsonToMap(result.Content)
if err == nil {
response = parsed
}
var toolCall message.ToolCall
for _, msg := range messages {
if msg.Role == message.Assistant {
for _, call := range msg.ToolCalls {
if call.ID == result.ToolCallID {
toolCall = call
break
}
}
}
}
history = append(history, &genai.Content{
Parts: []genai.Part{genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
}},
Role: "function",
})
}
}
}
return history
}
// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
declarations := make([]*genai.FunctionDeclaration, len(tools))
for i, tool := range tools {
info := tool.Info()
// Convert parameters to genai.Schema format
properties := make(map[string]*genai.Schema)
for name, param := range info.Parameters {
// Try to extract type and description from the parameter
paramMap, ok := param.(map[string]interface{})
if !ok {
// Default to string if unable to determine type
properties[name] = &genai.Schema{Type: genai.TypeString}
continue
}
schemaType := genai.TypeString // Default
var description string
var itemsTypeSchema *genai.Schema
if typeVal, found := paramMap["type"]; found {
if typeStr, ok := typeVal.(string); ok {
switch typeStr {
case "string":
schemaType = genai.TypeString
case "number":
schemaType = genai.TypeNumber
case "integer":
schemaType = genai.TypeInteger
case "boolean":
schemaType = genai.TypeBoolean
case "array":
schemaType = genai.TypeArray
items, found := paramMap["items"]
if found {
itemsMap, ok := items.(map[string]interface{})
if ok {
itemsType, found := itemsMap["type"]
if found {
itemsTypeStr, ok := itemsType.(string)
if ok {
switch itemsTypeStr {
case "string":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeString,
}
case "number":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeNumber,
}
case "integer":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeInteger,
}
case "boolean":
itemsTypeSchema = &genai.Schema{
Type: genai.TypeBoolean,
}
}
}
}
}
}
case "object":
schemaType = genai.TypeObject
if _, found := paramMap["properties"]; !found {
continue
}
// TODO: Add support for other types
}
}
}
if desc, found := paramMap["description"]; found {
if descStr, ok := desc.(string); ok {
description = descStr
}
}
properties[name] = &genai.Schema{
Type: schemaType,
Description: description,
Items: itemsTypeSchema,
}
}
declarations[i] = &genai.FunctionDeclaration{
Name: info.Name,
Description: info.Description,
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: properties,
Required: info.Required,
},
}
}
return declarations
}
// extractTokenUsage extracts token usage information from Gemini's response
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
}
return TokenUsage{
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
CacheCreationTokens: 0, // Not directly provided by Gemini
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
}
}
// SendMessages sends a batch of messages to Gemini and returns the response
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
// Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
// Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
}
// Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
// Get the most recent user message
var lastUserMsg message.Message
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == message.User {
lastUserMsg = messages[i]
break
}
}
// Send the message
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
if err != nil {
return nil, err
}
// Process the response
var content string
var toolCalls []message.ToolCall
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
content = string(p)
case genai.FunctionCall:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: p.Name,
Input: string(args),
Type: "function",
})
}
}
}
// Extract token usage
tokenUsage := p.extractTokenUsage(resp)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
// StreamResponse streams the response from Gemini
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
// Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
// Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
// Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
for _, declaration := range declarations {
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
}
}
// Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
// Start streaming
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
var finalResp *genai.GenerateContentResponse
currentContent := ""
toolCalls := []message.ToolCall{}
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
var apiErr *googleapi.Error
if errors.As(err, &apiErr) {
log.Printf("%s", apiErr.Body)
}
eventChan <- ProviderEvent{
Type: EventError,
Error: err,
}
return
}
finalResp = resp
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
newText := string(p)
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: newText,
}
currentContent += newText
case genai.FunctionCall:
// For function calls, we assume they come complete, not streamed in parts
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
newCall := message.ToolCall{
ID: id,
Name: p.Name,
Input: string(args),
Type: "function",
}
// Check if this is a new tool call
isNew := true
for _, existing := range toolCalls {
if existing.Name == newCall.Name && existing.Input == newCall.Input {
isNew = false
break
}
}
if isNew {
toolCalls = append(toolCalls, newCall)
}
}
}
}
}
// Extract token usage from the final response
tokenUsage := p.extractTokenUsage(finalResp)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
},
}
}()
return eventChan, nil
}
// Helper function to parse JSON string into map
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
return result, err
}

View File

@@ -0,0 +1,278 @@
package provider
import (
"context"
"errors"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
type openaiProvider struct {
client openai.Client
model models.Model
maxTokens int64
baseURL string
apiKey string
systemMessage string
}
type OpenAIOption func(*openaiProvider)
func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
provider := &openaiProvider{
maxTokens: 5000,
}
for _, opt := range opts {
opt(provider)
}
clientOpts := []option.RequestOption{
option.WithAPIKey(provider.apiKey),
}
if provider.baseURL != "" {
clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
}
provider.client = openai.NewClient(clientOpts...)
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
return provider, nil
}
func WithOpenAISystemMessage(message string) OpenAIOption {
return func(p *openaiProvider) {
p.systemMessage = message
}
}
func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
return func(p *openaiProvider) {
p.maxTokens = maxTokens
}
}
func WithOpenAIModel(model models.Model) OpenAIOption {
return func(p *openaiProvider) {
p.model = model
}
}
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
return func(p *openaiProvider) {
p.baseURL = baseURL
}
}
func WithOpenAIKey(apiKey string) OpenAIOption {
return func(p *openaiProvider) {
p.apiKey = apiKey
}
}
func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
var chatMessages []openai.ChatCompletionMessageParamUnion
chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
for _, msg := range messages {
switch msg.Role {
case message.User:
chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
Role: "assistant",
}
if msg.Content != "" {
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: openai.String(msg.Content),
}
}
if len(msg.ToolCalls) > 0 {
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
for i, call := range msg.ToolCalls {
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Type: "function",
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: call.Name,
Arguments: call.Input,
},
}
}
}
chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
OfAssistant: &assistantMsg,
})
case message.Tool:
for _, result := range msg.ToolResults {
chatMessages = append(chatMessages,
openai.ToolMessage(result.Content, result.ToolCallID),
)
}
}
}
return chatMessages
}
func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
for i, tool := range tools {
info := tool.Info()
openaiTools[i] = openai.ChatCompletionToolParam{
Function: openai.FunctionDefinitionParam{
Name: info.Name,
Description: openai.String(info.Description),
Parameters: openai.FunctionParameters{
"type": "object",
"properties": info.Parameters,
"required": info.Required,
},
},
}
}
return openaiTools
}
func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
cachedTokens := int64(0)
cachedTokens = usage.PromptTokensDetails.CachedTokens
inputTokens := usage.PromptTokens - cachedTokens
return TokenUsage{
InputTokens: inputTokens,
OutputTokens: usage.CompletionTokens,
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
CacheReadTokens: cachedTokens,
}
}
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
chatMessages := p.convertToOpenAIMessages(messages)
openaiTools := p.convertToOpenAITools(tools)
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(p.model.APIModel),
Messages: chatMessages,
MaxTokens: openai.Int(p.maxTokens),
Tools: openaiTools,
}
response, err := p.client.Chat.Completions.New(ctx, params)
if err != nil {
return nil, err
}
content := ""
if response.Choices[0].Message.Content != "" {
content = response.Choices[0].Message.Content
}
var toolCalls []message.ToolCall
if len(response.Choices[0].Message.ToolCalls) > 0 {
toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
for i, call := range response.Choices[0].Message.ToolCalls {
toolCalls[i] = message.ToolCall{
ID: call.ID,
Name: call.Function.Name,
Input: call.Function.Arguments,
Type: "function",
}
}
}
tokenUsage := p.extractTokenUsage(response.Usage)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
chatMessages := p.convertToOpenAIMessages(messages)
openaiTools := p.convertToOpenAITools(tools)
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(p.model.APIModel),
Messages: chatMessages,
MaxTokens: openai.Int(p.maxTokens),
Tools: openaiTools,
StreamOptions: openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
},
}
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
eventChan := make(chan ProviderEvent)
toolCalls := make([]message.ToolCall, 0)
go func() {
defer close(eventChan)
acc := openai.ChatCompletionAccumulator{}
currentContent := ""
for stream.Next() {
chunk := stream.Current()
acc.AddChunk(chunk)
if tool, ok := acc.JustFinishedToolCall(); ok {
toolCalls = append(toolCalls, message.ToolCall{
ID: tool.Id,
Name: tool.Name,
Input: tool.Arguments,
Type: "function",
})
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: choice.Delta.Content,
}
currentContent += choice.Delta.Content
}
}
}
if err := stream.Err(); err != nil {
eventChan <- ProviderEvent{
Type: EventError,
Error: err,
}
return
}
tokenUsage := p.extractTokenUsage(acc.Usage)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
},
}
}()
return eventChan, nil
}

View File

@@ -0,0 +1,48 @@
package provider
import (
"context"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
// EventType represents the type of streaming event
type EventType string
const (
EventContentStart EventType = "content_start"
EventContentDelta EventType = "content_delta"
EventThinkingDelta EventType = "thinking_delta"
EventContentStop EventType = "content_stop"
EventComplete EventType = "complete"
EventError EventType = "error"
)
type TokenUsage struct {
InputTokens int64
OutputTokens int64
CacheCreationTokens int64
CacheReadTokens int64
}
type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
}
type ProviderEvent struct {
Type EventType
Content string
Thinking string
ToolCall *message.ToolCall
Error error
Response *ProviderResponse
}
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
}