Skip to content

Commit

Permalink
feat: enhance upload files
Browse files Browse the repository at this point in the history
  • Loading branch information
vaayne committed Sep 1, 2024
1 parent eaa6cf6 commit 3d43acd
Show file tree
Hide file tree
Showing 19 changed files with 176 additions and 107 deletions.
4 changes: 2 additions & 2 deletions database/bindata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion database/queries/assistants_attachment.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ SELECT * FROM assistant_attachments WHERE uuid = $1;
SELECT * FROM assistant_attachments WHERE user_id = $1 ORDER BY created_at DESC;

-- name: ListAssistantAttachmentsByAssistantId :many
SELECT * FROM assistant_attachments WHERE assistant_id = $1 ORDER BY created_at DESC;
SELECT * FROM assistant_attachments WHERE assistant_id = $1 AND thread_id IS NULL ORDER BY created_at DESC;

-- name: ListAssistantAttachmentsByThreadId :many
SELECT * FROM assistant_attachments WHERE thread_id = $1 ORDER BY created_at DESC;
Expand Down
2 changes: 1 addition & 1 deletion internal/core/assistants/assistant_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (a *AssistantDTO) Load(dbo *db.Assistant) {
func (a *AssistantDTO) Dump() *db.Assistant {
metadata, _ := json.Marshal(a.Metadata)
return &db.Assistant{
UserID: pgtype.UUID{Bytes: a.UserId, Valid: true},
UserID: pgtype.UUID{Bytes: a.UserId, Valid: a.UserId != uuid.Nil},
Name: a.Name,
Description: pgtype.Text{String: a.Description, Valid: a.Description != ""},
SystemPrompt: pgtype.Text{String: a.SystemPrompt, Valid: a.SystemPrompt != ""},
Expand Down
14 changes: 7 additions & 7 deletions internal/core/assistants/attachment_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func (a *AttachmentDTO) Dump() *db.AssistantAttachment {
}
return &db.AssistantAttachment{
Uuid: a.Id,
UserID: pgtype.UUID{Bytes: a.UserId, Valid: true},
AssistantID: pgtype.UUID{Bytes: a.AssistantId, Valid: true},
ThreadID: pgtype.UUID{Bytes: a.ThreadId, Valid: true},
Name: pgtype.Text{String: a.Name, Valid: true},
Type: pgtype.Text{String: a.Type, Valid: true},
Url: pgtype.Text{String: a.URL, Valid: true},
Size: pgtype.Int4{Int32: int32(a.Size), Valid: true},
UserID: pgtype.UUID{Bytes: a.UserId, Valid: a.UserId != uuid.Nil},
AssistantID: pgtype.UUID{Bytes: a.AssistantId, Valid: a.AssistantId != uuid.Nil},
ThreadID: pgtype.UUID{Bytes: a.ThreadId, Valid: a.ThreadId != uuid.Nil},
Name: pgtype.Text{String: a.Name, Valid: a.Name != ""},
Type: pgtype.Text{String: a.Type, Valid: a.Type != ""},
Url: pgtype.Text{String: a.URL, Valid: a.URL != ""},
Size: pgtype.Int4{Int32: int32(a.Size), Valid: a.Size > 0},
Metadata: metadata,
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/core/assistants/embedding_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (e *EmbeddingDTO) Load(dbo *db.AssistantEmbeddding) {

func (e *EmbeddingDTO) Dump() *db.AssistantEmbeddding {
return &db.AssistantEmbeddding{
UserID: pgtype.UUID{Bytes: e.UserID, Valid: true},
UserID: pgtype.UUID{Bytes: e.UserID, Valid: e.UserID != uuid.Nil},
AttachmentID: pgtype.UUID{Bytes: e.AttachmentID, Valid: e.AttachmentID != uuid.Nil},
Text: e.Text,
Embeddings: pgvector.NewVector(e.Embeddings),
Expand Down
8 changes: 4 additions & 4 deletions internal/core/assistants/message_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ func (d *MessageDTO) Load(dbo *db.AssistantMessage) {
func (d *MessageDTO) Dump() *db.AssistantMessage {
metadata, _ := json.Marshal(d.Metadata)
if d.Embeddings == nil {
d.Embeddings = []float32{}
d.Embeddings = nil
}
return &db.AssistantMessage{
UserID: pgtype.UUID{Bytes: d.UserID, Valid: true},
AssistantID: pgtype.UUID{Bytes: d.AssistantID, Valid: true},
ThreadID: pgtype.UUID{Bytes: d.ThreadID, Valid: true},
UserID: pgtype.UUID{Bytes: d.UserID, Valid: d.UserID != uuid.Nil},
AssistantID: pgtype.UUID{Bytes: d.AssistantID, Valid: d.AssistantID != uuid.Nil},
ThreadID: pgtype.UUID{Bytes: d.ThreadID, Valid: d.ThreadID != uuid.Nil},
Model: pgtype.Text{String: d.Model, Valid: d.Model != ""},
Role: d.Role,
Text: pgtype.Text{String: d.Text, Valid: d.Text != ""},
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/db/assistants_attachment.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions internal/pkg/logger/color.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ func NewColorHandler(handlerOptions *slog.HandlerOptions, options ...Option) *Ha
return handler
}


type Option func(h *Handler)

func WithDestinationWriter(writer io.Writer) Option {
Expand All @@ -241,4 +240,4 @@ func WithOutputEmptyAttrs() Option {
return func(h *Handler) {
h.outputEmptyAttrs = true
}
}
}
66 changes: 66 additions & 0 deletions internal/port/httpserver/assistants_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type assistantService interface {
DeleteThreadMessage(ctx context.Context, tx db.DBTX, id uuid.UUID) error

CreateAttachment(ctx context.Context, tx db.DBTX, attachment *assistants.AttachmentDTO, docs []document.Document) (*assistants.AttachmentDTO, error)
ListAttachmentsByAssistant(ctx context.Context, tx db.DBTX, assistantId uuid.UUID) ([]assistants.AttachmentDTO, error)
ListAttachmentsByThread(ctx context.Context, tx db.DBTX, threadId uuid.UUID) ([]assistants.AttachmentDTO, error)

ListModels(ctx context.Context) ([]string, error)
ListTools(ctx context.Context) ([]tools.BaseTool, error)
Expand Down Expand Up @@ -64,7 +66,9 @@ func registerAssistantHandlers(e *echo.Group, s *Service) {
g.PUT("/:assistant-id/threads/:thread-id/messages/:message-id", h.updateThreadMessage)
g.DELETE("/:assistant-id/threads/:thread-id/messages/:message-id", h.deleteThreadMessage)

g.GET("/:assistant-id/attachments", h.listAttachmentsByAssistant)
g.POST("/:assistant-id/attachments", h.uploadAssistantAttachment)
g.GET("/:assistant-id/threads/:thread-id/attachments", h.listAttachmentsByThread)
g.POST("/:assistant-id/threads/:thread-id/attachments", h.uploadThreadAttachment)

g.GET("/models", h.listModels)
Expand Down Expand Up @@ -300,6 +304,37 @@ func (h *assistantHandler) deleteAssistant(c echo.Context) error {
return JsonResponse(c, http.StatusNoContent, nil)
}

// @Summary List Attachments by Assistant
// @Description Lists attachments for a specific assistant
// @Tags Assistants
// @Accept json
// @Produce json
// @Param assistant-id path string true "Assistant ID"
// @Success 200 {object} JSONResult{data=[]assistants.AttachmentDTO} "Success"
// @Failure 400 {object} JSONResult{data=nil} "Bad Request"
// @Failure 401 {object} JSONResult{data=nil} "Unauthorized"
// @Failure 500 {object} JSONResult{data=nil} "Internal Server Error"
// @Router /assistants/{assistant-id}/attachments [get]
func (h *assistantHandler) listAttachmentsByAssistant(c echo.Context) error {
ctx := c.Request().Context()
req := new(getAssistantRequest)
if err := bindAndValidate(c, req); err != nil {
return err
}

tx, _, err := initContext(ctx)
if err != nil {
return ErrorResponse(c, http.StatusInternalServerError, err)
}

attachments, err := h.service.ListAttachmentsByAssistant(ctx, tx, req.AssistantId)
if err != nil {
return ErrorResponse(c, http.StatusInternalServerError, err)
}

return JsonResponse(c, http.StatusOK, attachments)
}

type uploadAssistantAttachmentRequest struct {
AssistantId uuid.UUID `param:"assistant-id" validate:"required,uuid4"`
Type string `json:"type" validate:"required"`
Expand Down Expand Up @@ -349,6 +384,37 @@ func (h *assistantHandler) uploadAssistantAttachment(c echo.Context) error {
return JsonResponse(c, http.StatusCreated, attachment)
}

// @Summary List Attachments by Thread
// @Description Lists attachments for a specific thread
// @Tags Assistants
// @Accept json
// @Produce json
// @Param assistant-id path string true "Assistant ID"
// @Param thread-id path string true "Thread ID"
// @Success 200 {object} JSONResult{data=[]assistants.AttachmentDTO} "Success"
// @Failure 400 {object} JSONResult{data=nil} "Bad Request"
// @Failure 401 {object} JSONResult{data=nil} "Unauthorized"
// @Failure 500 {object} JSONResult{data=nil} "Internal Server Error"
// @Router /assistants/{assistant-id}/threads/{thread-id}/attachments [get]
func (h *assistantHandler) listAttachmentsByThread(c echo.Context) error {
ctx := c.Request().Context()
req := new(getThreadRequest)
if err := bindAndValidate(c, req); err != nil {
return err
}
tx, _, err := initContext(ctx)
if err != nil {
return ErrorResponse(c, http.StatusInternalServerError, err)
}

attachments, err := h.service.ListAttachmentsByThread(ctx, tx, req.ThreadId)
if err != nil {
return ErrorResponse(c, http.StatusInternalServerError, err)
}

return JsonResponse(c, http.StatusOK, attachments)
}

type uploadThreadAttachmentRequest struct {
AssistantId uuid.UUID `param:"assistant-id" validate:"required,uuid4"`
ThreadId uuid.UUID `param:"thread-id" validate:"required,uuid4"`
Expand Down
2 changes: 1 addition & 1 deletion internal/port/httpserver/file_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type getPresignedURLsRequest struct {
ThreadId uuid.UUID `query:"thread_id,omitempty" validate:"omitempty,uuid4"`
FileName string `query:"file_name" validate:"required"`
FileType string `query:"file_type" validate:"required"`
Action string `query:"action" validate:"required,oneof=put get"`
Action string `query:"action" validate:"required,oneof=PUT GET"`
// Expiration in seconds
Expiration int `query:"expiration" validate:"required,min=1,max=604800"`
}
Expand Down
2 changes: 1 addition & 1 deletion internal/port/httpserver/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func requestLoggerMiddleware() echo.MiddlewareFunc {
attrs = append(attrs, slog.String("error", v.Error.Error()))
logger.LogAttrs(ctx, slog.LevelError, "REQUEST", attrs...)
}
return nil
return v.Error
}

return middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
Expand Down
2 changes: 0 additions & 2 deletions internal/port/httpserver/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package httpserver

import (
"net/http"
"vibrain/internal/pkg/logger"

"github.com/labstack/echo/v4"
)
Expand Down Expand Up @@ -30,7 +29,6 @@ func JsonResponse(c echo.Context, code int, data interface{}) error {
}

func ErrorResponse(c echo.Context, code int, err error) error {
logger.FromContext(c.Request().Context()).Error("error response", "err", err)
_ = c.JSON(code, JSONResult{
Success: false,
Code: code,
Expand Down
15 changes: 7 additions & 8 deletions internal/port/httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package httpserver
import (
"context"
"fmt"
"os"
"os/exec"
"syscall"
"vibrain/internal/pkg/cache"
Expand Down Expand Up @@ -59,13 +58,13 @@ func New(pool *db.Pool, llm *llms.LLM, opts ...Option) (*Service, error) {
llm: llm,
}
s.Server.Validator = &CustomValidator{validator: validator.New()}
if config.Settings.DebugUI {
logger.Default.Info("debug ui enabled")
s.uiCmd = exec.Command("bun", "run", "dev")
s.uiCmd.Dir = "web"
s.uiCmd.Stdout = os.Stdout
s.uiCmd.Stderr = os.Stderr
}
// if config.Settings.DebugUI {
// logger.Default.Info("debug ui enabled")
// s.uiCmd = exec.Command("bun", "run", "dev")
// s.uiCmd.Dir = "web"
// s.uiCmd.Stdout = os.Stdout
// s.uiCmd.Stderr = os.Stderr
// }
for _, opt := range opts {
opt(s)
}
Expand Down
11 changes: 4 additions & 7 deletions web/src/components/thread-input.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ export function ThreadChatInput() {
const setTools = useStore((state) => state.setThreadTools);
const [modelSelecterValue, setModelSelecterValue] = useState("");

const [files, setFiles] = useStore((state) => [state.files, state.setFiles]);
const threadChatImages = useStore((state) => state.threadChatImages);
const setThreadChatImages = useStore((state) => state.setThreadChatImages);

Expand Down Expand Up @@ -129,9 +128,8 @@ export function ThreadChatInput() {
model: chatModel,
};

if (files.length > 0) {
const images = files.filter((file) => file.type.startsWith("image/"));
payload["metadata"] = { images: images };
if (threadChatImages.length > 0) {
payload["metadata"] = { images: threadChatImages };
}

const res = await post(
Expand All @@ -143,7 +141,6 @@ export function ThreadChatInput() {
},
onSuccess: (data) => {
addThreadMessage(data);
setFiles([]);
setThreadChatImages([]);
},
onError: (error) => {
Expand Down Expand Up @@ -201,7 +198,7 @@ export function ThreadChatInput() {
}}
onClick={() => {
// Function to remove image
setFiles((prevImages) =>
setThreadChatImages((prevImages) =>
prevImages.filter((image) => image !== imgUrl),
);
}}
Expand Down Expand Up @@ -252,7 +249,7 @@ export function ThreadChatInput() {
<Textarea
placeholder="Shift + Enter to send"
radius="lg"
leftSection={<UploadButton />}
leftSection={<UploadButton useButton={true} />}
minRows={1}
maxRows={5}
autosize
Expand Down
16 changes: 4 additions & 12 deletions web/src/components/thread-settings.jsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import { Icon } from "@iconify/react/dist/iconify.js";
import {
Button,
FileButton,
Divider,
Group,
Modal,
MultiSelect,
NativeSelect,
Slider,
Stack,
Text,
TextInput,
Textarea,
} from "@mantine/core";
Expand All @@ -17,6 +14,7 @@ import { useMutation } from "@tanstack/react-query";
import { useEffect } from "react";
import { put } from "../libs/api";
import useStore from "../libs/store";
import { UploadButton } from "./upload-button";

export function ThreadSettingsModal() {
const [isOpen, setIsOpen] = useStore((state) => [
Expand Down Expand Up @@ -99,6 +97,8 @@ export function ThreadSettingsModal() {
onClose={() => setIsOpen(false)}
title="Thread Settings"
>
<UploadButton />
<Divider my="sm" variant="dashed" />
<form
onSubmit={form.onSubmit(async (values) => {
console.log(values);
Expand Down Expand Up @@ -167,14 +167,6 @@ export function ThreadSettingsModal() {
searchable
/>
</Stack>
<FileButton
size="sm"
variant="transparent"
multiple
leftSection={<Icon icon="tabler:upload"></Icon>}
>
{(props) => <Button {...props}>Upload image</Button>}
</FileButton>
<Group justify="flex-end" mt="md">
<Button type="submit" onClick={() => setIsOpen(false)}>
Submit
Expand Down
Loading

0 comments on commit 3d43acd

Please sign in to comment.