initial working agent
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
}
|
||||
|
||||
log.Printf("Request: %s", content)
|
||||
agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
||||
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
for _, m := range history {
|
||||
messages = append(messages, &m.MessageData)
|
||||
}
|
||||
|
||||
builder := callbacks.NewHandlerBuilder()
|
||||
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
i, ok := input.(*eModel.CallbackInput)
|
||||
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
return ctx
|
||||
})
|
||||
|
||||
out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
||||
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
return
|
||||
}
|
||||
usage := out.ResponseMeta.Usage
|
||||
s.messages.Create(sessionID, *out)
|
||||
if usage != nil {
|
||||
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
|
||||
session, err := s.sessions.Get(sessionID)
|
||||
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
session.PromptTokens += int64(usage.PromptTokens)
|
||||
session.CompletionTokens += int64(usage.CompletionTokens)
|
||||
// TODO: calculate cost
|
||||
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
|
||||
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
|
||||
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
|
||||
var newTitle string
|
||||
if len(history) == 1 {
|
||||
// first message generate the title
|
||||
newTitle, err = agent.GenerateTitle(s.ctx, content)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if newTitle != "" {
|
||||
session.Title = newTitle
|
||||
}
|
||||
|
||||
_, err = s.sessions.Save(session)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
return
|
||||
}
|
||||
}
|
||||
s.messages.Create(sessionID, *out)
|
||||
}
|
||||
|
||||
func (s *service) SendRequest(sessionID string, content string) {
|
||||
|
||||
Reference in New Issue
Block a user