diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index d0bf2298..af4a28c3 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -2,6 +2,7 @@ package verificationhelper_test import ( "context" + "database/sql" "fmt" "os" "testing" @@ -65,11 +66,20 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true) + senderVerificationDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) + require.NoError(t, err) + + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true) + receiverVerificationDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) + require.NoError(t, err) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } diff --git a/crypto/verificationhelper/verificationstore_test.go b/crypto/verificationhelper/verificationstore_test.go new file mode 100644 index 00000000..a3b1895d --- /dev/null +++ b/crypto/verificationhelper/verificationstore_test.go @@ -0,0 +1,87 @@ +package verificationhelper_test + +import ( + "context" + "database/sql" + + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/id" +) + +type SQLiteVerificationStore struct { + db *sql.DB +} + +const ( + selectVerifications = `SELECT transaction_data FROM verifications` + getVerificationByTransactionID = selectVerifications + ` WHERE transaction_id = ?1` + getVerificationByUserDeviceID = selectVerifications + ` + WHERE transaction_data->>'their_user_id' = ?1 + AND transaction_data->>'their_device_id' = ?2 + ` + deleteVerificationsQuery = `DELETE FROM verifications WHERE transaction_id = ?1` +) + +var _ verificationhelper.VerificationStore = (*SQLiteVerificationStore)(nil) + +func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerificationStore, error) { + _, err := db.ExecContext(ctx, ` + CREATE TABLE verifications ( + transaction_id TEXT PRIMARY KEY NOT NULL, + transaction_data JSONB NOT NULL + ); + CREATE INDEX verifications_user_device_id ON + verifications(transaction_data->>'their_user_id', transaction_data->>'their_device_id'); + `) + return &SQLiteVerificationStore{db}, err +} + +func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) { + rows, err := s.db.QueryContext(ctx, selectVerifications) + if err != nil { + return nil, err + } + return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { + err = rows.Scan(&dbutil.JSON{Data: &txn}) + return + }).AsList() +} + +func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) { + zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction") + row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID) + err = row.Scan(&dbutil.JSON{Data: &txn}) + if err == sql.ErrNoRows { + err = verificationhelper.ErrUnknownVerificationTransaction + } + return +} + +func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) { + row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID) + err = row.Scan(&dbutil.JSON{Data: &txn}) + if err == sql.ErrNoRows { + err = verificationhelper.ErrUnknownVerificationTransaction + } + return +} + +func (vq *SQLiteVerificationStore) SaveVerificationTransaction(ctx context.Context, txn verificationhelper.VerificationTransaction) (err error) { + zerolog.Ctx(ctx).Debug().Any("transaction", &txn).Msg("Saving verification transaction") + _, err = vq.db.ExecContext(ctx, ` + INSERT INTO verifications (transaction_id, transaction_data) + VALUES (?1, ?2) + ON CONFLICT (transaction_id) DO UPDATE + SET transaction_data=excluded.transaction_data + `, txn.TransactionID, &dbutil.JSON{Data: &txn}) + return +} + +func (vq *SQLiteVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) (err error) { + _, err = vq.db.ExecContext(ctx, deleteVerificationsQuery, txnID) + return +}