Skip to content

Commit

Permalink
bot support chat in thread
Browse files Browse the repository at this point in the history
  • Loading branch information
vaayne committed Jul 26, 2024
1 parent 7975d83 commit 2f56419
Show file tree
Hide file tree
Showing 15 changed files with 369 additions and 129 deletions.
4 changes: 2 additions & 2 deletions database/bindata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions database/migrations/000004_create_assistant_threads.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ CREATE TABLE IF NOT EXISTS assistant_threads (
assistant_id UUID REFERENCES assistants(uuid),
name VARCHAR(255) NOT NULL,
description TEXT,
system_prompt TEXT,
model VARCHAR(32) NOT NULL,
is_long_term_memory BOOLEAN DEFAULT FALSE,
metadata JSONB DEFAULT '{}'::JSONB,
Expand Down
8 changes: 4 additions & 4 deletions database/queries/assistants.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ SELECT * FROM assistants WHERE user_id = $1 ORDER BY created_at DESC;

-- CRUD for assistant_threads
-- name: CreateAssistantThread :one
INSERT INTO assistant_threads (user_id, assistant_id, name, description, model, is_long_term_memory, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO assistant_threads (user_id, assistant_id, name, description, system_prompt, model, is_long_term_memory, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING *;

-- name: GetAssistantThread :one
SELECT * FROM assistant_threads WHERE uuid = $1;

-- name: UpdateAssistantThread :exec
UPDATE assistant_threads SET name = $2, description = $3, model = $4, is_long_term_memory = $5, metadata = $6
UPDATE assistant_threads SET name = $2, description = $3, model = $4, is_long_term_memory = $5, metadata = $6, system_prompt = $7
WHERE uuid = $1;

-- name: DeleteAssistantThread :exec
Expand Down Expand Up @@ -56,7 +56,7 @@ UPDATE assistant_messages SET text = $2, attachments = $3, metadata = $4 WHERE u
DELETE FROM assistant_messages WHERE uuid = $1;

-- name: ListThreadMessages :many
SELECT * FROM assistant_messages WHERE thread_id = $1 ORDER BY created_at DESC;
SELECT * FROM assistant_messages WHERE thread_id = $1 ORDER BY created_at ASC;

-- CRUD for assistant_attachments
-- name: CreateAttachment :one
Expand Down
53 changes: 47 additions & 6 deletions internal/core/assistants/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Repository interface {

CreateThread(ctx context.Context, thread *Thread) error
GetThread(ctx context.Context, id uuid.UUID) (*Thread, error)
CreateThreadMessage(ctx context.Context, threadID uuid.UUID, message ThreadMessage) error
}

type repository struct {
Expand Down Expand Up @@ -81,12 +82,52 @@ func (r *repository) GetThread(ctx context.Context, id uuid.UUID) (*Thread, erro
return nil, fmt.Errorf("failed to get thread: %w", err)
}
thread := &Thread{
Id: th.Uuid,
UserId: th.UserID.Bytes,
AssistantId: th.AssistantID.Bytes,
Name: th.Name,
Description: th.Description.String,
Model: th.Model,
Id: th.Uuid,
UserId: th.UserID.Bytes,
AssistantId: th.AssistantID.Bytes,
Name: th.Name,
Description: th.Description.String,
Model: th.Model,
SystemPrompt: th.SystemPrompt.String,
}

messages, err := r.ListThreadMessages(ctx, th.Uuid)
if err != nil {
return nil, fmt.Errorf("failed to get thread messages: %w", err)
}
thread.Messages = messages
return thread, nil
}

func (r *repository) ListThreadMessages(ctx context.Context, threadID uuid.UUID) ([]ThreadMessage, error) {
messages, err := r.db.ListThreadMessages(ctx, pgtype.UUID{Bytes: threadID, Valid: true})
if err != nil {
return nil, fmt.Errorf("failed to get thread messages: %w", err)
}

var result []ThreadMessage
for _, msg := range messages {
result = append(result, ThreadMessage{
Role: msg.Role,
Text: msg.Text.String,
CreatedAt: msg.CreatedAt.Time,
UpdatedAt: msg.UpdatedAt.Time,
})
}

return result, nil
}

func (r *repository) CreateThreadMessage(ctx context.Context, threadID uuid.UUID, message ThreadMessage) error {
_, err := r.db.CreateThreadMessage(ctx, db.CreateThreadMessageParams{
UserID: pgtype.UUID{Bytes: message.UserID, Valid: true},
ThreadID: pgtype.UUID{Bytes: threadID, Valid: true},
Model: pgtype.Text{String: message.Model, Valid: message.Model != ""},
Role: message.Role,
Text: pgtype.Text{String: message.Text, Valid: true},
})
if err != nil {
return fmt.Errorf("failed to save thread message: %w", err)
}
return nil
}
67 changes: 58 additions & 9 deletions internal/core/assistants/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,84 @@ package assistants

import (
"context"
"fmt"
"vibrain/internal/pkg/db"
"vibrain/internal/pkg/llms"

"github.com/google/uuid"
"github.com/sashabaranov/go-openai"
)

type Service struct {
db Repository
repository Repository
llm *llms.LLM
}

func NewService(db *db.Pool) (*Service, error) {
func NewService(db *db.Pool, llm *llms.LLM) (*Service, error) {
s := &Service{
db: NewRepository(db),
repository: NewRepository(db),
llm: llm,
}

return s, nil
}

func (s *Service) CreateAssistant(ctx context.Context, assistant *Assistant) error {
return s.db.CreateAssistant(ctx, assistant)
return s.repository.CreateAssistant(ctx, assistant)
}

func (s *Service) GetAssistant(ctx context.Context, id string) (*Assistant, error) {
return s.db.GetAssistant(ctx, uuid.MustParse(id))
func (s *Service) GetAssistant(ctx context.Context, id uuid.UUID) (*Assistant, error) {
return s.repository.GetAssistant(ctx, id)
}

func (s *Service) CreateThread(ctx context.Context, thread *Thread) error {
return s.db.CreateThread(ctx, thread)
return s.repository.CreateThread(ctx, thread)
}

func (s *Service) GetThread(ctx context.Context, id string) (*Thread, error) {
return s.db.GetThread(ctx, uuid.MustParse(id))
func (s *Service) GetThread(ctx context.Context, id uuid.UUID) (*Thread, error) {
return s.repository.GetThread(ctx, id)
}

func (s *Service) AddThreadMessage(ctx context.Context, thread *Thread, role, text string) error {
thread.AddMessage(role, text)
message := ThreadMessage{
UserID: thread.UserId,
ThreadID: thread.Id,
Model: thread.Model,
Role: role,
Text: text,
}
return s.repository.CreateThreadMessage(ctx, thread.Id, message)
}

func (s *Service) RunThread(ctx context.Context, thread *Thread) (*ThreadMessage, error) {
oaiMessages := make([]openai.ChatCompletionMessage, 0)
oaiMessages = append(oaiMessages, openai.ChatCompletionMessage{
Role: "system",
Content: thread.SystemPrompt,
})
for _, m := range thread.Messages {
oaiMessages = append(oaiMessages, openai.ChatCompletionMessage{
Role: m.Role,
Content: m.Text,
})
}
resp, usage, err := s.llm.GenerateContent(ctx, oaiMessages)
if err != nil {
return nil, err
}

message := ThreadMessage{
UserID: thread.UserId,
ThreadID: thread.Id,
Model: thread.Model,
Role: resp.Message.Role,
Text: resp.Message.Content,
Token: usage.TotalTokens,
}

if err := s.repository.CreateThreadMessage(ctx, thread.Id, message); err != nil {
return nil, fmt.Errorf("failed to save thread message: %w", err)
}
return &message, err
}
58 changes: 46 additions & 12 deletions internal/core/assistants/thread.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,53 @@
package assistants

import (
"time"

"github.com/google/uuid"
)

type ThreadMetaData struct{}

type ThreadMessage struct {
UserID uuid.UUID
ThreadID uuid.UUID
Model string
Token int
Role string
Text string
CreatedAt time.Time
UpdatedAt time.Time
}

type Thread struct {
Id uuid.UUID `json:"uuid"`
UserId uuid.UUID `json:"user_id"`
AssistantId uuid.UUID `json:"assistant_id"`
Name string `json:"name"`
Description string `json:"description"`
Model string `json:"model"`
MetaData ThreadMetaData `json:"metadata"`
Id uuid.UUID `json:"uuid"`
UserId uuid.UUID `json:"user_id"`
AssistantId uuid.UUID `json:"assistant_id"`
SystemPrompt string `json:"system_prompt"`
Name string `json:"name"`
Description string `json:"description"`
Model string `json:"model"`
MetaData ThreadMetaData `json:"metadata"`
Messages []ThreadMessage `json:"messages"`
}

func (t *Thread) AddMessage(role, text string) {
t.Messages = append(t.Messages, ThreadMessage{
Role: role,
Text: text,
})
}

type ThreadOption func(*Thread)

func NewThread(userId uuid.UUID, assistant Assistant, opts ...ThreadOption) *Thread {
t := &Thread{
UserId: userId,
AssistantId: assistant.Id,
Model: assistant.Model,
Name: "Thread" + uuid.New().String(),
Description: "I am a conversation thread.",
UserId: userId,
AssistantId: assistant.Id,
SystemPrompt: assistant.SystemPrompt,
Model: assistant.Model,
Name: "Thread" + uuid.New().String(),
Description: "I am a conversation thread.",
}

for _, opt := range opts {
Expand Down Expand Up @@ -57,3 +80,14 @@ func WithThreadMetaData(metaData ThreadMetaData) ThreadOption {
}
}

func AddThreadMessage(message ThreadMessage) ThreadOption {
return func(t *Thread) {
t.Messages = append(t.Messages, message)
}
}

func WithThreadMessages(messages []ThreadMessage) ThreadOption {
return func(t *Thread) {
t.Messages = messages
}
}
Loading

0 comments on commit 2f56419

Please sign in to comment.