initial working agent

This commit is contained in:
Kujtim Hoxha
2025-03-24 11:47:39 +01:00
parent e7258e38ae
commit 005b8ac167
6 changed files with 201 additions and 22 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
}
log.Printf("Request: %s", content)
agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
for _, m := range history {
messages = append(messages, &m.MessageData)
}
builder := callbacks.NewHandlerBuilder()
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
i, ok := input.(*eModel.CallbackInput)
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return ctx
})
out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
usage := out.ResponseMeta.Usage
s.messages.Create(sessionID, *out)
if usage != nil {
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
session, err := s.sessions.Get(sessionID)
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
session.PromptTokens += int64(usage.PromptTokens)
session.CompletionTokens += int64(usage.CompletionTokens)
// TODO: calculate cost
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
var newTitle string
if len(history) == 1 {
// first message generate the title
newTitle, err = agent.GenerateTitle(s.ctx, content)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
}
if newTitle != "" {
session.Title = newTitle
}
_, err = s.sessions.Save(session)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
}
s.messages.Create(sessionID, *out)
}
func (s *service) SendRequest(sessionID string, content string) {