rework llm
This commit is contained in:
309
internal/llm/provider/anthropic.go
Normal file
309
internal/llm/provider/anthropic.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
type anthropicProvider struct {
|
||||
client anthropic.Client
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
apiKey string
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
type AnthropicOption func(*anthropicProvider)
|
||||
|
||||
func WithAnthropicSystemMessage(message string) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicModel(model models.Model) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicKey(apiKey string) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
|
||||
provider := &anthropicProvider{
|
||||
maxTokens: 1024,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
anthropicMessages := a.convertToAnthropicMessages(messages)
|
||||
anthropicTools := a.convertToAnthropicTools(tools)
|
||||
|
||||
response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.model.APIModel),
|
||||
MaxTokens: a.maxTokens,
|
||||
Temperature: anthropic.Float(0),
|
||||
Messages: anthropicMessages,
|
||||
Tools: anthropicTools,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, block := range response.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := a.extractToolCalls(response.Content)
|
||||
tokenUsage := a.extractTokenUsage(response.Usage)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
anthropicMessages := a.convertToAnthropicMessages(messages)
|
||||
anthropicTools := a.convertToAnthropicTools(tools)
|
||||
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
lastMessage := messages[len(messages)-1]
|
||||
temperature := anthropic.Float(0)
|
||||
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
|
||||
thinkingParam = anthropic.ThinkingConfigParamUnion{
|
||||
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
|
||||
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
|
||||
Type: "enabled",
|
||||
},
|
||||
}
|
||||
temperature = anthropic.Float(1)
|
||||
}
|
||||
|
||||
stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.model.APIModel),
|
||||
MaxTokens: a.maxTokens,
|
||||
Temperature: temperature,
|
||||
Messages: anthropicMessages,
|
||||
Tools: anthropicTools,
|
||||
Thinking: thinkingParam,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
accumulatedMessage := anthropic.Message{}
|
||||
|
||||
for stream.Next() {
|
||||
event := stream.Current()
|
||||
err := accumulatedMessage.Accumulate(event)
|
||||
if err != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
return
|
||||
}
|
||||
|
||||
switch event := event.AsAny().(type) {
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
|
||||
case anthropic.ContentBlockDeltaEvent:
|
||||
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventThinkingDelta,
|
||||
Thinking: event.Delta.Thinking,
|
||||
}
|
||||
} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: event.Delta.Text,
|
||||
}
|
||||
}
|
||||
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
eventChan <- ProviderEvent{Type: EventContentStop}
|
||||
|
||||
case anthropic.MessageStopEvent:
|
||||
content := ""
|
||||
for _, block := range accumulatedMessage.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := a.extractToolCalls(accumulatedMessage.Content)
|
||||
tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stream.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
for _, block := range content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
toolCall := message.ToolCall{
|
||||
ID: variant.ID,
|
||||
Name: variant.Name,
|
||||
Input: string(variant.Input),
|
||||
Type: string(variant.Type),
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
|
||||
return TokenUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
|
||||
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
toolParam := anthropic.ToolParam{
|
||||
Name: info.Name,
|
||||
Description: anthropic.String(info.Description),
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: info.Parameters,
|
||||
},
|
||||
}
|
||||
|
||||
if i == len(tools)-1 {
|
||||
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
|
||||
}
|
||||
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
|
||||
anthropicMessages := make([]anthropic.MessageParam, len(messages))
|
||||
cachedBlocks := 0
|
||||
|
||||
for i, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content)
|
||||
if cachedBlocks < 2 {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
anthropicMessages[i] = anthropic.NewUserMessage(content)
|
||||
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content)
|
||||
if cachedBlocks < 2 {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
blocks = append(blocks, content)
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls {
|
||||
var inputMap map[string]any
|
||||
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
|
||||
}
|
||||
|
||||
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
|
||||
|
||||
case message.Tool:
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
|
||||
for i, toolResult := range msg.ToolResults {
|
||||
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
|
||||
}
|
||||
anthropicMessages[i] = anthropic.NewUserMessage(results...)
|
||||
}
|
||||
}
|
||||
|
||||
return anthropicMessages
|
||||
}
|
||||
443
internal/llm/provider/gemini.go
Normal file
443
internal/llm/provider/gemini.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
type geminiProvider struct {
|
||||
client *genai.Client
|
||||
model models.Model
|
||||
maxTokens int32
|
||||
apiKey string
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
type GeminiOption func(*geminiProvider)
|
||||
|
||||
func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
|
||||
provider := &geminiProvider{
|
||||
maxTokens: 5000,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.client = client
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func WithGeminiSystemMessage(message string) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiModel(model models.Model) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiKey(apiKey string) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func (p *geminiProvider) Close() {
|
||||
if p.client != nil {
|
||||
p.client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// convertToGeminiHistory converts the message history to Gemini's format
|
||||
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
|
||||
var history []*genai.Content
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.Text(msg.Content)},
|
||||
Role: "user",
|
||||
})
|
||||
case message.Assistant:
|
||||
content := &genai.Content{
|
||||
Role: "model",
|
||||
Parts: []genai.Part{},
|
||||
}
|
||||
|
||||
// Handle regular content
|
||||
if msg.Content != "" {
|
||||
content.Parts = append(content.Parts, genai.Text(msg.Content))
|
||||
}
|
||||
|
||||
// Handle tool calls if any
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, call := range msg.ToolCalls {
|
||||
args, _ := parseJsonToMap(call.Input)
|
||||
content.Parts = append(content.Parts, genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
history = append(history, content)
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults {
|
||||
// Parse response content to map if possible
|
||||
response := map[string]interface{}{"result": result.Content}
|
||||
parsed, err := parseJsonToMap(result.Content)
|
||||
if err == nil {
|
||||
response = parsed
|
||||
}
|
||||
var toolCall message.ToolCall
|
||||
for _, msg := range messages {
|
||||
if msg.Role == message.Assistant {
|
||||
for _, call := range msg.ToolCalls {
|
||||
if call.ID == result.ToolCallID {
|
||||
toolCall = call
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
}},
|
||||
Role: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
|
||||
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
|
||||
declarations := make([]*genai.FunctionDeclaration, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
|
||||
// Convert parameters to genai.Schema format
|
||||
properties := make(map[string]*genai.Schema)
|
||||
for name, param := range info.Parameters {
|
||||
// Try to extract type and description from the parameter
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if !ok {
|
||||
// Default to string if unable to determine type
|
||||
properties[name] = &genai.Schema{Type: genai.TypeString}
|
||||
continue
|
||||
}
|
||||
|
||||
schemaType := genai.TypeString // Default
|
||||
var description string
|
||||
var itemsTypeSchema *genai.Schema
|
||||
if typeVal, found := paramMap["type"]; found {
|
||||
if typeStr, ok := typeVal.(string); ok {
|
||||
switch typeStr {
|
||||
case "string":
|
||||
schemaType = genai.TypeString
|
||||
case "number":
|
||||
schemaType = genai.TypeNumber
|
||||
case "integer":
|
||||
schemaType = genai.TypeInteger
|
||||
case "boolean":
|
||||
schemaType = genai.TypeBoolean
|
||||
case "array":
|
||||
schemaType = genai.TypeArray
|
||||
items, found := paramMap["items"]
|
||||
if found {
|
||||
itemsMap, ok := items.(map[string]interface{})
|
||||
if ok {
|
||||
itemsType, found := itemsMap["type"]
|
||||
if found {
|
||||
itemsTypeStr, ok := itemsType.(string)
|
||||
if ok {
|
||||
switch itemsTypeStr {
|
||||
case "string":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeString,
|
||||
}
|
||||
case "number":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeNumber,
|
||||
}
|
||||
case "integer":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeInteger,
|
||||
}
|
||||
case "boolean":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeBoolean,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case "object":
|
||||
schemaType = genai.TypeObject
|
||||
if _, found := paramMap["properties"]; !found {
|
||||
continue
|
||||
}
|
||||
// TODO: Add support for other types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if desc, found := paramMap["description"]; found {
|
||||
if descStr, ok := desc.(string); ok {
|
||||
description = descStr
|
||||
}
|
||||
}
|
||||
|
||||
properties[name] = &genai.Schema{
|
||||
Type: schemaType,
|
||||
Description: description,
|
||||
Items: itemsTypeSchema,
|
||||
}
|
||||
}
|
||||
|
||||
declarations[i] = &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: properties,
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return declarations
|
||||
}
|
||||
|
||||
// extractTokenUsage extracts token usage information from Gemini's response
|
||||
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
if resp == nil || resp.UsageMetadata == nil {
|
||||
return TokenUsage{}
|
||||
}
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
|
||||
CacheCreationTokens: 0, // Not directly provided by Gemini
|
||||
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessages sends a batch of messages to Gemini and returns the response
|
||||
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
// Create a generative model
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
// Set system instruction
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
// Set up tools if provided
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
|
||||
}
|
||||
|
||||
// Create chat session and set history
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
// Get the most recent user message
|
||||
var lastUserMsg message.Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == message.User {
|
||||
lastUserMsg = messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send the message
|
||||
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process the response
|
||||
var content string
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
content = string(p)
|
||||
case genai.FunctionCall:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token usage
|
||||
tokenUsage := p.extractTokenUsage(resp)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StreamResponse streams the response from Gemini
|
||||
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
// Create a generative model
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
// Set system instruction
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
// Set up tools if provided
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
for _, declaration := range declarations {
|
||||
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
|
||||
}
|
||||
}
|
||||
|
||||
// Create chat session and set history
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
lastUserMsg := messages[len(messages)-1]
|
||||
|
||||
// Start streaming
|
||||
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
var finalResp *genai.GenerateContentResponse
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
var apiErr *googleapi.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
log.Printf("%s", apiErr.Body)
|
||||
}
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
finalResp = resp
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
newText := string(p)
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: newText,
|
||||
}
|
||||
currentContent += newText
|
||||
case genai.FunctionCall:
|
||||
// For function calls, we assume they come complete, not streamed in parts
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
}
|
||||
|
||||
// Check if this is a new tool call
|
||||
isNew := true
|
||||
for _, existing := range toolCalls {
|
||||
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
toolCalls = append(toolCalls, newCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token usage from the final response
|
||||
tokenUsage := p.extractTokenUsage(finalResp)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
// Helper function to parse JSON string into map
|
||||
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
return result, err
|
||||
}
|
||||
278
internal/llm/provider/openai.go
Normal file
278
internal/llm/provider/openai.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
type openaiProvider struct {
|
||||
client openai.Client
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
baseURL string
|
||||
apiKey string
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
type OpenAIOption func(*openaiProvider)
|
||||
|
||||
func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
|
||||
provider := &openaiProvider{
|
||||
maxTokens: 5000,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
clientOpts := []option.RequestOption{
|
||||
option.WithAPIKey(provider.apiKey),
|
||||
}
|
||||
if provider.baseURL != "" {
|
||||
clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
|
||||
}
|
||||
|
||||
provider.client = openai.NewClient(clientOpts...)
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func WithOpenAISystemMessage(message string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIModel(model models.Model) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIKey(apiKey string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
|
||||
var chatMessages []openai.ChatCompletionMessageParamUnion
|
||||
|
||||
chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
|
||||
|
||||
case message.Assistant:
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
Role: "assistant",
|
||||
}
|
||||
|
||||
if msg.Content != "" {
|
||||
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content),
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
|
||||
for i, call := range msg.ToolCalls {
|
||||
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
|
||||
ID: call.ID,
|
||||
Type: "function",
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: call.Name,
|
||||
Arguments: call.Input,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &assistantMsg,
|
||||
})
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults {
|
||||
chatMessages = append(chatMessages,
|
||||
openai.ToolMessage(result.Content, result.ToolCallID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chatMessages
|
||||
}
|
||||
|
||||
func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
||||
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
openaiTools[i] = openai.ChatCompletionToolParam{
|
||||
Function: openai.FunctionDefinitionParam{
|
||||
Name: info.Name,
|
||||
Description: openai.String(info.Description),
|
||||
Parameters: openai.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": info.Parameters,
|
||||
"required": info.Required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
|
||||
cachedTokens := int64(0)
|
||||
|
||||
cachedTokens = usage.PromptTokensDetails.CachedTokens
|
||||
inputTokens := usage.PromptTokens - cachedTokens
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
||||
CacheReadTokens: cachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
chatMessages := p.convertToOpenAIMessages(messages)
|
||||
openaiTools := p.convertToOpenAITools(tools)
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(p.model.APIModel),
|
||||
Messages: chatMessages,
|
||||
MaxTokens: openai.Int(p.maxTokens),
|
||||
Tools: openaiTools,
|
||||
}
|
||||
|
||||
response, err := p.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
if response.Choices[0].Message.Content != "" {
|
||||
content = response.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
var toolCalls []message.ToolCall
|
||||
if len(response.Choices[0].Message.ToolCalls) > 0 {
|
||||
toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
|
||||
for i, call := range response.Choices[0].Message.ToolCalls {
|
||||
toolCalls[i] = message.ToolCall{
|
||||
ID: call.ID,
|
||||
Name: call.Function.Name,
|
||||
Input: call.Function.Arguments,
|
||||
Type: "function",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(response.Usage)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
chatMessages := p.convertToOpenAIMessages(messages)
|
||||
openaiTools := p.convertToOpenAITools(tools)
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(p.model.APIModel),
|
||||
Messages: chatMessages,
|
||||
MaxTokens: openai.Int(p.maxTokens),
|
||||
Tools: openaiTools,
|
||||
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
},
|
||||
}
|
||||
|
||||
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
toolCalls := make([]message.ToolCall, 0)
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
currentContent := ""
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
|
||||
if tool, ok := acc.JustFinishedToolCall(); ok {
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: tool.Id,
|
||||
Name: tool.Name,
|
||||
Input: tool.Arguments,
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: choice.Delta.Content,
|
||||
}
|
||||
currentContent += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(acc.Usage)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
48
internal/llm/provider/provider.go
Normal file
48
internal/llm/provider/provider.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
// EventType represents the type of streaming event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventContentStart EventType = "content_start"
|
||||
EventContentDelta EventType = "content_delta"
|
||||
EventThinkingDelta EventType = "thinking_delta"
|
||||
EventContentStop EventType = "content_stop"
|
||||
EventComplete EventType = "complete"
|
||||
EventError EventType = "error"
|
||||
)
|
||||
|
||||
type TokenUsage struct {
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CacheCreationTokens int64
|
||||
CacheReadTokens int64
|
||||
}
|
||||
|
||||
type ProviderResponse struct {
|
||||
Content string
|
||||
ToolCalls []message.ToolCall
|
||||
Usage TokenUsage
|
||||
}
|
||||
|
||||
type ProviderEvent struct {
|
||||
Type EventType
|
||||
Content string
|
||||
Thinking string
|
||||
ToolCall *message.ToolCall
|
||||
Error error
|
||||
Response *ProviderResponse
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
|
||||
}
|
||||
Reference in New Issue
Block a user