Skip to content

Commit

Permalink
AI Prompt improvements (#32)
Browse files Browse the repository at this point in the history
* prompts json

* prompts improvements

* clean up

* Fixing texts

* Fixing texts

* tests

* bug fixes

* fixing ux

* fixing unit tests

* refactoring

* one database connection per session

* clean up

* fixing issues

---------

Co-authored-by: Maksym Bilan <>
  • Loading branch information
maximbilan authored Sep 16, 2024
1 parent c4535d4 commit a3cd83a
Show file tree
Hide file tree
Showing 18 changed files with 156 additions and 128 deletions.
23 changes: 13 additions & 10 deletions internal/analysis/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import (
)

// Reqeust an analysis of the user's journal entries
func Request(notes []string, locale translator.Locale) *string {
ctx := context.Background()
client := createClient(ctx)
func Request(notes []string, locale translator.Locale, ctx *context.Context, header *string) *string {
ai := createAI()

prompt := translator.Translate(locale, "ai_analysis_prompt")
systemPrompt := translator.Prompt(locale, "ai_analysis_system_message")
userPrompt := translator.Prompt(locale, "ai_analysis_user_message")
for index, note := range notes {
prompt += fmt.Sprintf("%d. %s ", index+1, note)
userPrompt += fmt.Sprintf("%d. %s ", index+1, note)
}

var responseSchema = generateSchema[Response]()
Expand All @@ -29,18 +29,17 @@ func Request(notes []string, locale translator.Locale) *string {
Strict: openai.Bool(true),
}

chat, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
chat, err := ai.Chat.Completions.New(*ctx, openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
openai.SystemMessage(systemPrompt),
openai.UserMessage(userPrompt),
}),
ResponseFormat: openai.F[openai.ChatCompletionNewParamsResponseFormatUnion](
openai.ResponseFormatJSONSchemaParam{
Type: openai.F(openai.ResponseFormatJSONSchemaTypeJSONSchema),
JSONSchema: openai.F(schemaParam),
},
),
// only certain models can perform structured outputs
// Model: openai.F(openai.ChatModelGPT4o2024_08_06),
Model: openai.F(openai.ChatModelGPT4oMini),
})

