Skip to content

Commit

Permalink
verificationhelper: add tests for using SQLite store for verification
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Nov 26, 2024
1 parent 2a8e6fb commit f7e5f0a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
14 changes: 12 additions & 2 deletions crypto/verificationhelper/verificationhelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package verificationhelper_test

import (
"context"
"database/sql"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -65,11 +66,20 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ
func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) {
t.Helper()
sendingCallbacks = newAllVerificationCallbacks()
sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true)
senderVerificationDB, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB)
require.NoError(t, err)

sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true)
require.NoError(t, sendingHelper.Init(ctx))

receivingCallbacks = newAllVerificationCallbacks()
receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true)
receiverVerificationDB, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB)
require.NoError(t, err)
receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true)
require.NoError(t, receivingHelper.Init(ctx))
return
}
Expand Down
87 changes: 87 additions & 0 deletions crypto/verificationhelper/verificationstore_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package verificationhelper_test

import (
"context"
"database/sql"

_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"

"maunium.net/go/mautrix/crypto/verificationhelper"
"maunium.net/go/mautrix/id"
)

type SQLiteVerificationStore struct {
db *sql.DB
}

const (
selectVerifications = `SELECT transaction_data FROM verifications`
getVerificationByTransactionID = selectVerifications + ` WHERE transaction_id = ?1`
getVerificationByUserDeviceID = selectVerifications + `
WHERE transaction_data->>'their_user_id' = ?1
AND transaction_data->>'their_device_id' = ?2
`
deleteVerificationsQuery = `DELETE FROM verifications WHERE transaction_id = ?1`
)

var _ verificationhelper.VerificationStore = (*SQLiteVerificationStore)(nil)

func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerificationStore, error) {
_, err := db.ExecContext(ctx, `
CREATE TABLE verifications (
transaction_id TEXT PRIMARY KEY NOT NULL,
transaction_data JSONB NOT NULL
);
CREATE INDEX verifications_user_device_id ON
verifications(transaction_data->>'their_user_id', transaction_data->>'their_device_id');
`)
return &SQLiteVerificationStore{db}, err
}

func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) {
rows, err := s.db.QueryContext(ctx, selectVerifications)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) {
err = rows.Scan(&dbutil.JSON{Data: &txn})
return
}).AsList()
}

func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) {
zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction")
row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID)
err = row.Scan(&dbutil.JSON{Data: &txn})
if err == sql.ErrNoRows {
err = verificationhelper.ErrUnknownVerificationTransaction
}
return
}

func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) {
row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID)
err = row.Scan(&dbutil.JSON{Data: &txn})
if err == sql.ErrNoRows {
err = verificationhelper.ErrUnknownVerificationTransaction
}
return
}

func (vq *SQLiteVerificationStore) SaveVerificationTransaction(ctx context.Context, txn verificationhelper.VerificationTransaction) (err error) {
zerolog.Ctx(ctx).Debug().Any("transaction", &txn).Msg("Saving verification transaction")
_, err = vq.db.ExecContext(ctx, `
INSERT INTO verifications (transaction_id, transaction_data)
VALUES (?1, ?2)
ON CONFLICT (transaction_id) DO UPDATE
SET transaction_data=excluded.transaction_data
`, txn.TransactionID, &dbutil.JSON{Data: &txn})
return
}

func (vq *SQLiteVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) (err error) {
_, err = vq.db.ExecContext(ctx, deleteVerificationsQuery, txnID)
return
}

0 comments on commit f7e5f0a

Please sign in to comment.