Skip to content

Commit

Permalink
✨ feat: Refactor initialization and improve image upload handling
Browse files Browse the repository at this point in the history
- Refactored initialization of default services (DB, cache, LLM, queue) using init() functions
- Improved image upload handling in DefaultImgHook with better URL validation and async processing
- Updated queue service initialization to use default queue instance
- Added LoadUser function to auth package for flexible user loading
- Fixed SummarierWorker constructor return type
- Simplified main.go by using default service instances
  • Loading branch information
vaayne committed Jan 8, 2025
1 parent 6bb898a commit d2ee7d2
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 57 deletions.
2 changes: 1 addition & 1 deletion database/bindata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/core/bookmarks/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (s *Service) FetchContent(ctx context.Context, tx db.DBTX, id, userID uuid.
}

func (s *Service) SummarierContent(ctx context.Context, tx db.DBTX, id, userID uuid.UUID) (*ContentDTO, error) {
user, err := auth.LoadUserFromContext(ctx)
user, err := auth.LoadUser(ctx, tx, userID)
if err != nil {
return nil, err
}
Expand Down
17 changes: 8 additions & 9 deletions internal/core/queue/crawler_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,14 @@ func (w *CrawlerWorker) work(ctx context.Context, tx pgx.Tx, job *river.Job[Craw
return err
}

go func() {
ctx := logger.CopyContext(ctx)
if dto.Content != "" && dto.Summary == "" {
dto, err = svc.SummarierContent(ctx, tx, job.Args.ID, job.Args.UserID)
if err != nil {
logger.FromContext(ctx).Error("failed to summarise content", "err", err)
}
}
}()
if result, err := DefaultQueue.Insert(ctx, SummarierWorkerArgs{
ID: dto.ID,
UserID: dto.UserID,
}, nil); err != nil {
logger.FromContext(ctx).Error("failed to insert summaries job", "err", err, "content_id", dto.ID)
} else {
logger.FromContext(ctx).Info("success inserted summaries job", "result", result, "content_id", dto.ID)
}

logger.FromContext(ctx).Info("fetched bookmark", "id", dto.ID, "title", dto.Title, "url", dto.URL, "fetcher", job.Args.FetcherName)
return nil
Expand Down
27 changes: 27 additions & 0 deletions internal/core/queue/default.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package queue

import (
"fmt"
"recally/internal/pkg/db"
"recally/internal/pkg/llms"
"recally/internal/pkg/logger"
)

var DefaultQueue *Queue

func init() {
var err error
DefaultQueue, err = NewDefaultQueue()
if err != nil {
logger.Default.Error("failed to create default queue", "err", err)
}
}

func NewDefaultQueue() (*Queue, error) {
// start queue service
q, err := New(db.DefaultPool, llms.DefaultLLM)
if err != nil {
return nil, fmt.Errorf("failed to create new queue service: %w", err)
}
return q, nil
}
7 changes: 3 additions & 4 deletions internal/core/queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ type Service struct {
*Queue
}

func NewServer(pool *db.Pool, llm *llms.LLM) (*Service, error) {
q, err := New(pool, llm)
if err != nil {
return nil, err
func NewServer(q *Queue) (*Service, error) {
if q == nil {
return nil, fmt.Errorf("default queue is nil")
}
return &Service{
Queue: q,
Expand Down
4 changes: 2 additions & 2 deletions internal/core/queue/summarier_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ func (SummarierWorkerArgs) Kind() string {
return "content_summarier"
}

func NewSummarierWorker(llm *llms.LLM, dbPool *pgxpool.Pool) *CrawlerWorker {
return &CrawlerWorker{
func NewSummarierWorker(llm *llms.LLM, dbPool *pgxpool.Pool) *SummarierWorker {
return &SummarierWorker{
llm: llm,
dbPool: dbPool,
}
Expand Down
15 changes: 15 additions & 0 deletions internal/pkg/auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"fmt"
"recally/internal/pkg/contexts"
"recally/internal/pkg/db"

"github.com/google/uuid"
)

func LoadUserFromContext(ctx context.Context) (*UserDTO, error) {
Expand All @@ -17,3 +20,15 @@ func LoadUserFromContext(ctx context.Context) (*UserDTO, error) {
func SetUserToContext(ctx context.Context, user *UserDTO) context.Context {
return context.WithValue(ctx, contexts.ContextKey(contexts.ContextKeyUser), user)
}

func LoadUserByID(ctx context.Context, tx db.DBTX, userID uuid.UUID) (*UserDTO, error) {
return New().GetUserById(ctx, tx, userID)
}

func LoadUser(ctx context.Context, tx db.DBTX, userID uuid.UUID) (*UserDTO, error) {
user, err := LoadUserFromContext(ctx)
if err != nil {
user, err = LoadUserByID(ctx, tx, userID)
}
return user, err
}
6 changes: 6 additions & 0 deletions internal/pkg/cache/db_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

var DefaultDBCache *DbCache

func init() {
DefaultDBCache = NewDBCache(db.DefaultPool)
}

type CacheKey struct {
Domain string
Key string
Expand Down
10 changes: 10 additions & 0 deletions internal/pkg/db/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
)

var DefaultPool *Pool

func init() {
var err error
DefaultPool, err = NewPool(context.Background(), config.Settings.Database.URL())
if err != nil {
logger.Default.Fatal("failed to create default database pool", "err", err)
}
}

type Pool struct {
*pgxpool.Pool
}
Expand Down
8 changes: 8 additions & 0 deletions internal/pkg/llms/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"recally/internal/pkg/cache"
"recally/internal/pkg/config"
"recally/internal/pkg/logger"
"recally/internal/pkg/tools"
"strings"
Expand All @@ -16,6 +17,13 @@ import (
"golang.org/x/sync/errgroup"
)


var DefaultLLM * LLM

func init() {
DefaultLLM = New(config.Settings.OpenAI.BaseURL, config.Settings.OpenAI.ApiKey)
}

const (
IntermediateStepTool = "tool"
IntermediateStepRag = "rag"
Expand Down
53 changes: 29 additions & 24 deletions internal/pkg/webreader/processor/hooks/default_img_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,31 @@ type ImageHook struct{}
func (h *ImageHook) UploadToS3(selec *goquery.Selection) {
selec.Find("img").Each(func(i int, s *goquery.Selection) {
src := s.AttrOr("src", "")
img, host, imageType, size, err := h.loadImage(src)

// Validate and parse URL
u, err := url.Parse(src)
if err != nil {
return
}
logger.Default.Debug("image loaded", "host", host, "imageType", imageType, "size", size)
// Upload image to S3
objectKey := fmt.Sprintf("images/%s/%s.%s", host, uuid.New().String(), imageType)
if u.Scheme == "" || u.Host == "" {
return
}
host := u.Host
objectKey := fmt.Sprintf("images/%s/%s", host, uuid.New().String())

// Asynchronously load and upload the image for better performance
go func() {
info, err := s3.DefaultClient.Upload(context.Background(), objectKey, bytes.NewReader(img), size, minio.PutObjectOptions{})
// Load image
img, contentType, size, err := h.loadImage(src)
if err != nil {
return
}
logger.Default.Debug("image loaded", "host", host, "contentType", contentType, "size", size)
// Upload image to S3
info, err := s3.DefaultClient.Upload(context.Background(), objectKey, bytes.NewReader(img), size, minio.PutObjectOptions{
ContentType: contentType,
CacheControl: "max-age=31536000, public",
})
if err != nil {
logger.Default.Error("failed to upload image to s3", "err", err, "objectKey", objectKey, "info", info)
return
Expand All @@ -59,25 +75,15 @@ func (h *ImageHook) UploadToS3(selec *goquery.Selection) {
})
}

func (h *ImageHook) loadImage(uri string) (img []byte, host string, imageType string, size int64, err error) {
func (h *ImageHook) loadImage(uri string) (img []byte, contentType string, size int64, err error) {
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Validate and parse URL
u, err := url.Parse(uri)
if err != nil {
return nil, "", "", 0, fmt.Errorf("invalid image URL: %w", err)
}
if u.Scheme == "" || u.Host == "" {
return nil, "", "", 0, fmt.Errorf("invalid image URL format: %s", uri)
}
host = u.Host

// Create request with context
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, "", "", 0, fmt.Errorf("failed to create request: %w", err)
return nil, "", 0, fmt.Errorf("failed to create request: %w", err)
}

// Add user agent to avoid being blocked
Expand All @@ -89,30 +95,29 @@ func (h *ImageHook) loadImage(uri string) (img []byte, host string, imageType st
}
resp, err := client.Do(req)
if err != nil {
return nil, "", "", 0, fmt.Errorf("failed to download image: %w", err)
return nil, "", 0, fmt.Errorf("failed to download image: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, "", "", 0, fmt.Errorf("failed to download image: status %s", resp.Status)
return nil, "", 0, fmt.Errorf("failed to download image: status %s", resp.Status)
}

// Read with max size limit (e.g., 10MB)
const maxSize = 10 * 1024 * 1024
img, err = io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return nil, "", "", 0, fmt.Errorf("failed to read image: %w", err)
return nil, "", 0, fmt.Errorf("failed to read image: %w", err)
}

// Determine content type
contentType := resp.Header.Get("Content-Type")
contentType = resp.Header.Get("Content-Type")
if contentType == "" || !strings.HasPrefix(contentType, "image/") {
contentType = http.DetectContentType(img)
if !strings.HasPrefix(contentType, "image/") {
return nil, "", "", 0, fmt.Errorf("invalid content type: %s", contentType)
return nil, "", 0, fmt.Errorf("invalid content type: %s", contentType)
}
}
imageType = contentType[strings.Index(contentType, "/")+1:]

// Get size
size = resp.ContentLength
Expand All @@ -127,5 +132,5 @@ func (h *ImageHook) loadImage(uri string) (img []byte, host string, imageType st
size = int64(len(img))
}

return img, host, imageType, size, nil
return img, contentType, size, nil
}
23 changes: 7 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,15 @@ func main() {

services := make([]Service, 0)

// init basic services
// init db pool
pool, err := db.NewPool(ctx, config.Settings.Database.URL())
if err != nil {
logger.Default.Fatal("failed to create new database pool", "err", err)
}

// init cache service
cacheService := cache.NewDBCache(pool)

llm := llms.New(config.Settings.OpenAI.BaseURL, config.Settings.OpenAI.ApiKey)
s3Client, err := s3.New(config.Settings.S3)
if err != nil {
logger.Default.Fatal("failed to create new s3 client", "err", err)
}
// init basic services using default config
pool := db.DefaultPool
cacheService := cache.DefaultDBCache
llm := llms.DefaultLLM
s3Client := s3.DefaultClient
riverQueue := queue.DefaultQueue

// start queue service
queueService, err := queue.NewServer(pool, llm)
queueService, err := queue.NewServer(riverQueue)
if err != nil {
logger.Default.Fatal("failed to create new queue service", "err", err)
}
Expand Down

0 comments on commit d2ee7d2

Please sign in to comment.