Expand All @@ -58,7 +57,11 @@ func Request(notes []string, locale translator.Locale) *string {

var analysis string
if response.Text != "" {
analysis = fmt.Sprintf("%s%s", translator.Translate(locale, "weekly_analysis"), response.Text)
if header != nil {
analysis = fmt.Sprintf("%s%s", translator.Translate(locale, *header), response.Text)
} else {
analysis = response.Text
}
return &analysis
} else {
return nil
Expand Down
3 changes: 1 addition & 2 deletions internal/analysis/client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package analysis

import (
"context"
"os"

"github.com/invopop/jsonschema"
Expand All @@ -10,7 +9,7 @@ import (
)

// Create a client for the OpenAI API
func createClient(ctx context.Context) *openai.Client {
func createAI() *openai.Client {
client := openai.NewClient(
option.WithAPIKey(os.Getenv("CAPY_AI_KEY")),
)
Expand Down
2 changes: 1 addition & 1 deletion internal/bot/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func handleAnalysis(session *Session) {
sendMessage("analysis_waiting", session)

// Request the analysis
analysis := analysis.Request(strings, session.Locale())
analysis := analysis.Request(strings, session.Locale(), session.Context, nil)
if analysis != nil {
// Send the analysis
setOutputText(*analysis, session)
Expand Down
16 changes: 13 additions & 3 deletions internal/bot/bot.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package bot

import (
"context"
"log"
"net/http"

"github.com/capymind/internal/firestore"
"github.com/capymind/internal/telegram"
)

Expand All @@ -15,15 +17,21 @@ func Parse(w http.ResponseWriter, r *http.Request) {
return
}

// Create a context
ctx := context.Background()

// Creat a database connection
firestore.CreateClient(&ctx)

// Create a user
user := createUser(*update)
user := createUser(*update, &ctx)
if user == nil {
log.Printf("[Bot] No user to process: message_id=%d", update.Message.ID)
return
}

// Update the user's data in the database if necessary
updatedUser := updateUser(user)
updatedUser := updateUser(user, &ctx)

// Create a job
job := createJob(*update)
Expand All @@ -33,9 +41,11 @@ func Parse(w http.ResponseWriter, r *http.Request) {
}

// Create and start a session
session := createSession(job, updatedUser)
session := createSession(job, updatedUser, &ctx)
// Execute the job
handleSession(session)
// Send the response
finishSession(session)
// Close the database connection
firestore.CloseClient()
}
19 changes: 0 additions & 19 deletions internal/bot/database.go

This file was deleted.

16 changes: 3 additions & 13 deletions internal/bot/notes.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@ func finishNote(session *Session) {

// Handle the note command
func handleLastNote(session *Session) {
client, ctx := createClient()
defer client.Close()

userID := session.User.ID
note, err := firestore.LastNote(ctx, client, userID)
note, err := firestore.LastNote(session.Context, userID)
if err != nil {
log.Printf("[Bot] Error getting last note from firestore, %s", err.Error())
}
Expand All @@ -46,10 +43,6 @@ func handleLastNote(session *Session) {

// Save a note
func saveNote(text string, session *Session) {
// Setup the database connection
client, ctx := createClient()
defer client.Close()

// Note data
timestamp := time.Now()
var note = firestore.Note{
Expand All @@ -59,18 +52,15 @@ func saveNote(text string, session *Session) {
}

// Save the note
err := firestore.NewNote(ctx, client, *session.User, note)
err := firestore.NewNote(session.Context, *session.User, note)
if err != nil {
log.Printf("[Bot] Error saving note in firestore, %s", err.Error())
}
}

// Get the user's notes
func getNotes(session *Session) []firestore.Note {
client, ctx := createClient()
defer client.Close()

notes, err := firestore.GetNotes(ctx, client, session.User.ID)
notes, err := firestore.GetNotes(session.Context, session.User.ID)
if err != nil {
log.Printf("[Bot] Error getting notes from firestore, %s", err.Error())
}
Expand Down
16 changes: 10 additions & 6 deletions internal/bot/session.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package bot

import (
"context"

"github.com/capymind/internal/firestore"
"github.com/capymind/internal/translator"
)

type Session struct {
Job *Job
User *firestore.User
Job *Job
User *firestore.User
Context *context.Context
}

// Return the locale of the current user
Expand All @@ -20,14 +23,15 @@ func (session *Session) Locale() translator.Locale {

// Save the user's data
func (session *Session) SaveUser() {
saveUser(session.User)
saveUser(session.User, session.Context)
}

// Create a session
func createSession(job *Job, user *firestore.User) *Session {
func createSession(job *Job, user *firestore.User, context *context.Context) *Session {
session := Session{
Job: job,
User: user,
Job: job,
User: user,
Context: context,
}
return &session
}
Expand Down
18 changes: 6 additions & 12 deletions internal/bot/user.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bot

import (
"context"
"log"

"github.com/capymind/internal/firestore"
Expand All @@ -9,7 +10,7 @@ import (
)

// Create a user from an update
func createUser(update telegram.Update) *firestore.User {
func createUser(update telegram.Update, ctx *context.Context) *firestore.User {
var chatID int64
var telegramUser *telegram.User

Expand Down Expand Up @@ -43,17 +44,13 @@ func createUser(update telegram.Update) *firestore.User {
}

// Update the user's data in the database if necessary
func updateUser(user *firestore.User) *firestore.User {
func updateUser(user *firestore.User, ctx *context.Context) *firestore.User {
if user == nil {
return nil
}

// Setup the database connection
client, ctx := createClient()
defer client.Close()

// Check if the user exists
fetchedUser, err := firestore.GetUser(ctx, client, user.ID)
fetchedUser, err := firestore.GetUser(ctx, user.ID)
if err != nil {
log.Printf("[User] Error fetching user from firestore, %s", err.Error())

Expand Down Expand Up @@ -81,11 +78,8 @@ func updateUser(user *firestore.User) *firestore.User {
}

// Save a user to the database
func saveUser(user *firestore.User) {
client, ctx := createClient()
defer client.Close()

err := firestore.SaveUser(ctx, client, *user)
func saveUser(user *firestore.User, ctx *context.Context) {
err := firestore.SaveUser(ctx, *user)
if err != nil {
log.Printf("[User] Error saving user to firestore, %s", err.Error())
}
Expand Down
7 changes: 5 additions & 2 deletions internal/bot/user_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bot

import (
"context"
"testing"

"github.com/capymind/internal/telegram"
Expand All @@ -25,7 +26,8 @@ func TestCreateUserFromMessage(t *testing.T) {
},
}

user := createUser(update)
ctx := context.Background()
user := createUser(update, &ctx)
if user == nil {
t.Fatalf("User is nil")
}
Expand Down Expand Up @@ -71,7 +73,8 @@ func TestUserFromCallback(t *testing.T) {
},
}

user := createUser(update)
ctx := context.Background()
user := createUser(update, &ctx)
if user == nil {
t.Fatalf("User is nil")
}
Expand Down
23 changes: 20 additions & 3 deletions internal/firestore/firestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package firestore

import (
"context"
"log"
"os"

"cloud.google.com/go/firestore"
"google.golang.org/api/option"
)

var client *firestore.Client

// Path to the credentials file
func credentialsPath() string {
var path = "credentials.json"
Expand All @@ -20,16 +23,16 @@ func credentialsPath() string {
}

// Client for Firestore
func NewClient(ctx context.Context) (*firestore.Client, error) {
func newClient(ctx *context.Context) (*firestore.Client, error) {
projectID := os.Getenv("CAPY_PROJECT_ID")
var client *firestore.Client
var err error

if os.Getenv("CLOUD") == "true" {
client, err = firestore.NewClient(ctx, projectID)
client, err = firestore.NewClient(*ctx, projectID)
} else {
path := credentialsPath()
client, err = firestore.NewClient(ctx, projectID, option.WithCredentialsFile(path))
client, err = firestore.NewClient(*ctx, projectID, option.WithCredentialsFile(path))
}

if err != nil {
Expand All @@ -38,3 +41,17 @@ func NewClient(ctx context.Context) (*firestore.Client, error) {

return client, nil
}

// Create a new Firestore database connection
func CreateClient(ctx *context.Context) {
newClient, err := newClient(ctx)
if err != nil {
log.Printf("[Firestore] Error creating firestore client, %s", err.Error())
}
client = newClient
}

// Close the Firestore database connection
func CloseClient() {
client.Close()
}
Loading

0 comments on commit a3cd83a

Please sign in to comment.