From 654e15bb6b0b307ef5bc8d934c3c19e48b15e671 Mon Sep 17 00:00:00 2001 From: Eric Bower Date: Sat, 25 Jan 2025 14:11:52 -0500 Subject: [PATCH] refactor(pgs): decouple from rest of codebase --- auth/__snapshots__/api_test.snap | 2 +- cmd/pgs/ssh/main.go | 25 ++- cmd/pgs/web/main.go | 24 ++- cmd/scripts/clean-object-store/clean.go | 6 +- db/db.go | 70 ++++--- db/postgres/storage.go | 217 ++-------------------- db/stub/stub.go | 32 +--- feeds/ssh.go | 2 +- go.mod | 7 +- go.sum | 8 +- pastes/ssh.go | 2 +- pgs/cli.go | 17 +- pgs/{wish.go => cli_wish.go} | 9 +- pgs/config.go | 92 ++++++++-- pgs/db/db.go | 26 +++ pgs/db/memory.go | 171 +++++++++++++++++ pgs/db/postgres.go | 232 ++++++++++++++++++++++++ pgs/ssh.go | 53 ++---- pgs/ssh_test.go | 188 +++++++++++++++++++ pgs/tunnel.go | 30 +-- pgs/uploader.go | 84 +++++---- pgs/web.go | 200 ++++++++++---------- pgs/web_asset_handler.go | 8 +- pgs/web_cache.go | 4 +- pgs/web_test.go | 102 +++-------- pico/ssh.go | 2 +- pipe/ssh.go | 2 +- prose/ssh.go | 2 +- shared/ssh.go | 54 +++--- tui/analytics/analytics.go | 6 +- 30 files changed, 1073 insertions(+), 604 deletions(-) rename pgs/{wish.go => cli_wish.go} (96%) create mode 100644 pgs/db/db.go create mode 100644 pgs/db/memory.go create mode 100644 pgs/db/postgres.go create mode 100644 pgs/ssh_test.go diff --git a/auth/__snapshots__/api_test.snap b/auth/__snapshots__/api_test.snap index 35b27fac..41e24a9a 100755 --- a/auth/__snapshots__/api_test.snap +++ b/auth/__snapshots__/api_test.snap @@ -87,7 +87,7 @@ successfully added pico+ user --- [TestUser - 1] -[{"id":"1","user_id":"user-1","name":"my-key","key":"nice-pubkey","created_at":"0001-01-01T00:00:00Z"}] +[{"id":"1","user_id":"user-1","name":"my-key","public_key":"nice-pubkey","created_at":"0001-01-01T00:00:00Z"}] --- [TestAuthApi/rss - 1] diff --git a/cmd/pgs/ssh/main.go b/cmd/pgs/ssh/main.go index 0488f563..cc714de6 100644 --- a/cmd/pgs/ssh/main.go +++ b/cmd/pgs/ssh/main.go @@ -1,7 +1,28 @@ package main -import "github.com/picosh/pico/pgs" +import ( + "github.com/picosh/pico/pgs" + pgsdb "github.com/picosh/pico/pgs/db" + "github.com/picosh/pico/shared" + "github.com/picosh/pico/shared/storage" + "github.com/picosh/utils" +) func main() { - pgs.StartSshServer() + minioURL := utils.GetEnv("MINIO_URL", "") + minioUser := utils.GetEnv("MINIO_ROOT_USER", "") + minioPass := utils.GetEnv("MINIO_ROOT_PASSWORD", "") + dbURL := utils.GetEnv("DATABASE_URL", "") + logger := shared.CreateLogger("pgs") + dbpool, err := pgsdb.NewDB(dbURL, logger) + if err != nil { + panic(err) + } + st, err := storage.NewStorageMinio(logger, minioURL, minioUser, minioPass) + if err != nil { + panic(err) + } + cfg := pgs.NewPgsConfig(logger, dbpool, st) + killCh := make(chan error) + pgs.StartSshServer(cfg, killCh) } diff --git a/cmd/pgs/web/main.go b/cmd/pgs/web/main.go index 38289b45..726f0730 100644 --- a/cmd/pgs/web/main.go +++ b/cmd/pgs/web/main.go @@ -1,7 +1,27 @@ package main -import "github.com/picosh/pico/pgs" +import ( + "github.com/picosh/pico/pgs" + pgsdb "github.com/picosh/pico/pgs/db" + "github.com/picosh/pico/shared" + "github.com/picosh/pico/shared/storage" + "github.com/picosh/utils" +) func main() { - pgs.StartApiServer() + minioURL := utils.GetEnv("MINIO_URL", "") + minioUser := utils.GetEnv("MINIO_ROOT_USER", "") + minioPass := utils.GetEnv("MINIO_ROOT_PASSWORD", "") + dbURL := utils.GetEnv("DATABASE_URL", "") + logger := shared.CreateLogger("pgs") + dbpool, err := pgsdb.NewDB(dbURL, logger) + if err != nil { + panic(err) + } + st, err := storage.NewStorageMinio(logger, minioURL, minioUser, minioPass) + if err != nil { + panic(err) + } + cfg := pgs.NewPgsConfig(logger, dbpool, st) + pgs.StartApiServer(cfg) } diff --git a/cmd/scripts/clean-object-store/clean.go b/cmd/scripts/clean-object-store/clean.go index 5c0db3bc..6ab37ba1 100644 --- a/cmd/scripts/clean-object-store/clean.go +++ b/cmd/scripts/clean-object-store/clean.go @@ -6,8 +6,8 @@ import ( "strings" "github.com/picosh/pico/db" - "github.com/picosh/pico/db/postgres" "github.com/picosh/pico/pgs" + pgsdb "github.com/picosh/pico/pgs/db" "github.com/picosh/pico/shared" "github.com/picosh/pico/shared/storage" "github.com/picosh/utils" @@ -41,10 +41,10 @@ func main() { picoCfg.MinioURL = os.Getenv("MINIO_URL") picoCfg.MinioUser = os.Getenv("MINIO_ROOT_USER") picoCfg.MinioPass = os.Getenv("MINIO_ROOT_PASSWORD") - picoDb := postgres.NewDB(picoCfg.DbURL, picoCfg.Logger) + picoDb, err := pgsdb.NewDB(picoCfg.DbURL, picoCfg.Logger) + bail(err) var st storage.StorageServe - var err error st, err = storage.NewStorageMinio(logger, picoCfg.MinioURL, picoCfg.MinioUser, picoCfg.MinioPass) bail(err) diff --git a/db/db.go b/db/db.go index eea5799e..e6be8bcd 100644 --- a/db/db.go +++ b/db/db.go @@ -15,18 +15,18 @@ var ErrNameInvalid = errors.New("username has invalid characters in it") var ErrPublicKeyTaken = errors.New("public key is already associated with another user") type PublicKey struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Key string `json:"key"` - CreatedAt *time.Time `json:"created_at"` + ID string `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id"` + Name string `json:"name" db:"name"` + Key string `json:"public_key" db:"public_key"` + CreatedAt *time.Time `json:"created_at" db:"created_at"` } type User struct { - ID string `json:"id"` - Name string `json:"name"` - PublicKey *PublicKey `json:"public_key,omitempty"` - CreatedAt *time.Time `json:"created_at"` + ID string `json:"id" db:"id"` + Name string `json:"name" db:"name"` + PublicKey *PublicKey `json:"public_key,omitempty" db:"public_key,omitempty"` + CreatedAt *time.Time `json:"created_at" db:"created_at"` } type PostData struct { @@ -53,20 +53,20 @@ func (p *PostData) Scan(value interface{}) error { } type Project struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - ProjectDir string `json:"project_dir"` - Username string `json:"username"` - Acl ProjectAcl `json:"acl"` - Blocked string `json:"blocked"` - CreatedAt *time.Time `json:"created_at"` - UpdatedAt *time.Time `json:"updated_at"` + ID string `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id"` + Name string `json:"name" db:"name"` + ProjectDir string `json:"project_dir" db:"project_dir"` + Username string `json:"username" db:"username"` + Acl ProjectAcl `json:"acl" db:"acl"` + Blocked string `json:"blocked" db:"blocked"` + CreatedAt *time.Time `json:"created_at" db:"created_at"` + UpdatedAt *time.Time `json:"updated_at" db:"updated_at"` } type ProjectAcl struct { - Type string `json:"type"` - Data []string `json:"data"` + Type string `json:"type" db:"type"` + Data []string `json:"data" db:"data"` } // Make the Attrs struct implement the driver.Valuer interface. This method @@ -218,13 +218,13 @@ type Token struct { } type FeatureFlag struct { - ID string `json:"id"` - UserID string `json:"user_id"` - PaymentHistoryID string `json:"payment_history_id"` - Name string `json:"name"` - CreatedAt *time.Time `json:"created_at"` - ExpiresAt *time.Time `json:"expires_at"` - Data FeatureFlagData `json:"data"` + ID string `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id"` + PaymentHistoryID string `json:"payment_history_id" db:"payment_history_id"` + Name string `json:"name" db:"name"` + CreatedAt *time.Time `json:"created_at" db:"created_at"` + ExpiresAt *time.Time `json:"expires_at" db:"expires_at"` + Data FeatureFlagData `json:"data" db:"data"` } func NewFeatureFlag(userID, name string, storageMax uint64, fileMax int64, specialFileMax int64) *FeatureFlag { @@ -268,9 +268,9 @@ func (ff *FeatureFlag) IsValid() bool { } type FeatureFlagData struct { - StorageMax uint64 `json:"storage_max"` - FileMax int64 `json:"file_max"` - SpecialFileMax int64 `json:"special_file_max"` + StorageMax uint64 `json:"storage_max" db:"storage_max"` + FileMax int64 `json:"file_max" db:"file_max"` + SpecialFileMax int64 `json:"special_file_max" db:"special_file_max"` } // Make the Attrs struct implement the driver.Valuer interface. This method @@ -354,6 +354,7 @@ type DB interface { FindUserForName(name string) (*User, error) FindUserForNameAndKey(name string, pubkey string) (*User, error) FindUserForKey(name string, pubkey string) (*User, error) + FindUserByPubkey(pubkey string) (*User, error) FindUser(userID string) (*User, error) ValidateName(name string) (bool, error) SetUserName(userID string, name string) error @@ -405,16 +406,7 @@ type DB interface { FindFeedItemsByPostID(postID string) ([]*FeedItem, error) UpsertProject(userID, name, projectDir string) (*Project, error) - InsertProject(userID, name, projectDir string) (string, error) - UpdateProject(userID, name string) error - UpdateProjectAcl(userID, name string, acl ProjectAcl) error - LinkToProject(userID, projectID, projectDir string, commit bool) error - RemoveProject(projectID string) error FindProjectByName(userID, name string) (*Project, error) - FindProjectLinks(userID, name string) ([]*Project, error) - FindProjectsByUser(userID string) ([]*Project, error) - FindProjectsByPrefix(userID, name string) ([]*Project, error) - FindAllProjects(page *Pager, by string) (*Paginate[*Project], error) Close() error } diff --git a/db/postgres/storage.go b/db/postgres/storage.go index 28d43d06..4905f95f 100644 --- a/db/postgres/storage.go +++ b/db/postgres/storage.go @@ -256,7 +256,6 @@ const ( sqlInsertProject = `INSERT INTO projects (user_id, name, project_dir) VALUES ($1, $2, $3) RETURNING id;` sqlUpdateProject = `UPDATE projects SET updated_at = $3 WHERE user_id = $1 AND name = $2;` - sqlUpdateProjectAcl = `UPDATE projects SET acl = $3, updated_at = $4 WHERE user_id = $1 AND name = $2;` sqlFindProjectByName = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = $2;` sqlSelectProjectCount = `SELECT count(id) FROM projects` sqlFindProjectsByUser = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 ORDER BY name ASC, updated_at DESC;` @@ -624,6 +623,22 @@ func (me *PsqlDB) FindUserForKey(username string, key string) (*db.User, error) return nil, err } +func (me *PsqlDB) FindUserByPubkey(key string) (*db.User, error) { + me.Logger.Info("attempting to find user with only public key", "key", key) + pk, err := me.FindPublicKeyForKey(key) + if err != nil { + return nil, err + } + + me.Logger.Info("found pubkey, looking for user", "key", key, "userId", pk.UserID) + user, err := me.FindUser(pk.UserID) + if err != nil { + return nil, err + } + user.PublicKey = pk + return user, nil +} + func (me *PsqlDB) FindUser(userID string) (*db.User, error) { user := &db.User{} var un sql.NullString @@ -1639,60 +1654,6 @@ func (me *PsqlDB) UpdateProject(userID, name string) error { return err } -func (me *PsqlDB) UpdateProjectAcl(userID, name string, acl db.ProjectAcl) error { - _, err := me.Db.Exec(sqlUpdateProjectAcl, userID, name, acl, time.Now()) - return err -} - -func (me *PsqlDB) LinkToProject(userID, projectID, projectDir string, commit bool) error { - linkToProject, err := me.FindProjectByName(userID, projectDir) - if err != nil { - return err - } - isAlreadyLinked := linkToProject.Name != linkToProject.ProjectDir - sameProject := linkToProject.ID == projectID - - /* - A project linked to another project which is also linked to a - project is forbidden. CI/CD Example: - - ProjectProd links to ProjectStaging - - ProjectStaging links to ProjectMain - - We merge `main` and trigger a deploy which uploads to ProjectMain - - All three get updated immediately - This scenario was not the intent of our CI/CD. What we actually - wanted was to create a snapshot of ProjectMain and have ProjectStaging - link to the snapshot, but that's not the intended design of pgs. - - So we want to close that gap here. - - We ensure that `project.Name` and `project.ProjectDir` are identical - when there is no aliasing. - */ - if !sameProject && isAlreadyLinked { - return fmt.Errorf( - "cannot link (%s) to (%s) because it is also a link to (%s)", - projectID, - projectDir, - linkToProject.ProjectDir, - ) - } - - if commit { - _, err = me.Db.Exec( - sqlLinkToProject, - projectDir, - time.Now(), - projectID, - ) - } - return err -} - -func (me *PsqlDB) RemoveProject(projectID string) error { - _, err := me.Db.Exec(sqlRemoveProject, projectID) - return err -} - func (me *PsqlDB) FindProjectByName(userID, name string) (*db.Project, error) { project := &db.Project{} r := me.Db.QueryRow(sqlFindProjectByName, userID, name) @@ -1713,152 +1674,6 @@ func (me *PsqlDB) FindProjectByName(userID, name string) (*db.Project, error) { return project, nil } -func (me *PsqlDB) FindProjectLinks(userID, name string) ([]*db.Project, error) { - var projects []*db.Project - rs, err := me.Db.Query(sqlFindProjectLinks, userID, name) - if err != nil { - return nil, err - } - for rs.Next() { - project := &db.Project{} - err := rs.Scan( - &project.ID, - &project.UserID, - &project.Name, - &project.ProjectDir, - &project.Acl, - &project.Blocked, - &project.CreatedAt, - &project.UpdatedAt, - ) - if err != nil { - return nil, err - } - - projects = append(projects, project) - } - - if rs.Err() != nil { - return nil, rs.Err() - } - - return projects, nil -} - -func (me *PsqlDB) FindProjectsByPrefix(userID, prefix string) ([]*db.Project, error) { - var projects []*db.Project - rs, err := me.Db.Query(sqlFindProjectsByPrefix, userID, prefix+"%") - if err != nil { - return nil, err - } - for rs.Next() { - project := &db.Project{} - err := rs.Scan( - &project.ID, - &project.UserID, - &project.Name, - &project.ProjectDir, - &project.Acl, - &project.Blocked, - &project.CreatedAt, - &project.UpdatedAt, - ) - if err != nil { - return nil, err - } - - projects = append(projects, project) - } - - if rs.Err() != nil { - return nil, rs.Err() - } - - return projects, nil -} - -func (me *PsqlDB) FindProjectsByUser(userID string) ([]*db.Project, error) { - var projects []*db.Project - rs, err := me.Db.Query(sqlFindProjectsByUser, userID) - if err != nil { - return nil, err - } - for rs.Next() { - project := &db.Project{} - err := rs.Scan( - &project.ID, - &project.UserID, - &project.Name, - &project.ProjectDir, - &project.Acl, - &project.Blocked, - &project.CreatedAt, - &project.UpdatedAt, - ) - if err != nil { - return nil, err - } - - projects = append(projects, project) - } - - if rs.Err() != nil { - return nil, rs.Err() - } - - return projects, nil -} - -func (me *PsqlDB) FindAllProjects(page *db.Pager, by string) (*db.Paginate[*db.Project], error) { - var projects []*db.Project - sqlFindAllProjects := fmt.Sprintf(` - SELECT projects.id, user_id, app_users.name as username, projects.name, project_dir, projects.acl, projects.blocked, projects.created_at, projects.updated_at - FROM projects - LEFT JOIN app_users ON app_users.id = projects.user_id - ORDER BY %s DESC - LIMIT $1 OFFSET $2`, by) - rs, err := me.Db.Query(sqlFindAllProjects, page.Num, page.Num*page.Page) - if err != nil { - return nil, err - } - for rs.Next() { - project := &db.Project{} - err := rs.Scan( - &project.ID, - &project.UserID, - &project.Username, - &project.Name, - &project.ProjectDir, - &project.Acl, - &project.Blocked, - &project.CreatedAt, - &project.UpdatedAt, - ) - if err != nil { - return nil, err - } - - projects = append(projects, project) - } - - if rs.Err() != nil { - return nil, rs.Err() - } - - var count int - err = me.Db.QueryRow(sqlSelectProjectCount).Scan(&count) - if err != nil { - return nil, err - } - - pager := &db.Paginate[*db.Project]{ - Data: projects, - Total: int(math.Ceil(float64(count) / float64(page.Num))), - } - - return pager, nil -} - func (me *PsqlDB) InsertToken(userID, name string) (string, error) { var token string err := me.Db.QueryRow(sqlInsertToken, userID, name).Scan(&token) diff --git a/db/stub/stub.go b/db/stub/stub.go index 31707d7c..ade42ae6 100644 --- a/db/stub/stub.go +++ b/db/stub/stub.go @@ -69,6 +69,10 @@ func (me *StubDB) FindUserForKey(username string, key string) (*db.User, error) return nil, notImpl } +func (me *StubDB) FindUserByPubkey(key string) (*db.User, error) { + return nil, notImpl +} + func (me *StubDB) FindUser(userID string) (*db.User, error) { return nil, notImpl } @@ -225,38 +229,10 @@ func (me *StubDB) UpdateProject(userID, name string) error { return notImpl } -func (me *StubDB) UpdateProjectAcl(userID, name string, acl db.ProjectAcl) error { - return notImpl -} - -func (me *StubDB) LinkToProject(userID, projectID, projectDir string, commit bool) error { - return notImpl -} - -func (me *StubDB) RemoveProject(projectID string) error { - return notImpl -} - func (me *StubDB) FindProjectByName(userID, name string) (*db.Project, error) { return &db.Project{}, notImpl } -func (me *StubDB) FindProjectLinks(userID, name string) ([]*db.Project, error) { - return []*db.Project{}, notImpl -} - -func (me *StubDB) FindProjectsByPrefix(userID, prefix string) ([]*db.Project, error) { - return []*db.Project{}, notImpl -} - -func (me *StubDB) FindProjectsByUser(userID string) ([]*db.Project, error) { - return []*db.Project{}, notImpl -} - -func (me *StubDB) FindAllProjects(page *db.Pager, by string) (*db.Paginate[*db.Project], error) { - return &db.Paginate[*db.Project]{}, notImpl -} - func (me *StubDB) InsertToken(userID, name string) (string, error) { return "", notImpl } diff --git a/feeds/ssh.go b/feeds/ssh.go index cea5ce77..625faace 100644 --- a/feeds/ssh.go +++ b/feeds/ssh.go @@ -69,7 +69,7 @@ func StartSshServer() { } handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap) - sshAuth := shared.NewSshAuthHandler(dbh, logger, cfg) + sshAuth := shared.NewSshAuthHandler(dbh, logger) s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), diff --git a/go.mod b/go.mod index 3f8b8749..17aa9c60 100644 --- a/go.mod +++ b/go.mod @@ -42,8 +42,10 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/feeds v1.2.0 + github.com/jmoiron/sqlx v1.4.0 github.com/lib/pq v1.10.9 github.com/microcosm-cc/bluemonday v1.0.27 + github.com/minio/minio-go/v7 v7.0.80 github.com/mmcdole/gofeed v1.3.0 github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.15.3-0.20240912151726-82936c5ea257 @@ -53,6 +55,7 @@ require ( github.com/picosh/send v0.0.0-20250121195737-daab6db117d5 github.com/picosh/tunkit v0.0.0-20240905223921-532404cef9d9 github.com/picosh/utils v0.0.0-20241120033529-8ca070c09bf4 + github.com/pkg/sftp v1.13.7 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/sendgrid/sendgrid-go v3.16.0+incompatible github.com/simplesurance/go-ip-anonymizer v0.0.0-20200429124537-35a880f8e87d @@ -173,7 +176,7 @@ require ( github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect - github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-xmlfmt/xmlfmt v1.1.2 // indirect github.com/goccy/go-json v0.10.3 // indirect @@ -234,7 +237,6 @@ require ( github.com/miekg/dns v1.1.62 // indirect github.com/minio/madmin-go/v3 v3.0.77 // indirect github.com/minio/md5-simd v1.1.2 // indirect - github.com/minio/minio-go/v7 v7.0.80 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect @@ -256,7 +258,6 @@ require ( github.com/picosh/go-rsync-receiver v0.0.0-20250121150813-93b4f1b7aa4b // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pkg/sftp v1.13.7 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/pquerna/cachecontrol v0.2.0 // indirect github.com/prometheus/client_golang v1.20.5 // indirect diff --git a/go.sum b/go.sum index 0e6f1bed..057486b3 100644 --- a/go.sum +++ b/go.sum @@ -355,8 +355,8 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= -github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= @@ -544,6 +544,8 @@ github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -613,6 +615,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sixel v0.0.5 h1:55w2FR5ncuhKhXrM5ly1eiqMQfZsnAHIpYNGZX03Cv8= github.com/mattn/go-sixel v0.0.5/go.mod h1:h2Sss+DiUEHy0pUqcIB6PFXo5Cy8sTQEFr3a9/5ZLNw= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= diff --git a/pastes/ssh.go b/pastes/ssh.go index 30502873..96b8077c 100644 --- a/pastes/ssh.go +++ b/pastes/ssh.go @@ -67,7 +67,7 @@ func StartSshServer() { "fallback": filehandlers.NewScpPostHandler(dbh, cfg, hooks), } handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap) - sshAuth := shared.NewSshAuthHandler(dbh, logger, cfg) + sshAuth := shared.NewSshAuthHandler(dbh, logger) s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), diff --git a/pgs/cli.go b/pgs/cli.go index cd7cb5a2..56c1c38b 100644 --- a/pgs/cli.go +++ b/pgs/cli.go @@ -7,10 +7,12 @@ import ( "log/slog" "path/filepath" "strings" + "time" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss/table" "github.com/picosh/pico/db" + pgsdb "github.com/picosh/pico/pgs/db" "github.com/picosh/pico/shared" "github.com/picosh/pico/tui/common" sst "github.com/picosh/pobj/storage" @@ -124,12 +126,12 @@ type Cmd struct { Session utils.CmdSession Log *slog.Logger Store sst.ObjectStorage - Dbpool db.DB + Dbpool pgsdb.PgsDB Write bool Styles common.Styles Width int Height int - Cfg *shared.ConfigSite + Cfg *PgsConfig } func (c *Cmd) output(out string) { @@ -205,7 +207,7 @@ func (c *Cmd) help() { } func (c *Cmd) stats(cfgMaxSize uint64) error { - ff, err := c.Dbpool.FindFeatureForUser(c.User.ID, "plus") + ff, err := c.Dbpool.FindFeature(c.User.ID, "plus") if err != nil { ff = db.NewFeatureFlag(c.User.ID, "plus", cfgMaxSize, 0, 0) } @@ -539,7 +541,14 @@ func (c *Cmd) cache(projectName string) error { } func (c *Cmd) cacheAll() error { - isAdmin := c.Dbpool.HasFeatureForUser(c.User.ID, "admin") + isAdmin := false + ff, _ := c.Dbpool.FindFeature(c.User.ID, "admin") + if ff != nil { + if ff.ExpiresAt.Before(time.Now()) { + isAdmin = true + } + } + if !isAdmin { return fmt.Errorf("must be admin to use this command") } diff --git a/pgs/wish.go b/pgs/cli_wish.go similarity index 96% rename from pgs/wish.go rename to pgs/cli_wish.go index a6f3f458..1e9c3e01 100644 --- a/pgs/wish.go +++ b/pgs/cli_wish.go @@ -11,19 +11,20 @@ import ( bm "github.com/charmbracelet/wish/bubbletea" "github.com/muesli/termenv" "github.com/picosh/pico/db" + pgsdb "github.com/picosh/pico/pgs/db" "github.com/picosh/pico/tui/common" sendutils "github.com/picosh/send/utils" "github.com/picosh/utils" ) -func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) { +func getUser(s ssh.Session, dbpool pgsdb.PgsDB) (*db.User, error) { if s.PublicKey() == nil { return nil, fmt.Errorf("key not found") } key := utils.KeyForKeyText(s.PublicKey()) - user, err := dbpool.FindUserForKey(s.User(), key) + user, err := dbpool.FindUserByPubkey(key) if err != nil { return nil, err } @@ -64,10 +65,10 @@ func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool { } func WishMiddleware(handler *UploadAssetHandler) wish.Middleware { - dbpool := handler.DBPool + dbpool := handler.Cfg.DB log := handler.Cfg.Logger cfg := handler.Cfg - store := handler.Storage + store := handler.Cfg.Storage return func(next ssh.Handler) ssh.Handler { return func(sesh ssh.Session) { diff --git a/pgs/config.go b/pgs/config.go index 9101b09e..22116a9e 100644 --- a/pgs/config.go +++ b/pgs/config.go @@ -2,19 +2,72 @@ package pgs import ( "fmt" + "log/slog" + "path/filepath" "time" - "github.com/picosh/pico/shared" + pgsdb "github.com/picosh/pico/pgs/db" + "github.com/picosh/pico/shared/storage" "github.com/picosh/utils" ) +type PgsConfig struct { + CacheControl string + CacheTTL time.Duration + Domain string + MaxAssetSize int64 + MaxSize uint64 + MaxSpecialFileSize int64 + SshHost string + SshPort string + StorageDir string + TxtPrefix string + WebPort string + WebProtocol string + + // This channel will receive the surrogate key for a project (e.g. static site) + // which will inform the caching layer to clear the cache for that site. + CacheClearingQueue chan string + // Database layer; it's just an interface that could be implemented + // with anything. + DB pgsdb.PgsDB + Logger *slog.Logger + // Where we store the static assets uploaded to our service. + Storage storage.StorageServe +} + +func (c *PgsConfig) AssetURL(username, projectName, fpath string) string { + if username == projectName { + return fmt.Sprintf( + "%s://%s.%s/%s", + c.WebProtocol, + username, + c.Domain, + fpath, + ) + } + + return fmt.Sprintf( + "%s://%s-%s.%s/%s", + c.WebProtocol, + username, + projectName, + c.Domain, + fpath, + ) +} + +func (c *PgsConfig) StaticPath(fname string) string { + return filepath.Join("pgs", fname) +} + var maxSize = uint64(25 * utils.MB) var maxAssetSize = int64(10 * utils.MB) // Needs to be small for caching files like _headers and _redirects. var maxSpecialFileSize = int64(5 * utils.KB) -func NewConfigSite() *shared.ConfigSite { +func NewPgsConfig(logger *slog.Logger, dbpool pgsdb.PgsDB, st storage.StorageServe) *PgsConfig { domain := utils.GetEnv("PGS_DOMAIN", "pgs.sh") port := utils.GetEnv("PGS_WEB_PORT", "3000") protocol := utils.GetEnv("PGS_PROTOCOL", "https") @@ -26,27 +79,28 @@ func NewConfigSite() *shared.ConfigSite { cacheControl := utils.GetEnv( "PGS_CACHE_CONTROL", fmt.Sprintf("max-age=%d", int(cacheTTL.Seconds()))) - minioURL := utils.GetEnv("MINIO_URL", "") - minioUser := utils.GetEnv("MINIO_ROOT_USER", "") - minioPass := utils.GetEnv("MINIO_ROOT_PASSWORD", "") - dbURL := utils.GetEnv("DATABASE_URL", "") - cfg := shared.ConfigSite{ - Domain: domain, - Port: port, - Protocol: protocol, - DbURL: dbURL, - StorageDir: storageDir, - CacheTTL: cacheTTL, + sshHost := utils.GetEnv("PGS_SSH_HOST", "0.0.0.0") + sshPort := utils.GetEnv("PGS_SSH_PORT", "2222") + + cfg := PgsConfig{ CacheControl: cacheControl, - MinioURL: minioURL, - MinioUser: minioUser, - MinioPass: minioPass, - Space: "pgs", - MaxSize: maxSize, + CacheTTL: cacheTTL, + Domain: domain, MaxAssetSize: maxAssetSize, + MaxSize: maxSize, MaxSpecialFileSize: maxSpecialFileSize, - Logger: shared.CreateLogger("pgs"), + SshHost: sshHost, + SshPort: sshPort, + StorageDir: storageDir, + TxtPrefix: "pgs", + WebPort: port, + WebProtocol: protocol, + + CacheClearingQueue: make(chan string, 100), + DB: dbpool, + Logger: logger, + Storage: st, } return &cfg diff --git a/pgs/db/db.go b/pgs/db/db.go new file mode 100644 index 00000000..b8d904c6 --- /dev/null +++ b/pgs/db/db.go @@ -0,0 +1,26 @@ +package pgsdb + +import "github.com/picosh/pico/db" + +type PgsDB interface { + FindUser(userID string) (*db.User, error) + FindUserByName(name string) (*db.User, error) + FindUserByPubkey(pubkey string) (*db.User, error) + FindUsers() ([]*db.User, error) + + FindFeature(userID string, name string) (*db.FeatureFlag, error) + + InsertProject(userID, name, projectDir string) (string, error) + UpdateProject(userID, name string) error + UpdateProjectAcl(userID, name string, acl db.ProjectAcl) error + UpsertProject(userID, projectName, projectDir string) (*db.Project, error) + RemoveProject(projectID string) error + LinkToProject(userID, projectID, projectDir string, commit bool) error + FindProjectByName(userID, name string) (*db.Project, error) + FindProjectLinks(userID, name string) ([]*db.Project, error) + FindProjectsByUser(userID string) ([]*db.Project, error) + FindProjectsByPrefix(userID, name string) ([]*db.Project, error) + FindProjects(by string) ([]*db.Project, error) + + Close() error +} diff --git a/pgs/db/memory.go b/pgs/db/memory.go new file mode 100644 index 00000000..c9964046 --- /dev/null +++ b/pgs/db/memory.go @@ -0,0 +1,171 @@ +package pgsdb + +import ( + "fmt" + "log/slog" + "time" + + "github.com/google/uuid" + "github.com/picosh/pico/db" + "github.com/picosh/utils" +) + +type MemoryDB struct { + Logger *slog.Logger + Users []*db.User + Projects []*db.Project + Pubkeys []*db.PublicKey + Feature *db.FeatureFlag +} + +var _ PgsDB = (*MemoryDB)(nil) + +func NewDBMemory(logger *slog.Logger) *MemoryDB { + d := &MemoryDB{ + Logger: logger, + } + d.Logger.Info("connecting to our in-memory database. All data created during runtime will be lost on exit.") + return d +} + +func (me *MemoryDB) SetupTestData() { + user := &db.User{ + ID: uuid.NewString(), + Name: "testusr", + } + me.Users = append(me.Users, user) + feature := db.NewFeatureFlag( + user.ID, + "plus", + uint64(25*utils.MB), + int64(10*utils.MB), + int64(5*utils.KB), + ) + expiresAt := time.Now().Add(time.Hour * 24) + feature.ExpiresAt = &expiresAt + me.Feature = feature +} + +var notImpl = fmt.Errorf("not implemented") + +func (me *MemoryDB) FindUsers() ([]*db.User, error) { + users := []*db.User{} + return users, notImpl +} + +func (me *MemoryDB) FindUserByPubkey(key string) (*db.User, error) { + for _, pk := range me.Pubkeys { + if pk.Key == key { + return me.FindUser(pk.UserID) + } + } + return nil, fmt.Errorf("user not found") +} + +func (me *MemoryDB) FindUser(userID string) (*db.User, error) { + for _, user := range me.Users { + if user.ID == userID { + return user, nil + } + } + return nil, fmt.Errorf("user not found") +} + +func (me *MemoryDB) FindUserByName(name string) (*db.User, error) { + for _, user := range me.Users { + if user.Name == name { + return user, nil + } + } + return nil, fmt.Errorf("user not found") +} + +func (me *MemoryDB) FindFeature(userID, name string) (*db.FeatureFlag, error) { + return me.Feature, nil +} + +func (me *MemoryDB) Close() error { + return nil +} + +func (me *MemoryDB) FindTotalSizeForUser(userID string) (int, error) { + return 0, notImpl +} + +func (me *MemoryDB) InsertProject(userID, name, projectDir string) (string, error) { + id := uuid.NewString() + me.Projects = append(me.Projects, &db.Project{ + ID: id, + UserID: userID, + Name: name, + ProjectDir: projectDir, + }) + return id, nil +} + +func (me *MemoryDB) UpdateProject(userID, name string) error { + return notImpl +} + +func (me *MemoryDB) UpsertProject(userID, projectName, projectDir string) (*db.Project, error) { + project, err := me.FindProjectByName(userID, projectName) + if err == nil { + // this just updates the `createdAt` timestamp, useful for book-keeping + err = me.UpdateProject(userID, projectName) + if err != nil { + me.Logger.Error("could not update project", "err", err) + return nil, err + } + return project, nil + } + + _, err = me.InsertProject(userID, projectName, projectName) + if err != nil { + me.Logger.Error("could not create project", "err", err) + return nil, err + } + return me.FindProjectByName(userID, projectName) +} + +func (me *MemoryDB) LinkToProject(userID, projectID, projectDir string, commit bool) error { + return notImpl +} + +func (me *MemoryDB) RemoveProject(projectID string) error { + return notImpl +} + +func (me *MemoryDB) FindProjectByName(userID, name string) (*db.Project, error) { + for _, project := range me.Projects { + if project.UserID != userID { + continue + } + + if project.Name != name { + continue + } + + return project, nil + } + return nil, fmt.Errorf("project not found by name %s", name) +} + +func (me *MemoryDB) FindProjectLinks(userID, name string) ([]*db.Project, error) { + return []*db.Project{}, notImpl +} + +func (me *MemoryDB) FindProjectsByPrefix(userID, prefix string) ([]*db.Project, error) { + return []*db.Project{}, notImpl +} + +func (me *MemoryDB) FindProjectsByUser(userID string) ([]*db.Project, error) { + return []*db.Project{}, notImpl +} + +func (me *MemoryDB) FindProjects(userID string) ([]*db.Project, error) { + return []*db.Project{}, notImpl +} + +func (me *MemoryDB) UpdateProjectAcl(userID, name string, acl db.ProjectAcl) error { + return notImpl +} diff --git a/pgs/db/postgres.go b/pgs/db/postgres.go new file mode 100644 index 00000000..8224d107 --- /dev/null +++ b/pgs/db/postgres.go @@ -0,0 +1,232 @@ +package pgsdb + +import ( + "fmt" + "log/slog" + "time" + + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" + "github.com/picosh/pico/db" + "github.com/picosh/utils" +) + +type PgsPsqlDB struct { + Logger *slog.Logger + Db *sqlx.DB +} + +var _ PgsDB = (*PgsPsqlDB)(nil) + +func NewDB(databaseUrl string, logger *slog.Logger) (*PgsPsqlDB, error) { + var err error + d := &PgsPsqlDB{ + Logger: logger, + } + d.Logger.Info("connecting to postgres", "databaseUrl", databaseUrl) + + db, err := sqlx.Connect("postgres", databaseUrl) + if err != nil { + return nil, err + } + + d.Db = db + return d, nil +} + +func (me *PgsPsqlDB) Close() error { + return me.Db.Close() +} + +func (me *PgsPsqlDB) FindUsers() ([]*db.User, error) { + users := []*db.User{} + err := me.Db.Select(&users, "SELECT * FROM app_users") + return users, err +} + +func (me *PgsPsqlDB) FindUserByPubkey(key string) (*db.User, error) { + pk := []db.PublicKey{} + err := me.Db.Select(&pk, "SELECT * FROM public_keys WHERE public_key=$1", key) + if err != nil { + return nil, err + } + if len(pk) == 0 { + return nil, fmt.Errorf("pubkey not found in our database: [%s]", key) + } + // When we run PublicKeyForKey and there are multiple public keys returned from the database + // that should mean that we don't have the correct username for this public key. + // When that happens we need to reject the authentication and ask the user to provide the correct + // username when using ssh. So instead of `ssh ` it should be `ssh user@` + if len(pk) > 1 { + return nil, &db.ErrMultiplePublicKeys{} + } + + return me.FindUser(pk[0].UserID) +} + +func (me *PgsPsqlDB) FindUser(userID string) (*db.User, error) { + user := db.User{} + err := me.Db.Get(&user, "SELECT * FROM app_users WHERE id=$1", userID) + return &user, err +} + +func (me *PgsPsqlDB) FindUserByName(name string) (*db.User, error) { + user := db.User{} + err := me.Db.Get(&user, "SELECT * FROM app_users WHERE name=$1", name) + return &user, err +} + +func (me *PgsPsqlDB) FindFeature(userID, name string) (*db.FeatureFlag, error) { + ff := db.FeatureFlag{} + err := me.Db.Get(&ff, "SELECT * FROM feature_flags WHERE user_id=$1 AND name=$2 ORDER BY expires_at DESC LIMIT 1", userID, name) + return &ff, err +} + +func (me *PgsPsqlDB) InsertProject(userID, name, projectDir string) (string, error) { + if !utils.IsValidSubdomain(name) { + return "", fmt.Errorf("'%s' is not a valid project name, must match /^[a-z0-9-]+$/", name) + } + + var projectID string + row := me.Db.QueryRow( + "INSERT INTO projects (user_id, name, project_dir) VALUES ($1, $2, $3) RETURNING id", + userID, + name, + projectDir, + ) + err := row.Scan(&projectID) + return projectID, err +} + +func (me *PgsPsqlDB) UpdateProject(userID, name string) error { + _, err := me.Db.Exec("UPDATE projects SET updated_at=$1 WHERE user_id=$2 AND name=$3", time.Now(), userID, name) + return err +} + +func (me *PgsPsqlDB) UpsertProject(userID, projectName, projectDir string) (*db.Project, error) { + project, err := me.FindProjectByName(userID, projectName) + if err == nil { + // this just updates the `createdAt` timestamp, useful for book-keeping + err = me.UpdateProject(userID, projectName) + if err != nil { + me.Logger.Error("could not update project", "err", err) + return nil, err + } + return project, nil + } + + _, err = me.InsertProject(userID, projectName, projectName) + if err != nil { + me.Logger.Error("could not create project", "err", err) + return nil, err + } + return me.FindProjectByName(userID, projectName) +} + +func (me *PgsPsqlDB) LinkToProject(userID, projectID, projectDir string, commit bool) error { + linkToProject, err := me.FindProjectByName(userID, projectDir) + if err != nil { + return err + } + isAlreadyLinked := linkToProject.Name != linkToProject.ProjectDir + sameProject := linkToProject.ID == projectID + + /* + A project linked to another project which is also linked to a + project is forbidden. CI/CD Example: + - ProjectProd links to ProjectStaging + - ProjectStaging links to ProjectMain + - We merge `main` and trigger a deploy which uploads to ProjectMain + - All three get updated immediately + This scenario was not the intent of our CI/CD. What we actually + wanted was to create a snapshot of ProjectMain and have ProjectStaging + link to the snapshot, but that's not the intended design of pgs. + + So we want to close that gap here. + + We ensure that `project.Name` and `project.ProjectDir` are identical + when there is no aliasing. + */ + if !sameProject && isAlreadyLinked { + return fmt.Errorf( + "cannot link (%s) to (%s) because it is also a link to (%s)", + projectID, + projectDir, + linkToProject.ProjectDir, + ) + } + + if commit { + _, err = me.Db.Exec( + "UPDATE projects SET project_dir=$1, updated_at=$2 WHERE id=$3", + projectDir, + time.Now(), + projectID, + ) + } + return err +} + +func (me *PgsPsqlDB) RemoveProject(projectID string) error { + _, err := me.Db.Exec("DELETE FROM projects WHERE id=$1", projectID) + return err +} + +func (me *PgsPsqlDB) FindProjectByName(userID, name string) (*db.Project, error) { + project := db.Project{} + err := me.Db.Get(&project, "SELECT * FROM projects WHERE user_id=$1 AND name=$2", userID, name) + return &project, err +} + +func (me *PgsPsqlDB) FindProjectLinks(userID, name string) ([]*db.Project, error) { + projects := []*db.Project{} + err := me.Db.Select( + &projects, + "SELECT * FROM projects WHERE user_id=$1 AND name != project_dir AND project_dir=$2 ORDER BY name ASC", + userID, + name, + ) + return projects, err +} + +func (me *PgsPsqlDB) FindProjectsByPrefix(userID, prefix string) ([]*db.Project, error) { + projects := []*db.Project{} + err := me.Db.Select( + &projects, + "SELECT * FROM projects WHERE user_id=$1 AND name=project_dir AND name ILIKE $2 ORDER BY updated_at ASC, name ASC", + userID, + prefix+"%", + ) + return projects, err +} + +func (me *PgsPsqlDB) FindProjectsByUser(userID string) ([]*db.Project, error) { + projects := []*db.Project{} + err := me.Db.Select( + &projects, + "SELECT * FROM projects WHERE user_id=$1 ORDER BY name ASC", + userID, + ) + return projects, err +} + +func (me *PgsPsqlDB) FindProjects(by string) ([]*db.Project, error) { + projects := []*db.Project{} + err := me.Db.Select( + &projects, + `SELECT p.id, p.user_id, u.name as username, p.name, p.project_dir, p.acl, p.blocked, p.created_at, p.updated_at + FROM projects AS p + LEFT JOIN app_users AS u ON u.id = p.user_id + ORDER BY $1 DESC`, + by, + ) + return projects, err +} + +func (me *PgsPsqlDB) UpdateProjectAcl(userID, name string, acl db.ProjectAcl) error { + _, err := me.Db.Exec( + "UPDATE projects SET acl=$3, updated_at=$4 WHERE user_id=$1 AND name=$2", + userID, name, acl, time.Now(), + ) + return err +} diff --git a/pgs/ssh.go b/pgs/ssh.go index 0aa56d8b..bfd2e49a 100644 --- a/pgs/ssh.go +++ b/pgs/ssh.go @@ -11,9 +11,7 @@ import ( "github.com/charmbracelet/promwish" "github.com/charmbracelet/ssh" "github.com/charmbracelet/wish" - "github.com/picosh/pico/db/postgres" "github.com/picosh/pico/shared" - "github.com/picosh/pico/shared/storage" wsh "github.com/picosh/pico/wish" "github.com/picosh/send/auth" "github.com/picosh/send/list" @@ -52,49 +50,28 @@ func withProxy(handler *UploadAssetHandler, otherMiddleware ...wish.Middleware) } } -func StartSshServer() { +func StartSshServer(cfg *PgsConfig, killCh chan error) { host := utils.GetEnv("PGS_HOST", "0.0.0.0") port := utils.GetEnv("PGS_SSH_PORT", "2222") promPort := utils.GetEnv("PGS_PROM_PORT", "9222") - cfg := NewConfigSite() logger := cfg.Logger - dbpool := postgres.NewDB(cfg.DbURL, cfg.Logger) - defer dbpool.Close() - - var st storage.StorageServe - var err error - if cfg.MinioURL == "" { - st, err = storage.NewStorageFS(cfg.Logger, cfg.StorageDir) - } else { - st, err = storage.NewStorageMinio(cfg.Logger, cfg.MinioURL, cfg.MinioUser, cfg.MinioPass) - } - - if err != nil { - logger.Error(err.Error()) - return - } ctx := context.Background() defer ctx.Done() + + cacheClearingQueue := make(chan string, 100) handler := NewUploadAssetHandler( - dbpool, cfg, - st, + cacheClearingQueue, ctx, ) - apiConfig := &shared.ApiConfig{ - Cfg: cfg, - Dbpool: dbpool, - Storage: st, - } - webTunnel := &tunkit.WebTunnelHandler{ Logger: logger, - HttpHandler: createHttpHandler(apiConfig), + HttpHandler: createHttpHandler(cfg), } - sshAuth := shared.NewSshAuthHandler(dbpool, logger, cfg) + sshAuth := shared.NewSshAuthHandler(cfg.DB, logger) s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), @@ -120,12 +97,16 @@ func StartSshServer() { } }() - <-done - logger.Info("stopping SSH server") - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer func() { cancel() }() - if err := s.Shutdown(ctx); err != nil { - logger.Error("shutdown", "err", err.Error()) - os.Exit(1) + select { + case <-done: + logger.Info("stopping ssh server") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer func() { cancel() }() + if err := s.Shutdown(ctx); err != nil { + logger.Error("shutdown", "err", err.Error()) + os.Exit(1) + } + case <-killCh: + logger.Info("stopping ssh server") } } diff --git a/pgs/ssh_test.go b/pgs/ssh_test.go new file mode 100644 index 00000000..d30fa310 --- /dev/null +++ b/pgs/ssh_test.go @@ -0,0 +1,188 @@ +package pgs + +import ( + "crypto/ed25519" + "crypto/rand" + "io" + "log/slog" + "os" + "strings" + "testing" + "time" + + "github.com/picosh/pico/db" + pgsdb "github.com/picosh/pico/pgs/db" + "github.com/picosh/pico/shared/storage" + "github.com/picosh/utils" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +func TestSshServer(t *testing.T) { + logger := slog.Default() + dbpool := pgsdb.NewDBMemory(logger) + // setup test data + dbpool.SetupTestData() + st, err := storage.NewStorageMemory(map[string]map[string]string{}) + if err != nil { + panic(err) + } + cfg := NewPgsConfig(logger, dbpool, st) + done := make(chan error) + go StartSshServer(cfg, done) + // Hack to wait for startup + time.Sleep(time.Millisecond * 100) + + user := GenerateUser() + // add user's pubkey to the default test account + dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{ + ID: "nice-pubkey", + UserID: dbpool.Users[0].ID, + Key: utils.KeyForKeyText(user.signer.PublicKey()), + }) + + client, err := user.NewClient() + if err != nil { + t.Error(err) + return + } + defer client.Close() + + _, err = WriteFileWithSftp(cfg, client) + if err != nil { + t.Error(err) + return + } + + done <- nil +} + +type UserSSH struct { + username string + signer ssh.Signer +} + +func NewUserSSH(username string, signer ssh.Signer) *UserSSH { + return &UserSSH{ + username: username, + signer: signer, + } +} + +func (s UserSSH) Public() string { + pubkey := s.signer.PublicKey() + return string(ssh.MarshalAuthorizedKey(pubkey)) +} + +func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string { + res, err := s.Cmd(client, patch, cmd) + if err != nil { + panic(err) + } + return res +} + +func (s UserSSH) NewClient() (*ssh.Client, error) { + host := "localhost:2222" + + config := &ssh.ClientConfig{ + User: s.username, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(s.signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + client, err := ssh.Dial("tcp", host, config) + return client, err +} + +func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) { + session, err := client.NewSession() + if err != nil { + return "", err + } + defer session.Close() + + stdinPipe, err := session.StdinPipe() + if err != nil { + return "", err + } + + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return "", err + } + + if err := session.Start(cmd); err != nil { + return "", err + } + + if patch != nil { + _, err = stdinPipe.Write(patch) + if err != nil { + return "", err + } + } + + stdinPipe.Close() + + if err := session.Wait(); err != nil { + return "", err + } + + buf := new(strings.Builder) + _, err = io.Copy(buf, stdoutPipe) + if err != nil { + return "", err + } + + return buf.String(), nil +} + +func GenerateUser() UserSSH { + _, userKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + userSigner, err := ssh.NewSignerFromKey(userKey) + if err != nil { + panic(err) + } + + return UserSSH{ + username: "testuser", + signer: userSigner, + } +} + +func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) { + // open an SFTP session over an existing ssh connection. + client, err := sftp.NewClient(conn) + if err != nil { + cfg.Logger.Error("could not create sftp client", "err", err) + return nil, err + } + defer client.Close() + + f, err := client.Create("test/hello.txt") + if err != nil { + cfg.Logger.Error("could not create file", "err", err) + return nil, err + } + if _, err := f.Write([]byte("Hello world!")); err != nil { + cfg.Logger.Error("could not write to file", "err", err) + return nil, err + } + f.Close() + + // check it's there + fi, err := client.Lstat("test/hello.txt") + if err != nil { + cfg.Logger.Error("could not get stat for file", "err", err) + return nil, err + } + + return &fi, nil +} diff --git a/pgs/tunnel.go b/pgs/tunnel.go index a42cd4a4..8e240a37 100644 --- a/pgs/tunnel.go +++ b/pgs/tunnel.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "strings" + "time" "github.com/charmbracelet/ssh" "github.com/picosh/pico/db" @@ -43,10 +44,9 @@ func getInfoFromUser(user string) (string, string) { return "", user } -func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { +func createHttpHandler(cfg *PgsConfig) CtxHttpBridge { return func(ctx ssh.Context) http.Handler { - dbh := apiConfig.Dbpool - logger := apiConfig.Cfg.Logger + logger := cfg.Logger asUser, subdomain := getInfoFromUser(ctx.User()) log := logger.With( "subdomain", subdomain, @@ -69,7 +69,7 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { return http.HandlerFunc(shared.UnauthorizedHandler) } - owner, err := dbh.FindUserForName(props.Username) + owner, err := cfg.DB.FindUserByName(props.Username) if err != nil { log.Error( "could not find user from name", @@ -82,13 +82,13 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { "owner", owner.Name, ) - project, err := dbh.FindProjectByName(owner.ID, props.ProjectName) + project, err := cfg.DB.FindProjectByName(owner.ID, props.ProjectName) if err != nil { log.Error("could not get project by name", "project", props.ProjectName, "err", err.Error()) return http.HandlerFunc(shared.UnauthorizedHandler) } - requester, _ := dbh.FindUserForKey("", pubkey) + requester, _ := cfg.DB.FindUserByPubkey(pubkey) if requester != nil { log = log.With( "requester", requester.Name, @@ -97,12 +97,19 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { // impersonation logic if asUser != "" { - isAdmin := dbh.HasFeatureForUser(requester.ID, "admin") + isAdmin := false + ff, _ := cfg.DB.FindFeature(requester.ID, "admin") + if ff != nil { + if ff.ExpiresAt.Before(time.Now()) { + isAdmin = true + } + } + if !isAdmin { log.Error("impersonation attempt failed") return http.HandlerFunc(shared.UnauthorizedHandler) } - requester, _ = dbh.FindUserForName(asUser) + requester, _ = cfg.DB.FindUserByName(asUser) } ctx.Permissions().Extensions["user_id"] = requester.ID @@ -118,12 +125,7 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { log.Info("user has access to site") - routes := NewWebRouter( - apiConfig.Cfg, - logger, - apiConfig.Dbpool, - apiConfig.Storage, - ) + routes := NewWebRouter(cfg) tunnelRouter := TunnelWebRouter{routes, subdomain} tunnelRouter.initRouters() return &tunnelRouter diff --git a/pgs/uploader.go b/pgs/uploader.go index 9ed7b2b4..91d44cd1 100644 --- a/pgs/uploader.go +++ b/pgs/uploader.go @@ -18,6 +18,7 @@ import ( "github.com/charmbracelet/ssh" "github.com/charmbracelet/wish" "github.com/picosh/pico/db" + pgsdb "github.com/picosh/pico/pgs/db" "github.com/picosh/pico/shared" "github.com/picosh/pobj" sst "github.com/picosh/pobj/storage" @@ -99,21 +100,14 @@ type FileData struct { } type UploadAssetHandler struct { - DBPool db.DB - Cfg *shared.ConfigSite - Storage sst.ObjectStorage + Cfg *PgsConfig CacheClearingQueue chan string } -func NewUploadAssetHandler(dbpool db.DB, cfg *shared.ConfigSite, storage sst.ObjectStorage, ctx context.Context) *UploadAssetHandler { - // Enable buffering so we don't slow down uploads. - ch := make(chan string, 100) - go runCacheQueue(cfg, ctx, ch) - // publish all file uploads to a pipe topic +func NewUploadAssetHandler(cfg *PgsConfig, ch chan string, ctx context.Context) *UploadAssetHandler { + go runCacheQueue(cfg, ctx) return &UploadAssetHandler{ - DBPool: dbpool, Cfg: cfg, - Storage: storage, CacheClearingQueue: ch, } } @@ -123,7 +117,7 @@ func (h *UploadAssetHandler) GetLogger() *slog.Logger { } func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) { - user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) + user, err := h.Cfg.DB.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return nil, nil, err } @@ -135,13 +129,13 @@ func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os FModTime: time.Unix(entry.Mtime, 0), } - bucket, err := h.Storage.GetBucket(shared.GetAssetBucketName(user.ID)) + bucket, err := h.Cfg.Storage.GetBucket(shared.GetAssetBucketName(user.ID)) if err != nil { return nil, nil, err } fname := shared.GetAssetFileName(entry) - contents, info, err := h.Storage.GetObject(bucket, fname) + contents, info, err := h.Cfg.Storage.GetObject(bucket, fname) if err != nil { return nil, nil, err } @@ -157,7 +151,7 @@ func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) { var fileList []os.FileInfo - user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) + user, err := h.Cfg.DB.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return fileList, err } @@ -165,7 +159,7 @@ func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recur cleanFilename := fpath bucketName := shared.GetAssetBucketName(user.ID) - bucket, err := h.Storage.GetBucket(bucketName) + bucket, err := h.Cfg.Storage.GetBucket(bucketName) if err != nil { return fileList, err } @@ -187,7 +181,7 @@ func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recur cleanFilename += "/" } - foundList, err := h.Storage.ListObjects(bucket, cleanFilename, recursive) + foundList, err := h.Cfg.Storage.ListObjects(bucket, cleanFilename, recursive) if err != nil { return fileList, err } @@ -199,19 +193,19 @@ func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recur } func (h *UploadAssetHandler) Validate(s ssh.Session) error { - user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) + user, err := h.Cfg.DB.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return err } assetBucket := shared.GetAssetBucketName(user.ID) - bucket, err := h.Storage.UpsertBucket(assetBucket) + bucket, err := h.Cfg.Storage.UpsertBucket(assetBucket) if err != nil { return err } s.Context().SetValue(ctxBucketKey{}, bucket) - totalStorageSize, err := h.Storage.GetBucketQuota(bucket) + totalStorageSize, err := h.Cfg.Storage.GetBucketQuota(bucket) if err != nil { return err } @@ -225,14 +219,14 @@ func (h *UploadAssetHandler) Validate(s ssh.Session) error { h.Cfg.Logger.Info( "attempting to upload files", "user", user.Name, - "space", h.Cfg.Space, + "txtPrefix", h.Cfg.TxtPrefix, ) return nil } func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project, logger *slog.Logger) (string, error) { - fp, _, err := h.Storage.GetObject(bucket, filepath.Join(project.ProjectDir, "_pgs_ignore")) + fp, _, err := h.Cfg.Storage.GetObject(bucket, filepath.Join(project.ProjectDir, "_pgs_ignore")) if err != nil { return "", fmt.Errorf("_pgs_ignore not found") } @@ -249,8 +243,28 @@ func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project return str, nil } +func findPlusFF(dbpool pgsdb.PgsDB, cfg *PgsConfig, userID string) *db.FeatureFlag { + ff, _ := dbpool.FindFeature(userID, "plus") + // we have free tiers so users might not have a feature flag + // in which case we set sane defaults + if ff == nil { + ff = db.NewFeatureFlag( + userID, + "plus", + cfg.MaxSize, + cfg.MaxAssetSize, + cfg.MaxSpecialFileSize, + ) + } + // this is jank + ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize) + ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize) + ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize) + return ff +} + func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) { - user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) + user, err := h.Cfg.DB.FindUser(s.Permissions().Extensions["user_id"]) if user == nil || err != nil { h.Cfg.Logger.Error("user not found in ctx", "err", err.Error()) return "", err @@ -279,7 +293,7 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (s // find, create, or update project if we haven't already done it if project == nil { - project, err = h.DBPool.UpsertProject(user.ID, projectName, projectName) + project, err = h.Cfg.DB.UpsertProject(user.ID, projectName, projectName) if err != nil { logger.Error("upsert project", "err", err.Error()) return "", err @@ -293,7 +307,7 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (s } if entry.Mode.IsDir() { - _, _, err := h.Storage.PutObject( + _, _, err := h.Cfg.Storage.PutObject( bucket, path.Join(shared.GetAssetFileName(entry), "._pico_keep_dir"), bytes.NewReader([]byte{}), @@ -302,11 +316,11 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (s return "", err } - featureFlag := shared.FindPlusFF(h.DBPool, h.Cfg, user.ID) + featureFlag := findPlusFF(h.Cfg.DB, h.Cfg, user.ID) // calculate the filsize difference between the same file already // stored and the updated file being uploaded assetFilename := shared.GetAssetFileName(entry) - obj, info, _ := h.Storage.GetObject(bucket, assetFilename) + obj, info, _ := h.Cfg.Storage.GetObject(bucket, assetFilename) var curFileSize int64 if info != nil { curFileSize = info.Size @@ -400,7 +414,7 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (s ) surrogate := getSurrogateKey(user.Name, projectName) - h.CacheClearingQueue <- surrogate + h.Cfg.CacheClearingQueue <- surrogate return str, err } @@ -411,7 +425,7 @@ func isSpecialFile(entry *sendutils.FileEntry) bool { } func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error { - user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) + user, err := h.Cfg.DB.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { h.Cfg.Logger.Error("user not found in ctx", "err", err.Error()) return err @@ -447,7 +461,7 @@ func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) e pathDir := filepath.Dir(assetFilepath) fileName := filepath.Base(assetFilepath) - sibs, err := h.Storage.ListObjects(bucket, pathDir+"/", false) + sibs, err := h.Cfg.Storage.ListObjects(bucket, pathDir+"/", false) if err != nil { return err } @@ -457,7 +471,7 @@ func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) e }) if len(sibs) == 0 { - _, _, err := h.Storage.PutObject( + _, _, err := h.Cfg.Storage.PutObject( bucket, filepath.Join(pathDir, "._pico_keep_dir"), bytes.NewReader([]byte{}), @@ -467,10 +481,10 @@ func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) e return err } } - err = h.Storage.DeleteObject(bucket, assetFilepath) + err = h.Cfg.Storage.DeleteObject(bucket, assetFilepath) surrogate := getSurrogateKey(user.Name, projectName) - h.CacheClearingQueue <- surrogate + h.Cfg.CacheClearingQueue <- surrogate if err != nil { return err @@ -514,7 +528,7 @@ func (h *UploadAssetHandler) writeAsset(reader io.Reader, data *FileData) (int64 "filename", assetFilepath, ) - _, fsize, err := h.Storage.PutObject( + _, fsize, err := h.Cfg.Storage.PutObject( data.Bucket, assetFilepath, reader, @@ -527,13 +541,13 @@ func (h *UploadAssetHandler) writeAsset(reader io.Reader, data *FileData) (int64 // One message arrives per file that is written/deleted during uploads. // Repeated messages for the same site are grouped so that we only flush once // per site per 5 seconds. -func runCacheQueue(cfg *shared.ConfigSite, ctx context.Context, ch chan string) { +func runCacheQueue(cfg *PgsConfig, ctx context.Context) { send := createPubCacheDrain(ctx, cfg.Logger) var pendingFlushes sync.Map tick := time.Tick(5 * time.Second) for { select { - case host := <-ch: + case host := <-cfg.CacheClearingQueue: pendingFlushes.Store(host, host) case <-tick: go func() { diff --git a/pgs/web.go b/pgs/web.go index f5293e8a..0e36bec1 100644 --- a/pgs/web.go +++ b/pgs/web.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "html/template" "log/slog" "net/http" "net/url" @@ -20,7 +21,6 @@ import ( "github.com/darkweak/storages/core" "github.com/gorilla/feeds" "github.com/picosh/pico/db" - "github.com/picosh/pico/db/postgres" "github.com/picosh/pico/shared" "github.com/picosh/pico/shared/storage" sst "github.com/picosh/pobj/storage" @@ -39,26 +39,9 @@ func (c *CachedHttp) ServeHTTP(writer http.ResponseWriter, req *http.Request) { }) } -func StartApiServer() { +func StartApiServer(cfg *PgsConfig) { ctx := context.Background() - cfg := NewConfigSite() - logger := cfg.Logger - dbpool := postgres.NewDB(cfg.DbURL, cfg.Logger) - defer dbpool.Close() - - var st storage.StorageServe - var err error - if cfg.MinioURL == "" { - st, err = storage.NewStorageFS(cfg.Logger, cfg.StorageDir) - } else { - st, err = storage.NewStorageMinio(cfg.Logger, cfg.MinioURL, cfg.MinioUser, cfg.MinioPass) - } - - if err != nil { - logger.Error("could not connect to object storage", "err", err.Error()) - return - } ttl := configurationtypes.Duration{Duration: cfg.CacheTTL} stale := configurationtypes.Duration{Duration: cfg.CacheTTL * 2} c := &middleware.BaseConfiguration{ @@ -81,10 +64,10 @@ func StartApiServer() { DefaultCacheControl: cfg.CacheControl, }, } - c.SetLogger(&CompatLogger{logger}) + c.SetLogger(&CompatLogger{cfg.Logger}) storages.InitFromConfiguration(c) httpCache := middleware.NewHTTPCacheHandler(c) - routes := NewWebRouter(cfg, logger, dbpool, st) + routes := NewWebRouter(cfg) cacher := &CachedHttp{ handler: httpCache, routes: routes, @@ -92,14 +75,14 @@ func StartApiServer() { go routes.cacheMgmt(ctx, httpCache) - portStr := fmt.Sprintf(":%s", cfg.Port) - logger.Info( + portStr := fmt.Sprintf(":%s", cfg.WebPort) + cfg.Logger.Info( "starting server on port", - "port", cfg.Port, + "port", cfg.WebPort, "domain", cfg.Domain, ) - err = http.ListenAndServe(portStr, cacher) - logger.Error( + err := http.ListenAndServe(portStr, cacher) + cfg.Logger.Error( "listen and serve", "err", err.Error(), ) @@ -108,20 +91,14 @@ func StartApiServer() { type HasPerm = func(proj *db.Project) bool type WebRouter struct { - Cfg *shared.ConfigSite - Logger *slog.Logger - Dbpool db.DB - Storage storage.StorageServe + Cfg *PgsConfig RootRouter *http.ServeMux UserRouter *http.ServeMux } -func NewWebRouter(cfg *shared.ConfigSite, logger *slog.Logger, dbpool db.DB, st storage.StorageServe) *WebRouter { +func NewWebRouter(cfg *PgsConfig) *WebRouter { router := &WebRouter{ - Cfg: cfg, - Logger: logger, - Dbpool: dbpool, - Storage: st, + Cfg: cfg, } router.initRouters() return router @@ -154,7 +131,7 @@ func (web *WebRouter) initRouters() { func (web *WebRouter) serveFile(file string, contentType string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger := web.Logger + logger := web.Cfg.Logger cfg := web.Cfg contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file))) @@ -180,11 +157,28 @@ func (web *WebRouter) serveFile(file string, contentType string) http.HandlerFun } } +func renderTemplate(cfg *PgsConfig, templates []string) (*template.Template, error) { + files := make([]string, len(templates)) + copy(files, templates) + files = append( + files, + cfg.StaticPath("html/footer.partial.tmpl"), + cfg.StaticPath("html/marketing-footer.partial.tmpl"), + cfg.StaticPath("html/base.layout.tmpl"), + ) + + ts, err := template.New("base").ParseFiles(files...) + if err != nil { + return nil, err + } + return ts, nil +} + func (web *WebRouter) createPageHandler(fname string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger := web.Logger + logger := web.Cfg.Logger cfg := web.Cfg - ts, err := shared.RenderTemplate(cfg, []string{cfg.StaticPath(fname)}) + ts, err := renderTemplate(cfg, []string{cfg.StaticPath(fname)}) if err != nil { logger.Error( @@ -197,7 +191,7 @@ func (web *WebRouter) createPageHandler(fname string) http.HandlerFunc { } data := shared.PageData{ - Site: *cfg.GetSiteData(), + Site: shared.SitePageData{Domain: template.URL(cfg.Domain), HomeURL: "/"}, } err = ts.Execute(w, data) if err != nil { @@ -212,54 +206,52 @@ func (web *WebRouter) createPageHandler(fname string) http.HandlerFunc { } func (web *WebRouter) checkHandler(w http.ResponseWriter, r *http.Request) { - dbpool := web.Dbpool + dbpool := web.Cfg.DB cfg := web.Cfg - logger := web.Logger - - if cfg.IsCustomdomains() { - hostDomain := r.URL.Query().Get("domain") - appDomain := strings.Split(cfg.Domain, ":")[0] - - if !strings.Contains(hostDomain, appDomain) { - subdomain := shared.GetCustomDomain(hostDomain, cfg.Space) - props, err := shared.GetProjectFromSubdomain(subdomain) - if err != nil { - logger.Error( - "could not get project from subdomain", - "subdomain", subdomain, - "err", err.Error(), - ) - w.WriteHeader(http.StatusNotFound) - return - } + logger := web.Cfg.Logger - u, err := dbpool.FindUserForName(props.Username) - if err != nil { - logger.Error("could not find user", "err", err.Error()) - w.WriteHeader(http.StatusNotFound) - return - } + hostDomain := r.URL.Query().Get("domain") + appDomain := strings.Split(cfg.Domain, ":")[0] + + if !strings.Contains(hostDomain, appDomain) { + subdomain := shared.GetCustomDomain(hostDomain, cfg.TxtPrefix) + props, err := shared.GetProjectFromSubdomain(subdomain) + if err != nil { + logger.Error( + "could not get project from subdomain", + "subdomain", subdomain, + "err", err.Error(), + ) + w.WriteHeader(http.StatusNotFound) + return + } + + u, err := dbpool.FindUserByName(props.Username) + if err != nil { + logger.Error("could not find user", "err", err.Error()) + w.WriteHeader(http.StatusNotFound) + return + } - logger = logger.With( + logger = logger.With( + "user", u.Name, + "project", props.ProjectName, + ) + p, err := dbpool.FindProjectByName(u.ID, props.ProjectName) + if err != nil { + logger.Error( + "could not find project for user", "user", u.Name, "project", props.ProjectName, + "err", err.Error(), ) - p, err := dbpool.FindProjectByName(u.ID, props.ProjectName) - if err != nil { - logger.Error( - "could not find project for user", - "user", u.Name, - "project", props.ProjectName, - "err", err.Error(), - ) - w.WriteHeader(http.StatusNotFound) - return - } + w.WriteHeader(http.StatusNotFound) + return + } - if u != nil && p != nil { - w.WriteHeader(http.StatusOK) - return - } + if u != nil && p != nil { + w.WriteHeader(http.StatusOK) + return } } @@ -268,21 +260,21 @@ func (web *WebRouter) checkHandler(w http.ResponseWriter, r *http.Request) { func (web *WebRouter) cacheMgmt(ctx context.Context, httpCache *middleware.SouinBaseHandler) { storer := httpCache.Storers[0] - drain := createSubCacheDrain(ctx, web.Logger) + drain := createSubCacheDrain(ctx, web.Cfg.Logger) for { scanner := bufio.NewScanner(drain) for scanner.Scan() { surrogateKey := strings.TrimSpace(scanner.Text()) - web.Logger.Info("received cache-drain item", "surrogateKey", surrogateKey) + web.Cfg.Logger.Info("received cache-drain item", "surrogateKey", surrogateKey) if surrogateKey == "*" { storer.DeleteMany(".+") err := httpCache.SurrogateKeyStorer.Destruct() if err != nil { - web.Logger.Error("could not clear cache and surrogate key store", "err", err) + web.Cfg.Logger.Error("could not clear cache and surrogate key store", "err", err) } else { - web.Logger.Info("successfully cleared cache and surrogate keys store") + web.Cfg.Logger.Info("successfully cleared cache and surrogate keys store") } continue } @@ -298,7 +290,7 @@ func (web *WebRouter) cacheMgmt(ctx context.Context, httpCache *middleware.Souin if e := proto.Unmarshal(b, &mapping); e == nil { for k := range mapping.GetMapping() { qkey, _ := url.QueryUnescape(k) - web.Logger.Info( + web.Cfg.Logger.Info( "deleting key from surrogate cache", "surrogateKey", surrogateKey, "key", qkey, @@ -309,7 +301,7 @@ func (web *WebRouter) cacheMgmt(ctx context.Context, httpCache *middleware.Souin } qkey, _ := url.QueryUnescape(key) - web.Logger.Info( + web.Cfg.Logger.Info( "deleting from cache", "surrogateKey", surrogateKey, "key", core.MappingKeyPrefix+qkey, @@ -322,11 +314,11 @@ func (web *WebRouter) cacheMgmt(ctx context.Context, httpCache *middleware.Souin func (web *WebRouter) createRssHandler(by string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - dbpool := web.Dbpool - logger := web.Logger + dbpool := web.Cfg.DB + logger := web.Cfg.Logger cfg := web.Cfg - pager, err := dbpool.FindAllProjects(&db.Pager{Num: 100, Page: 0}, by) + projects, err := dbpool.FindProjects(by) if err != nil { logger.Error("could not find projects", "err", err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -335,14 +327,14 @@ func (web *WebRouter) createRssHandler(by string) http.HandlerFunc { feed := &feeds.Feed{ Title: fmt.Sprintf("%s discovery feed %s", cfg.Domain, by), - Link: &feeds.Link{Href: cfg.ReadURL()}, + Link: &feeds.Link{Href: "https://pgs.sh"}, Description: fmt.Sprintf("%s projects %s", cfg.Domain, by), Author: &feeds.Author{Name: cfg.Domain}, Created: time.Now(), } var feedItems []*feeds.Item - for _, project := range pager.Data { + for _, project := range projects { realUrl := strings.TrimSuffix( cfg.AssetURL(project.Username, project.Name, ""), "/", @@ -385,7 +377,7 @@ func (web *WebRouter) createRssHandler(by string) http.HandlerFunc { } func (web *WebRouter) Perm(proj *db.Project) bool { - return proj.Acl.Type == "public" + return proj.Acl.Type == "public" || proj.Acl.Type == "" } var imgRegex = regexp.MustCompile("(.+.(?:jpg|jpeg|png|gif|webp|svg))(/.+)") @@ -414,7 +406,7 @@ func (web *WebRouter) ImageRequest(w http.ResponseWriter, r *http.Request) { opts, err := storage.UriToImgProcessOpts(imgOpts) if err != nil { errMsg := fmt.Sprintf("error processing img options: %s", err.Error()) - web.Logger.Error("error processing img options", "err", errMsg) + web.Cfg.Logger.Error("error processing img options", "err", errMsg) http.Error(w, errMsg, http.StatusUnprocessableEntity) return } @@ -425,7 +417,7 @@ func (web *WebRouter) ImageRequest(w http.ResponseWriter, r *http.Request) { func (web *WebRouter) ServeAsset(fname string, opts *storage.ImgProcessOpts, fromImgs bool, hasPerm HasPerm, w http.ResponseWriter, r *http.Request) { subdomain := shared.GetSubdomain(r) - logger := web.Logger.With( + logger := web.Cfg.Logger.With( "subdomain", subdomain, "filename", fname, "url", fmt.Sprintf("%s%s", r.Host, r.URL.Path), @@ -447,7 +439,7 @@ func (web *WebRouter) ServeAsset(fname string, opts *storage.ImgProcessOpts, fro "user", props.Username, ) - user, err := web.Dbpool.FindUserForName(props.Username) + user, err := web.Cfg.DB.FindUserByName(props.Username) if err != nil { logger.Info("user not found") http.Error(w, "user not found", http.StatusNotFound) @@ -465,10 +457,10 @@ func (web *WebRouter) ServeAsset(fname string, opts *storage.ImgProcessOpts, fro var bucket sst.Bucket // imgs has a different bucket directory if fromImgs { - bucket, err = web.Storage.GetBucket(shared.GetImgsBucketName(user.ID)) + bucket, err = web.Cfg.Storage.GetBucket(shared.GetImgsBucketName(user.ID)) } else { - bucket, err = web.Storage.GetBucket(shared.GetAssetBucketName(user.ID)) - project, perr := web.Dbpool.FindProjectByName(user.ID, props.ProjectName) + bucket, err = web.Cfg.Storage.GetBucket(shared.GetAssetBucketName(user.ID)) + project, perr := web.Cfg.DB.FindProjectByName(user.ID, props.ProjectName) if perr != nil { logger.Info("project not found") http.Error(w, "project not found", http.StatusNotFound) @@ -500,7 +492,13 @@ func (web *WebRouter) ServeAsset(fname string, opts *storage.ImgProcessOpts, fro return } - hasPicoPlus := web.Dbpool.HasFeatureForUser(user.ID, "plus") + hasPicoPlus := false + ff, _ := web.Cfg.DB.FindFeature(user.ID, "plus") + if ff != nil { + if ff.ExpiresAt.Before(time.Now()) { + hasPicoPlus = true + } + } asset := &ApiAssetHandler{ WebRouter: web, @@ -521,9 +519,9 @@ func (web *WebRouter) ServeAsset(fname string, opts *storage.ImgProcessOpts, fro } func (web *WebRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - subdomain := shared.GetSubdomainFromRequest(r, web.Cfg.Domain, web.Cfg.Space) + subdomain := shared.GetSubdomainFromRequest(r, web.Cfg.Domain, web.Cfg.TxtPrefix) if web.RootRouter == nil || web.UserRouter == nil { - web.Logger.Error("routers not initialized") + web.Cfg.Logger.Error("routers not initialized") http.Error(w, "routers not initialized", http.StatusInternalServerError) return } diff --git a/pgs/web_asset_handler.go b/pgs/web_asset_handler.go index cf58d6bd..f4ac4659 100644 --- a/pgs/web_asset_handler.go +++ b/pgs/web_asset_handler.go @@ -41,7 +41,7 @@ func hasProtocol(url string) bool { func (h *ApiAssetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger := h.Logger var redirects []*RedirectRule - redirectFp, redirectInfo, err := h.Storage.GetObject(h.Bucket, filepath.Join(h.ProjectDir, "_redirects")) + redirectFp, redirectInfo, err := h.Cfg.Storage.GetObject(h.Bucket, filepath.Join(h.ProjectDir, "_redirects")) if err == nil { defer redirectFp.Close() if redirectInfo != nil && redirectInfo.Size > h.Cfg.MaxSpecialFileSize { @@ -85,7 +85,7 @@ func (h *ApiAssetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // before redirecting, this saves a hop that will just end up a 404 if !hasProtocol(fp.Filepath) && strings.HasSuffix(fp.Filepath, "/") { next := filepath.Join(h.ProjectDir, fp.Filepath, "index.html") - obj, _, err := h.Storage.GetObject(h.Bucket, next) + obj, _, err := h.Cfg.Storage.GetObject(h.Bucket, next) if err != nil { continue } @@ -137,7 +137,7 @@ func (h *ApiAssetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { attempts = append(attempts, fpath) logger = logger.With("object", fpath) logger.Info("serving object") - c, info, err = h.Storage.ServeObject( + c, info, err = h.Cfg.Storage.ServeObject( h.Bucket, fpath, h.ImgProcessOpts, @@ -164,7 +164,7 @@ func (h *ApiAssetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer contents.Close() var headers []*HeaderRule - headersFp, headersInfo, err := h.Storage.GetObject(h.Bucket, filepath.Join(h.ProjectDir, "_headers")) + headersFp, headersInfo, err := h.Cfg.Storage.GetObject(h.Bucket, filepath.Join(h.ProjectDir, "_headers")) if err == nil { defer headersFp.Close() if headersInfo != nil && headersInfo.Size > h.Cfg.MaxSpecialFileSize { diff --git a/pgs/web_cache.go b/pgs/web_cache.go index e372c6c5..7763b755 100644 --- a/pgs/web_cache.go +++ b/pgs/web_cache.go @@ -48,13 +48,13 @@ func createSubCacheDrain(ctx context.Context, logger *slog.Logger) *pipe.Reconne // cached assets for a given subdomain are grouped under a single key (which is // separate from the "GET-https-example.com-/path" key used for serving files // from the cache). -func purgeCache(cfg *shared.ConfigSite, send *pipe.ReconnectReadWriteCloser, surrogate string) error { +func purgeCache(cfg *PgsConfig, send *pipe.ReconnectReadWriteCloser, surrogate string) error { cfg.Logger.Info("purging cache", "surrogate", surrogate) time.Sleep(1 * time.Second) _, err := send.Write([]byte(surrogate + "\n")) return err } -func purgeAllCache(cfg *shared.ConfigSite, send *pipe.ReconnectReadWriteCloser) error { +func purgeAllCache(cfg *PgsConfig, send *pipe.ReconnectReadWriteCloser) error { return purgeCache(cfg, send, "*") } diff --git a/pgs/web_test.go b/pgs/web_test.go index 6e7c7567..03d50c87 100644 --- a/pgs/web_test.go +++ b/pgs/web_test.go @@ -10,16 +10,12 @@ import ( "testing" "time" - "github.com/picosh/pico/db" - "github.com/picosh/pico/db/stub" + pgsdb "github.com/picosh/pico/pgs/db" "github.com/picosh/pico/shared" "github.com/picosh/pico/shared/storage" sst "github.com/picosh/pobj/storage" ) -var testUserID = "user-1" -var testUsername = "user" - type ApiExample struct { name string path string @@ -29,63 +25,34 @@ type ApiExample struct { status int contentType string - dbpool db.DB storage map[string]map[string]string } type PgsDb struct { - *stub.StubDB + *pgsdb.MemoryDB } func NewPgsDb(logger *slog.Logger) *PgsDb { - sb := stub.NewStubDB(logger) - return &PgsDb{ - StubDB: sb, + sb := pgsdb.NewDBMemory(logger) + sb.SetupTestData() + _, err := sb.InsertProject(sb.Users[0].ID, "test", "test") + if err != nil { + panic(err) } -} - -func (p *PgsDb) FindUserForName(name string) (*db.User, error) { - return &db.User{ - ID: testUserID, - Name: testUsername, - }, nil -} - -func (p *PgsDb) FindProjectByName(userID, name string) (*db.Project, error) { - return &db.Project{ - ID: "project-1", - UserID: userID, - Name: name, - ProjectDir: name, - Username: testUsername, - Acl: db.ProjectAcl{ - Type: "public", - }, - }, nil -} - -type PgsAnalyticsDb struct { - *PgsDb -} - -func NewPgsAnalticsDb(logger *slog.Logger) *PgsAnalyticsDb { - return &PgsAnalyticsDb{ - PgsDb: NewPgsDb(logger), + return &PgsDb{ + MemoryDB: sb, } } -func (p *PgsAnalyticsDb) HasFeatureForUser(userID, feature string) bool { - return true -} - -func mkpath(path string) string { - return fmt.Sprintf("https://%s-test.pgs.test%s", testUsername, path) +func (p *PgsDb) mkpath(path string) string { + return fmt.Sprintf("https://%s-test.pgs.test%s", p.Users[0].Name, path) } func TestApiBasic(t *testing.T) { - bucketName := shared.GetAssetBucketName(testUserID) - cfg := NewConfigSite() - cfg.Domain = "pgs.test" + logger := slog.Default() + dbpool := NewPgsDb(logger) + bucketName := shared.GetAssetBucketName(dbpool.Users[0].ID) + tt := []*ApiExample{ { name: "basic", @@ -94,7 +61,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/index.html": "hello world!", @@ -108,7 +74,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/test.html": "hello world!", @@ -122,7 +87,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusMovedPermanently, contentType: "text/html; charset=utf-8", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/subdir/index.html": "hello world!", @@ -136,7 +100,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusMovedPermanently, contentType: "text/html; charset=utf-8", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/_redirects": "/anything /about.html 301", @@ -151,7 +114,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/subdir/index.html": "hello world!", @@ -165,7 +127,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/_redirects": "/* /index.html 200", @@ -180,7 +141,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusNotFound, contentType: "text/plain; charset=utf-8", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: {}, }, @@ -192,7 +152,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusNotFound, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/404.html": "boom!", @@ -206,7 +165,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "image/jpeg", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/profile.jpg": "image", @@ -221,7 +179,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusMovedPermanently, contentType: "text/html; charset=utf-8", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/_redirects": "/anything /about.html 301", @@ -239,7 +196,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusNotModified, contentType: "", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/test.html": "hello world!", @@ -256,7 +212,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/test.html": "hello world!", @@ -273,7 +228,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusNotModified, contentType: "", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/test.html": "hello world!", @@ -290,7 +244,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusOK, contentType: "text/html", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/test.html": "hello world!", @@ -309,7 +262,6 @@ func TestApiBasic(t *testing.T) { status: http.StatusNotModified, contentType: "", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/test.html": "hello world!", @@ -320,14 +272,16 @@ func TestApiBasic(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("GET", mkpath(tc.path), strings.NewReader("")) + request := httptest.NewRequest("GET", dbpool.mkpath(tc.path), strings.NewReader("")) for key, val := range tc.reqHeaders { request.Header.Set(key, val) } responseRecorder := httptest.NewRecorder() st, _ := storage.NewStorageMemory(tc.storage) - router := NewWebRouter(cfg, cfg.Logger, tc.dbpool, st) + cfg := NewPgsConfig(logger, dbpool, st) + cfg.Domain = "pgs.test" + router := NewWebRouter(cfg) router.ServeHTTP(responseRecorder, request) if responseRecorder.Code != tc.status { @@ -349,6 +303,10 @@ func TestApiBasic(t *testing.T) { if err != nil { t.Errorf("err: %s", err.Error()) } + if location == nil { + t.Error("no location header in response") + return + } if tc.wantUrl != location.String() { t.Errorf("Want '%s', got '%s'", tc.wantUrl, location.String()) } @@ -374,9 +332,9 @@ func (s *ImageStorageMemory) ServeObject(bucket sst.Bucket, fpath string, opts * } func TestImageManipulation(t *testing.T) { - bucketName := shared.GetAssetBucketName(testUserID) - cfg := NewConfigSite() - cfg.Domain = "pgs.test" + logger := slog.Default() + dbpool := NewPgsDb(logger) + bucketName := shared.GetAssetBucketName(dbpool.Users[0].ID) tt := []ApiExample{ { @@ -386,7 +344,6 @@ func TestImageManipulation(t *testing.T) { status: http.StatusOK, contentType: "image/jpeg", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/app.jpg": "hello world!", @@ -400,7 +357,6 @@ func TestImageManipulation(t *testing.T) { status: http.StatusOK, contentType: "image/jpeg", - dbpool: NewPgsDb(cfg.Logger), storage: map[string]map[string]string{ bucketName: { "test/subdir/app.jpg": "hello world!", @@ -411,7 +367,7 @@ func TestImageManipulation(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("GET", mkpath(tc.path), strings.NewReader("")) + request := httptest.NewRequest("GET", dbpool.mkpath(tc.path), strings.NewReader("")) responseRecorder := httptest.NewRecorder() memst, _ := storage.NewStorageMemory(tc.storage) @@ -421,7 +377,9 @@ func TestImageManipulation(t *testing.T) { Ratio: &storage.Ratio{}, }, } - router := NewWebRouter(cfg, cfg.Logger, tc.dbpool, st) + cfg := NewPgsConfig(logger, dbpool, st) + cfg.Domain = "pgs.test" + router := NewWebRouter(cfg) router.ServeHTTP(responseRecorder, request) if responseRecorder.Code != tc.status { diff --git a/pico/ssh.go b/pico/ssh.go index c0b204f3..bc1c1474 100644 --- a/pico/ssh.go +++ b/pico/ssh.go @@ -71,7 +71,7 @@ func StartSshServer() { DBPool: dbpool, } - sshAuth := shared.NewSshAuthHandler(dbpool, logger, cfg) + sshAuth := shared.NewSshAuthHandler(dbpool, logger) s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), diff --git a/pipe/ssh.go b/pipe/ssh.go index d3e9b586..87f8cdfa 100644 --- a/pipe/ssh.go +++ b/pipe/ssh.go @@ -42,7 +42,7 @@ func StartSshServer() { Access: syncmap.New[string, []string](), } - sshAuth := shared.NewSshAuthHandler(dbh, logger, cfg) + sshAuth := shared.NewSshAuthHandler(dbh, logger) s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), diff --git a/prose/ssh.go b/prose/ssh.go index a67c0729..102b66bf 100644 --- a/prose/ssh.go +++ b/prose/ssh.go @@ -86,7 +86,7 @@ func StartSshServer() { } handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap) - sshAuth := shared.NewSshAuthHandler(dbh, logger, cfg) + sshAuth := shared.NewSshAuthHandler(dbh, logger) s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), diff --git a/shared/ssh.go b/shared/ssh.go index f0270e75..f1f96669 100644 --- a/shared/ssh.go +++ b/shared/ssh.go @@ -9,42 +9,24 @@ import ( ) type SshAuthHandler struct { - DBPool db.DB + DB AuthFindUser Logger *slog.Logger - Cfg *ConfigSite } -func NewSshAuthHandler(dbpool db.DB, logger *slog.Logger, cfg *ConfigSite) *SshAuthHandler { - return &SshAuthHandler{ - DBPool: dbpool, - Logger: logger, - Cfg: cfg, - } +type AuthFindUser interface { + FindUserByPubkey(key string) (*db.User, error) } -func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag { - ff, _ := dbpool.FindFeatureForUser(userID, "plus") - // we have free tiers so users might not have a feature flag - // in which case we set sane defaults - if ff == nil { - ff = db.NewFeatureFlag( - userID, - "plus", - cfg.MaxSize, - cfg.MaxAssetSize, - cfg.MaxSpecialFileSize, - ) +func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger) *SshAuthHandler { + return &SshAuthHandler{ + DB: dbh, + Logger: logger, } - // this is jank - ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize) - ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize) - ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize) - return ff } func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool { pubkey := utils.KeyForKeyText(key) - user, err := r.DBPool.FindUserForKey(ctx.User(), pubkey) + user, err := r.DB.FindUserByPubkey(pubkey) if err != nil { r.Logger.Error( "could not find user for key", @@ -66,3 +48,23 @@ func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) b ctx.Permissions().Extensions["pubkey"] = pubkey return true } + +func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag { + ff, _ := dbpool.FindFeatureForUser(userID, "plus") + // we have free tiers so users might not have a feature flag + // in which case we set sane defaults + if ff == nil { + ff = db.NewFeatureFlag( + userID, + "plus", + cfg.MaxSize, + cfg.MaxAssetSize, + cfg.MaxSpecialFileSize, + ) + } + // this is jank + ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize) + ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize) + ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize) + return ff +} diff --git a/tui/analytics/analytics.go b/tui/analytics/analytics.go index 7515f771..ad28eade 100644 --- a/tui/analytics/analytics.go +++ b/tui/analytics/analytics.go @@ -158,8 +158,12 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.KeyMsg: switch msg.String() { - case "q", "esc": + case "esc": return m, pages.Navigate(pages.MenuPage) + case "q": + if !m.input.Focused() { + return m, pages.Navigate(pages.MenuPage) + } case "tab": if m.input.Focused() { m.input.Blur()