taler-mailbox

Service for asynchronous wallet-to-wallet payment messages
Log | Files | Refs | Submodules | README | LICENSE

commit f3bd8342954f4f15bc149c7464c4e5adcd9cbe2e
parent dc8e47b663891246cad3e60ac46fb20ef402e4f3
Author: Martin Schanzenbach <schanzen@gnunet.org>
Date:   Sun, 22 Mar 2026 15:54:43 +0100

refactor for prepared statements

Diffstat:
Mcmd/mailbox-server/main.go | 3+--
Mcmd/mailbox-server/main_test.go | 27++++++++++-----------------
Mpkg/rest/db.go | 476+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
Mpkg/rest/mailbox.go | 124+++++++++++++++++++++----------------------------------------------------------
4 files changed, 387 insertions(+), 243 deletions(-)

diff --git a/cmd/mailbox-server/main.go b/cmd/mailbox-server/main.go @@ -19,7 +19,6 @@ package main import ( - "database/sql" "flag" "fmt" "log" @@ -97,7 +96,7 @@ func main() { } psqlconn := iniCfg.GetString("mailbox-pq", "connection_string", "postgres:///taler-mailbox") - db, err := sql.Open("postgres", psqlconn) + db, err := mailbox.OpenDatabase(psqlconn) if err != nil { log.Panic(err) } diff --git a/cmd/mailbox-server/main_test.go b/cmd/mailbox-server/main_test.go @@ -5,7 +5,6 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/sha512" - "database/sql" "encoding/binary" "encoding/json" "fmt" @@ -83,18 +82,12 @@ func TestMain(m *testing.M) { os.Exit(1) } psqlconn := cfg.GetString("mailbox-pq", "connection_string", "postgres:///taler-mailbox") - segments := strings.Split(strings.Split(psqlconn, "?")[0], "/") - dbName := segments[len(segments)-1] - db, err := sql.Open("postgres", psqlconn) + db, err := mailbox.OpenDatabase(psqlconn) if err != nil { log.Panic(err) } defer db.Close() - err = talerutil.DBInit(db, "../..", dbName, "taler-mailbox") - if err != nil { - log.Fatalf("Failed to apply versioning or patches: %v", err) - } merchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var orderResp merchant.PostOrderRequest if r.URL.Path == "/config" { @@ -143,12 +136,12 @@ func TestMain(m *testing.M) { code := m.Run() // Purge DB - mailbox.DeleteAllInboxEntriesFromDatabase(a.DB) + a.DB.DeleteAllInboxEntries() os.Exit(code) } func TestEmptyMailbox(t *testing.T) { - mailbox.DeleteAllInboxEntriesFromDatabase(a.DB) + a.DB.DeleteAllInboxEntries() req, _ := http.NewRequest("GET", "/"+testAliceHashedSigningKeyString, nil) response := executeRequest(req) @@ -162,7 +155,7 @@ func TestEmptyMailbox(t *testing.T) { func TestSendMessage(t *testing.T) { testMessage := make([]byte, 256) - mailbox.DeleteAllInboxEntriesFromDatabase(a.DB) + a.DB.DeleteAllInboxEntries() req, _ := http.NewRequest("POST", "/"+testAliceHashedSigningKeyString, bytes.NewReader(testMessage)) response := executeRequest(req) @@ -191,7 +184,7 @@ func TestSendMessagePaid(t *testing.T) { setMailboxPaid(true) // Cleanup - mailbox.DeleteAllInboxEntriesFromDatabase(a.DB) + a.DB.DeleteAllInboxEntries() testMessage := make([]byte, 256) rand.Read(testMessage) @@ -275,8 +268,8 @@ func TestMailboxRegistration(t *testing.T) { if respMsg.EncryptionKey != msg.MailboxMetadata.EncryptionKey { fmt.Printf("Keys mismatch! %v %v\n", respMsg, msg.MailboxMetadata) } - mailbox.DeleteAllMailboxesFromDatabase(a.DB) - mailbox.DeleteAllPendingRegistrationsFromDatabase(a.DB) + a.DB.DeleteAllMailboxes() + a.DB.DeleteAllPendingRegistrations() } func createMailboxMetadata(encKey []byte, signKey []byte, info string) mailbox.MailboxMetadata { @@ -344,8 +337,8 @@ func TestMailboxRegistrationWithInfo(t *testing.T) { if respMsg.Info != "Hello World" { fmt.Printf("Info field missing! %v %v\n", respMsg, msg.MailboxMetadata) } - mailbox.DeleteAllMailboxesFromDatabase(a.DB) - mailbox.DeleteAllPendingRegistrationsFromDatabase(a.DB) + a.DB.DeleteAllMailboxes() + a.DB.DeleteAllPendingRegistrations() } func TestMailboxRegistrationPaid(t *testing.T) { @@ -400,7 +393,7 @@ func TestPostThenDeleteMessage(t *testing.T) { numMessagesToPost := (a.MessageResponseLimit + 7) testMessages := make([]byte, 256*numMessagesToPost) _, _ = rand.Read(testMessages) - mailbox.DeleteAllInboxEntriesFromDatabase(a.DB) + a.DB.DeleteAllInboxEntries() for i := 0; i < int(numMessagesToPost); i++ { testMessage := testMessages[i*256 : (i+1)*256] diff --git a/pkg/rest/db.go b/pkg/rest/db.go @@ -22,8 +22,12 @@ import ( "context" "database/sql" "errors" - _ "github.com/lib/pq" + "log" + "strings" "time" + + _ "github.com/lib/pq" + talerutil "github.com/schanzen/taler-go/pkg/util" ) type Timestamp struct { @@ -95,10 +99,251 @@ type InboxEntry struct { HashedSigningKey string } -func InsertInboxEntryIntoDatabase(db *sql.DB, e *InboxEntry) error { - query := `INSERT INTO taler_mailbox.inbox_entries - VALUES (DEFAULT, $1, $2);` - rows, err := db.Query(query, e.HashedSigningKey, e.Body) +// MailboxDatabase is the main taldir database connection handle +type MailboxDatabase struct { + // SQL connection + db *sql.DB + + // Get mailbox by key + getMailboxMetadataBySigningKeyStmt *sql.Stmt + + // Get entry by key and body + getInboxEntryBySigningKeyAndBodyStmt *sql.Stmt + + // Get entry statement + getInboxEntryBySerialStmt *sql.Stmt + + // Insert entry statement + insertInboxEntryStmt *sql.Stmt + + // Delete entry statement + deleteInboxEntryBySerialStmt *sql.Stmt + + // Insert pending registration + insertPendingRegistrationStmt *sql.Stmt + + // Update pending registration oder ID + updatePendingRegistrationOrderIdStmt *sql.Stmt + + // Delete pending registration + deletePendingRegistrationStmt *sql.Stmt + + // Insert mailbox registration + insertMailboxRegistrationStmt *sql.Stmt + + // Update the expiration of a Mailbox + updateMailboxExpirationStmt *sql.Stmt + + // Get pending registration by signing key + getPendingMailboxRegistrationBySigningKeyStmt *sql.Stmt + + // Delete stale registrations + deleteStaleRegistrationsStmt *sql.Stmt + + // Delete stale pending registrations + deleteStalePendingRegistrationsStmt *sql.Stmt + + // Get messages + getMessagesBySigningKeyStmt *sql.Stmt + + // Count messages + countMessagesBySigningKeyStmt *sql.Stmt +} + +func (db *MailboxDatabase) Close() { + for _, s := range []*sql.Stmt{ + db.insertInboxEntryStmt, + db.getMailboxMetadataBySigningKeyStmt, + db.getInboxEntryBySerialStmt, + db.getInboxEntryBySigningKeyAndBodyStmt, + db.deleteInboxEntryBySerialStmt, + db.insertPendingRegistrationStmt, + db.updatePendingRegistrationOrderIdStmt, + db.deletePendingRegistrationStmt, + db.insertMailboxRegistrationStmt, + db.updateMailboxExpirationStmt, + db.getPendingMailboxRegistrationBySigningKeyStmt, + db.deleteStalePendingRegistrationsStmt, + db.deleteStaleRegistrationsStmt, + db.getMessagesBySigningKeyStmt, + db.countMessagesBySigningKeyStmt, + } { + if s != nil { + s.Close() + } + } + db.db.Close() +} + +func OpenDatabase(psqlconn string) (*MailboxDatabase, error) { + db, err := sql.Open("postgres", psqlconn) + if err != nil { + return nil, err + } + segments := strings.Split(strings.Split(psqlconn, "?")[0], "/") + dbName := segments[len(segments)-1] + + err = talerutil.DBInit(db, "../..", dbName, "taler-directory") + if err != nil { + log.Fatalf("Failed to apply versioning or patches: %v", err) + } + insertInboxEntryStmt, err := db.Prepare(`INSERT INTO taler_mailbox.inbox_entries + VALUES (DEFAULT, $1, $2);`) + if err != nil { + return nil, err + } + insertPendingRegistrationStmt, err := db.Prepare(`INSERT INTO taler_mailbox.pending_mailbox_registrations + VALUES (DEFAULT, $1, $2, $3, $4);`) + if err != nil { + return nil, err + } + insertMailboxRegistrationStmt, err := db.Prepare(`INSERT INTO taler_mailbox.mailbox_metadata + VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7);`) + if err != nil { + return nil, err + } + updatePendingRegistrationOrderIdStmt, err := db.Prepare(`UPDATE taler_mailbox.pending_mailbox_registrations + SET + "order_id" = $2 + WHERE "hashed_signing_key" = $1;`) + if err != nil { + return nil, err + } + updateMailboxExpirationStmt, err := db.Prepare(`UPDATE taler_mailbox.mailbox_metadata + SET + "expiration" = $2 + WHERE "hashed_signing_key" = $1;`) + getPendingRegistrationBySingingKeyStmt, err := db.Prepare(`SELECT + "serial", + "hashed_signing_key", + "registration_duration", + "order_id" + FROM taler_mailbox.pending_mailbox_registrations + WHERE + "hashed_signing_key"=$1 + LIMIT 1 + ;`) + if err != nil { + return nil, err + } + getMailboxMetadataBySigningKey, err := db.Prepare(`SELECT + "serial", + "hashed_signing_key", + "signing_key", + "signing_key_type", + "encryption_key", + "encryption_key_type", + "expiration", + "info" + FROM taler_mailbox.mailbox_metadata + WHERE + "hashed_signing_key"=$1 + LIMIT 1 + ;`) + if err != nil { + return nil, err + } + getInboxEntryBySigningKeyAndBodyStmt, err := db.Prepare(`SELECT + "serial", + "hashed_signing_key", + "body" + FROM taler_mailbox.inbox_entries + WHERE + "hashed_signing_key"=$1 AND + "body"=$2 + ;`) + if err != nil { + return nil, err + } + getInboxEntryBySerialStmt, err := db.Prepare(`SELECT + "serial", + "hashed_signing_key", + "body" + FROM taler_mailbox.inbox_entries + WHERE + "serial"=$1 AND + "hashed_signing_key"=$2 + ;`) + if err != nil { + return nil, err + } + deletePendingRegistrationStmt, err := db.Prepare(`DELETE + FROM taler_mailbox.pending_mailbox_registrations + WHERE + "serial" = $1 + ;`) + if err != nil { + return nil, err + } + deleteInboxEntryBySerialStmt, err := db.Prepare(`DELETE FROM taler_mailbox.inbox_entries + WHERE serial IN ( + SELECT serial FROM taler_mailbox.inbox_entries + WHERE + "hashed_signing_key"=$1 AND + "serial">=$2 + LIMIT $3 + );`) + if err != nil { + return nil, err + } + deleteStaleRegistrationsStmt, err := db.Prepare(`DELETE + FROM taler_mailbox.mailbox_metadata + WHERE + "expiration" < $1 + ;`) + if err != nil { + return nil, err + } + deleteStalePendingRegistrationsStmt, err := db.Prepare(`DELETE + FROM taler_mailbox.pending_mailbox_registrations + WHERE + "created_at" < $1 + ;`) + if err != nil { + return nil, err + } + getMessagesBySigningKeyStmt, err := db.Prepare(`SELECT + "serial", + "hashed_signing_key", + "body" + FROM taler_mailbox.inbox_entries + WHERE + "hashed_signing_key" = $1 + LIMIT $2 + ;`) + if err != nil { + return nil, err + } + countMessagesBySigningKeyStmt, err := db.Prepare(`SELECT COUNT(*) AS num_messages + FROM taler_mailbox.inbox_entries + WHERE + "hashed_signing_key"=$1 + ;`) + if err != nil { + return nil, err + } + return &MailboxDatabase{ + db: db, + insertInboxEntryStmt: insertInboxEntryStmt, + insertPendingRegistrationStmt: insertPendingRegistrationStmt, + insertMailboxRegistrationStmt: insertMailboxRegistrationStmt, + updatePendingRegistrationOrderIdStmt: updatePendingRegistrationOrderIdStmt, + updateMailboxExpirationStmt: updateMailboxExpirationStmt, + getPendingMailboxRegistrationBySigningKeyStmt: getPendingRegistrationBySingingKeyStmt, + getMailboxMetadataBySigningKeyStmt: getMailboxMetadataBySigningKey, + getInboxEntryBySigningKeyAndBodyStmt: getInboxEntryBySigningKeyAndBodyStmt, + getInboxEntryBySerialStmt: getInboxEntryBySerialStmt, + deletePendingRegistrationStmt: deletePendingRegistrationStmt, + deleteInboxEntryBySerialStmt: deleteInboxEntryBySerialStmt, + deleteStaleRegistrationsStmt: deleteStaleRegistrationsStmt, + deleteStalePendingRegistrationsStmt: deleteStalePendingRegistrationsStmt, + getMessagesBySigningKeyStmt: getMessagesBySigningKeyStmt, + countMessagesBySigningKeyStmt: countMessagesBySigningKeyStmt, + }, nil +} + +func (db *MailboxDatabase) InsertInboxEntry(e *InboxEntry) error { + rows, err := db.insertInboxEntryStmt.Query(e.HashedSigningKey, e.Body) if err != nil { return err } @@ -106,11 +351,9 @@ func InsertInboxEntryIntoDatabase(db *sql.DB, e *InboxEntry) error { return nil } -func InsertPendingRegistrationIntoDatabase(db *sql.DB, pr *PendingMailboxRegistration) error { +func (db *MailboxDatabase) InsertPendingMailboxRegistration(pr *PendingMailboxRegistration) error { pr.CreatedAt = time.Now().Unix() - query := `INSERT INTO taler_mailbox.pending_mailbox_registrations - VALUES (DEFAULT, $1, $2, $3, $4);` - rows, err := db.Query(query, pr.CreatedAt, pr.HashedSigningKey, pr.OrderID, pr.Duration) + rows, err := db.insertPendingRegistrationStmt.Query(pr.CreatedAt, pr.HashedSigningKey, pr.OrderID, pr.Duration) if err != nil { return err } @@ -118,10 +361,8 @@ func InsertPendingRegistrationIntoDatabase(db *sql.DB, pr *PendingMailboxRegistr return nil } -func InsertMailboxRegistrationIntoDatabase(db *sql.DB, mb *MailboxMetadata) error { - query := `INSERT INTO taler_mailbox.mailbox_metadata - VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7);` - rows, err := db.Query(query, mb.HashedSigningKey, mb.SigningKey, mb.SigningKeyType, mb.EncryptionKey, mb.EncryptionKeyType, mb.Expiration.Seconds, mb.Info) +func (db *MailboxDatabase) InsertMailboxRegistration(mb *MailboxMetadata) error { + rows, err := db.insertMailboxRegistrationStmt.Query(mb.HashedSigningKey, mb.SigningKey, mb.SigningKeyType, mb.EncryptionKey, mb.EncryptionKeyType, mb.Expiration.Seconds, mb.Info) if err != nil { return err } @@ -129,12 +370,8 @@ func InsertMailboxRegistrationIntoDatabase(db *sql.DB, mb *MailboxMetadata) erro return nil } -func UpdatePendingRegistrationOrderIdInDatabase(db *sql.DB, pr *PendingMailboxRegistration) error { - query := `UPDATE taler_mailbox.pending_mailbox_registrations - SET - "order_id" = $2 - WHERE "hashed_signing_key" = $1;` - rows, err := db.Query(query, pr.HashedSigningKey, pr.OrderID) +func (db *MailboxDatabase) UpdatePendingMailboxRegistrationOrderId(pr *PendingMailboxRegistration) error { + rows, err := db.updatePendingRegistrationOrderIdStmt.Query(pr.HashedSigningKey, pr.OrderID) if err != nil { return err } @@ -142,12 +379,8 @@ func UpdatePendingRegistrationOrderIdInDatabase(db *sql.DB, pr *PendingMailboxRe return nil } -func UpdateMailboxExpirationInDatabase(db *sql.DB, mb *MailboxMetadata) error { - query := `UPDATE taler_mailbox.mailbox_metadata - SET - "expiration" = $2 - WHERE "hashed_signing_key" = $1;` - rows, err := db.Query(query, mb.HashedSigningKey, mb.Expiration.Seconds) +func (db *MailboxDatabase) UpdateMailboxExpiration(mb *MailboxMetadata) error { + rows, err := db.updateMailboxExpirationStmt.Query(mb.HashedSigningKey, mb.Expiration.Seconds) if err != nil { return err } @@ -155,19 +388,9 @@ func UpdateMailboxExpirationInDatabase(db *sql.DB, mb *MailboxMetadata) error { return nil } -func GetPendingRegistrationFromDatabaseBySigningKey(db *sql.DB, pr *PendingMailboxRegistration, hashedKey string) error { - query := `SELECT - "serial", - "hashed_signing_key", - "registration_duration", - "order_id" - FROM taler_mailbox.pending_mailbox_registrations - WHERE - "hashed_signing_key"=$1 - LIMIT 1 - ;` +func (db *MailboxDatabase) GetPendingMailboxRegistrationBySigningKey(pr *PendingMailboxRegistration, hashedKey string) error { // Execute Query - rows, err := db.Query(query, hashedKey) + rows, err := db.getPendingMailboxRegistrationBySigningKeyStmt.Query(hashedKey) if err != nil { return err } @@ -184,23 +407,9 @@ func GetPendingRegistrationFromDatabaseBySigningKey(db *sql.DB, pr *PendingMailb ) } -func GetMailboxMetadataFromDatabaseBySigningKey(db *sql.DB, mb *MailboxMetadata, hashedKey string) error { - query := `SELECT - "serial", - "hashed_signing_key", - "signing_key", - "signing_key_type", - "encryption_key", - "encryption_key_type", - "expiration", - "info" - FROM taler_mailbox.mailbox_metadata - WHERE - "hashed_signing_key"=$1 - LIMIT 1 - ;` +func (db *MailboxDatabase) GetMailboxMetadataBySigningKey(mb *MailboxMetadata, hashedKey string) error { // Execute Query - rows, err := db.Query(query, hashedKey) + rows, err := db.getMailboxMetadataBySigningKeyStmt.Query(hashedKey) if err != nil { return err } @@ -221,18 +430,9 @@ func GetMailboxMetadataFromDatabaseBySigningKey(db *sql.DB, mb *MailboxMetadata, ) } -func GetInboxEntryFromDatabaseBySigningKeyAndBody(db *sql.DB, e *InboxEntry, hashedKey string, body []byte) error { - query := `SELECT - "serial", - "hashed_signing_key", - "body" - FROM taler_mailbox.inbox_entries - WHERE - "hashed_signing_key"=$1 AND - "body"=$2 - ;` +func (db *MailboxDatabase) GetInboxEntryBySigningKeyAndBody(e *InboxEntry, hashedKey string, body []byte) error { // Execute Query - rows, err := db.Query(query, hashedKey, body) + rows, err := db.getInboxEntryBySigningKeyAndBodyStmt.Query(hashedKey, body) if err != nil { return err } @@ -248,18 +448,9 @@ func GetInboxEntryFromDatabaseBySigningKeyAndBody(db *sql.DB, e *InboxEntry, has ) } -func GetInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, hashedKey string, serial int64) error { - query := `SELECT - "serial", - "hashed_signing_key", - "body" - FROM taler_mailbox.inbox_entries - WHERE - "serial"=$1 AND - "hashed_signing_key"=$2 - ;` +func (db *MailboxDatabase) GetInboxEntryBySerial(e *InboxEntry, hashedKey string, serial int64) error { // Execute Query - rows, err := db.Query(query, serial, hashedKey) + rows, err := db.getInboxEntryBySerialStmt.Query(serial, hashedKey) if err != nil { return err } @@ -275,22 +466,17 @@ func GetInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, hashedKey stri ) } -func DeletePendingRegistrationFromDatabase(db *sql.DB, pr *PendingMailboxRegistration) (int64, error) { +func (db *MailboxDatabase) DeletePendingRegistration(pr *PendingMailboxRegistration) (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() - query := `DELETE - FROM taler_mailbox.pending_mailbox_registrations - WHERE - "serial" = $1 - ;` // Execute Query - result, err := conn.ExecContext(ctx, query, pr.Serial) + result, err := db.deletePendingRegistrationStmt.ExecContext(ctx, pr.Serial) if err != nil { return 0, err } @@ -301,27 +487,18 @@ func DeletePendingRegistrationFromDatabase(db *sql.DB, pr *PendingMailboxRegistr return rows, nil } -// DeleteInboxEntryFromDatabaseBySerial Deletes all entries starting from given serial -func DeleteInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, count int) (int64, error) { +// DeleteInboxEntryBySerial Deletes all entries starting from given serial +func (db *MailboxDatabase) DeleteInboxEntryBySerial(e *InboxEntry, count int) (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() - query := `DELETE FROM taler_mailbox.inbox_entries - WHERE serial IN ( - SELECT serial FROM taler_mailbox.inbox_entries - WHERE - "hashed_signing_key"=$1 AND - "serial">=$2 - LIMIT $3 - ) - ;` // Execute Query - result, err := conn.ExecContext(ctx, query, e.HashedSigningKey, e.Serial, count) + result, err := db.deleteInboxEntryBySerialStmt.ExecContext(ctx, e.HashedSigningKey, e.Serial, count) if err != nil { return 0, err } @@ -332,23 +509,18 @@ func DeleteInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, count int) return rows, nil } -// DeleteStaleRegstrationsFromDatabase purges stale registrations -func DeleteStaleRegistrationsFromDatabase(db *sql.DB, registrationExpiration time.Time) (int64, error) { +// DeleteStaleRegstrations purges stale registrations +func (db *MailboxDatabase) DeleteStaleRegistrations(registrationExpiration time.Time) (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() - query := `DELETE - FROM taler_mailbox.mailbox_metadata - WHERE - "expiration" < $1 - ;` // Execute Query - result, err := conn.ExecContext(ctx, query, registrationExpiration.Unix()) + result, err := db.deleteStaleRegistrationsStmt.ExecContext(ctx, registrationExpiration.Unix()) if err != nil { return 0, err } @@ -359,23 +531,18 @@ func DeleteStaleRegistrationsFromDatabase(db *sql.DB, registrationExpiration tim return rows, nil } -// DeleteStalePendingRegstrationsFromDatabase purges stale registrations -func DeleteStalePendingRegistrationsFromDatabase(db *sql.DB, registrationExpiration time.Time) (int64, error) { +// DeleteStalePendingRegstrations purges stale registrations +func (db *MailboxDatabase) DeleteStalePendingRegistrations(registrationExpiration time.Time) (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() - query := `DELETE - FROM taler_mailbox.pending_mailbox_registrations - WHERE - "created_at" < $1 - ;` // Execute Query - result, err := conn.ExecContext(ctx, query, registrationExpiration.Unix()) + result, err := db.deleteStalePendingRegistrationsStmt.ExecContext(ctx, registrationExpiration.Unix()) if err != nil { return 0, err } @@ -386,20 +553,20 @@ func DeleteStalePendingRegistrationsFromDatabase(db *sql.DB, registrationExpirat return rows, nil } -func DeleteAllPendingRegistrationsFromDatabase(db *sql.DB) (int64, error) { +func (db *MailboxDatabase) DeleteAllPendingRegistrations() (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() query := `DELETE - FROM taler_mailbox.pending_mailbox_registrations - WHERE - 1=1 - ;` + FROM taler_mailbox.pending_mailbox_registrations + WHERE + 1=1 + ;` // Execute Query result, err := conn.ExecContext(ctx, query) if err != nil { @@ -412,20 +579,20 @@ func DeleteAllPendingRegistrationsFromDatabase(db *sql.DB) (int64, error) { return rows, nil } -func DeleteAllMailboxesFromDatabase(db *sql.DB) (int64, error) { +func (db *MailboxDatabase) DeleteAllMailboxes() (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() query := `DELETE - FROM taler_mailbox.mailbox_metadata - WHERE - 1=1 - ;` + FROM taler_mailbox.mailbox_metadata + WHERE + 1=1 + ;` // Execute Query result, err := conn.ExecContext(ctx, query) if err != nil { @@ -438,20 +605,20 @@ func DeleteAllMailboxesFromDatabase(db *sql.DB) (int64, error) { return rows, nil } -func DeleteAllInboxEntriesFromDatabase(db *sql.DB) (int64, error) { +func (db *MailboxDatabase) DeleteAllInboxEntries() (int64, error) { var ctx context.Context ctx, stop := context.WithCancel(context.Background()) defer stop() - conn, err := db.Conn(ctx) + conn, err := db.db.Conn(ctx) if err != nil { return 0, err } defer conn.Close() query := `DELETE - FROM taler_mailbox.inbox_entries - WHERE - 1=1 - ;` + FROM taler_mailbox.inbox_entries + WHERE + 1=1 + ;` // Execute Query result, err := conn.ExecContext(ctx, query) if err != nil { @@ -463,3 +630,48 @@ func DeleteAllInboxEntriesFromDatabase(db *sql.DB) (int64, error) { } return rows, nil } + +// Get Hash-salted alias from database +func (db *MailboxDatabase) GetMessages(hashedKey string, limit int) ([]InboxEntry, error) { + // Execute Query + rows, err := db.getMessagesBySigningKeyStmt.Query(hashedKey, limit) + if err != nil { + return []InboxEntry{}, err + } + defer rows.Close() + var entries = make([]InboxEntry, 0) + for rows.Next() { + var e InboxEntry + err = rows.Scan( + &e.Serial, + &e.HashedSigningKey, + &e.Body, + ) + if err != nil { + return entries, err + } + entries = append(entries, e) + } + return entries, nil +} + +// Get Hash-salted alias from database +func (db *MailboxDatabase) GetMessagesCount(hashedKey string) (int64, error) { + // Execute Query + rows, err := db.countMessagesBySigningKeyStmt.Query(hashedKey) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, nil + } + var res int64 + err = rows.Scan( + &res, + ) + if err != nil { + return 0, err + } + return res, nil +} diff --git a/pkg/rest/mailbox.go b/pkg/rest/mailbox.go @@ -22,7 +22,6 @@ package mailbox import ( "crypto/ed25519" "crypto/sha512" - "database/sql" "encoding/binary" "encoding/json" "fmt" @@ -74,7 +73,7 @@ type MailboxConfig struct { Ini talerutil.TalerConfiguration // The database connection to use - DB *sql.DB + DB *MailboxDatabase // Merchant connection Merchant merchant.Merchant @@ -90,7 +89,7 @@ type Mailbox struct { Router *mux.Router // The database connection to use - DB *sql.DB + DB *MailboxDatabase // Our configuration from the ini Cfg MailboxConfig @@ -205,7 +204,7 @@ type MailboxRateLimitedResponse struct { } func (m *Mailbox) configResponse(w http.ResponseWriter, r *http.Request) { - dp, err := m.Cfg.Ini.GetDuration("mailbox", "delivery_period", 3 * 24 * time.Hour) + dp, err := m.Cfg.Ini.GetDuration("mailbox", "delivery_period", 3*24*time.Hour) if err != nil { log.Fatal(err) } @@ -223,65 +222,6 @@ func (m *Mailbox) configResponse(w http.ResponseWriter, r *http.Request) { w.Write(response) } -// Get Hash-salted alias from database -func GetMessagesCountFromDatabase(db *sql.DB, hashedKey string) (int64, error) { - query := `SELECT COUNT(*) AS num_messages - FROM taler_mailbox.inbox_entries - WHERE - "hashed_signing_key"=$1 - ;` - // Execute Query - rows, err := db.Query(query, hashedKey) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, nil - } - var res int64 - err = rows.Scan( - &res, - ) - if err != nil { - return 0, err - } - return res, nil -} - -// Get Hash-salted alias from database -func GetMessagesFromDatabase(db *sql.DB, hashedKey string, limit int) ([]InboxEntry, error) { - query := `SELECT - "serial", - "hashed_signing_key", - "body" - FROM taler_mailbox.inbox_entries - WHERE - "hashed_signing_key" = $1 - LIMIT $2 - ;` - // Execute Query - rows, err := db.Query(query, hashedKey, limit) - if err != nil { - return []InboxEntry{}, err - } - defer rows.Close() - var entries = make([]InboxEntry, 0) - for rows.Next() { - var e InboxEntry - err = rows.Scan( - &e.Serial, - &e.HashedSigningKey, - &e.Body, - ) - if err != nil { - return entries, err - } - entries = append(entries, e) - } - return entries, nil -} - func (m *Mailbox) getMessagesResponse(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) //to, toSet := vars["timeout_ms"] @@ -289,8 +229,8 @@ func (m *Mailbox) getMessagesResponse(w http.ResponseWriter, r *http.Request) { // FIXME rate limit // FIXME timeout // FIXME possibly limit results here - m.checkPendingRegistrationUpdates(vars["h_mailbox"]) - entries, err := GetMessagesFromDatabase(m.DB, vars["h_mailbox"], int(m.MessageResponseLimit)) + m.checkPendingMailboxRegistrationUpdates(vars["h_mailbox"]) + entries, err := m.DB.GetMessages(vars["h_mailbox"], int(m.MessageResponseLimit)) if err != nil { m.Logf(LogError, "Error getting messages: %v", err) w.WriteHeader(http.StatusNotFound) @@ -327,7 +267,7 @@ func (m *Mailbox) sendMessageResponse(w http.ResponseWriter, r *http.Request) { } if !m.MessageFee.IsZero() { var count int64 - count, err = GetMessagesCountFromDatabase(m.DB, vars["h_mailbox"]) + count, err = m.DB.GetMessagesCount(vars["h_mailbox"]) if nil != err { m.Logf(LogError, "Error getting messages: %v", err) http.Error(w, "Cannot look for entries", http.StatusBadRequest) @@ -339,15 +279,15 @@ func (m *Mailbox) sendMessageResponse(w http.ResponseWriter, r *http.Request) { return } } - m.checkPendingRegistrationUpdates(vars["h_mailbox"]) - err = GetInboxEntryFromDatabaseBySigningKeyAndBody(m.DB, &entry, vars["h_mailbox"], body) + m.checkPendingMailboxRegistrationUpdates(vars["h_mailbox"]) + err = m.DB.GetInboxEntryBySigningKeyAndBody(&entry, vars["h_mailbox"], body) if err == nil { w.WriteHeader(http.StatusNotModified) return } entry.HashedSigningKey = vars["h_mailbox"] entry.Body = body - err = InsertInboxEntryIntoDatabase(m.DB, &entry) + err = m.DB.InsertInboxEntry(&entry) if err != nil { m.Logf(LogError, "Error storing message: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -359,8 +299,8 @@ func (m *Mailbox) sendMessageResponse(w http.ResponseWriter, r *http.Request) { func (m *Mailbox) getKeysResponse(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) var keyEntry MailboxMetadata - m.checkPendingRegistrationUpdates(vars["h_mailbox"]) - err := GetMailboxMetadataFromDatabaseBySigningKey(m.DB, &keyEntry, vars["h_mailbox"]) + m.checkPendingMailboxRegistrationUpdates(vars["h_mailbox"]) + err := m.DB.GetMailboxMetadataBySigningKey(&keyEntry, vars["h_mailbox"]) if err != nil { m.Logf(LogError, "Error finding mailbox: %v", err) w.WriteHeader(http.StatusNotFound) @@ -469,7 +409,7 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request reqExpiration := time.Unix(int64(msg.MailboxMetadata.Expiration.Seconds), 0) now := time.Now() reqDuration := reqExpiration.Sub(now).Round(monthDuration) - err = GetMailboxMetadataFromDatabaseBySigningKey(m.DB, &registrationEntry, hMailbox) + err = m.DB.GetMailboxMetadataBySigningKey(&registrationEntry, hMailbox) if err == nil { // This probably means the registration is modified or extended or both entryModified := (registrationEntry.EncryptionKey != msg.MailboxMetadata.EncryptionKey) @@ -486,25 +426,25 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request hAddr := sha512.New() hAddr.Write(pk) registrationEntry.HashedSigningKey = util.Base32CrockfordEncode(hAddr.Sum(nil)) - err = InsertMailboxRegistrationIntoDatabase(m.DB, &registrationEntry) + err = m.DB.InsertMailboxRegistration(&registrationEntry) if nil != err { m.Logf(LogError, "%v\n", err) w.WriteHeader(http.StatusInternalServerError) return } } - err = GetPendingRegistrationFromDatabaseBySigningKey(m.DB, &pendingRegistration, hMailbox) + err = m.DB.GetPendingMailboxRegistrationBySigningKey(&pendingRegistration, hMailbox) pendingRegistrationExists := (nil == err) if !pendingRegistrationExists { pendingRegistration.HashedSigningKey = hMailbox pendingRegistration.Duration = reqDuration.Microseconds() - err = InsertPendingRegistrationIntoDatabase(m.DB, &pendingRegistration) + err = m.DB.InsertPendingMailboxRegistration(&pendingRegistration) if nil != err { m.Logf(LogError, "Error inserting pending registration: %v\n", err) w.WriteHeader(http.StatusInternalServerError) return } - err = GetPendingRegistrationFromDatabaseBySigningKey(m.DB, &pendingRegistration, hMailbox) + err = m.DB.GetPendingMailboxRegistrationBySigningKey(&pendingRegistration, hMailbox) if nil != err { m.Logf(LogError, "Error getting pending registration: %v\n", err) w.WriteHeader(http.StatusInternalServerError) @@ -544,7 +484,7 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request return } if len(payto) != 0 { - err = UpdatePendingRegistrationOrderIdInDatabase(m.DB, &pendingRegistration) + err = m.DB.UpdatePendingMailboxRegistrationOrderId(&pendingRegistration) if err != nil { m.Logf(LogError, "Error updating pending registration: %v\n", err) w.WriteHeader(http.StatusInternalServerError) @@ -557,13 +497,13 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request } // Update expiration time of registration. registrationEntry.Expiration.Seconds += uint64(reqDuration.Seconds()) - _, err = DeletePendingRegistrationFromDatabase(m.DB, &pendingRegistration) + _, err = m.DB.DeletePendingRegistration(&pendingRegistration) if nil != err { m.Logf(LogError, "Error deleting pending registration: %v\n", err) w.WriteHeader(http.StatusInternalServerError) return } - err = UpdateMailboxExpirationInDatabase(m.DB, &registrationEntry) + err = m.DB.UpdateMailboxExpiration(&registrationEntry) if nil != err { m.Logf(LogError, "Error updating mailbox registration: %v\n", err) w.WriteHeader(http.StatusInternalServerError) @@ -572,10 +512,10 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request w.WriteHeader(http.StatusNoContent) } -func (m *Mailbox) checkPendingRegistrationUpdates(hMailbox string) { +func (m *Mailbox) checkPendingMailboxRegistrationUpdates(hMailbox string) { var pendingEntry PendingMailboxRegistration var registrationEntry MailboxMetadata - err := GetPendingRegistrationFromDatabaseBySigningKey(m.DB, &pendingEntry, hMailbox) + err := m.DB.GetPendingMailboxRegistrationBySigningKey(&pendingEntry, hMailbox) if err != nil { return } @@ -585,7 +525,7 @@ func (m *Mailbox) checkPendingRegistrationUpdates(hMailbox string) { if rc == http.StatusNotFound { m.Logf(LogInfo, "Registration order for `%s' not found, removing\n", hMailbox) } - _, err = DeletePendingRegistrationFromDatabase(m.DB, &pendingEntry) + _, err = m.DB.DeletePendingRegistration(&pendingEntry) if nil != err { m.Logf(LogInfo, "Error deleting pending registration: %v\n", err) } @@ -594,15 +534,15 @@ func (m *Mailbox) checkPendingRegistrationUpdates(hMailbox string) { m.Logf(LogDebug, "Order status for %s is %s", pendingEntry.HashedSigningKey, orderStatus) if merchant.OrderPaid == orderStatus { m.Logf(LogDebug, "Order for %v appears to be paid", pendingEntry) - err = GetMailboxMetadataFromDatabaseBySigningKey(m.DB, &registrationEntry, hMailbox) + err = m.DB.GetMailboxMetadataBySigningKey(&registrationEntry, hMailbox) if err == nil { m.Logf(LogDebug, "Adding %d seconds to entry expiration", pendingEntry.Duration) registrationEntry.Expiration.Seconds += uint64(pendingEntry.Duration) - err = UpdateMailboxExpirationInDatabase(m.DB, &registrationEntry) + err = m.DB.UpdateMailboxExpiration(&registrationEntry) if nil != err { m.Logf(LogInfo, "Error updating mailbox expiration: %v\n", err) } - _, err = DeletePendingRegistrationFromDatabase(m.DB, &pendingEntry) + _, err = m.DB.DeletePendingRegistration(&pendingEntry) if nil != err { m.Logf(LogInfo, "Error deleting pending registration: %v\n", err) } @@ -656,7 +596,7 @@ func (m *Mailbox) deleteMessagesResponse(w http.ResponseWriter, r *http.Request) h := sha512.New() h.Write(pkey) hMailbox := util.Base32CrockfordEncode(h.Sum(nil)) - m.checkPendingRegistrationUpdates(hMailbox) + m.checkPendingMailboxRegistrationUpdates(hMailbox) var signedMsg [4 * 4]byte binary.BigEndian.PutUint32(signedMsg[0:4], 4*4) binary.BigEndian.PutUint32(signedMsg[4:8], gana.TalerSignaturePurposeMailboxMessagesDelete) @@ -668,14 +608,14 @@ func (m *Mailbox) deleteMessagesResponse(w http.ResponseWriter, r *http.Request) } // Check that expectedETag actually exists var entry InboxEntry - err = GetInboxEntryFromDatabaseBySerial(m.DB, &entry, hMailbox, int64(expectedETag)) + err = m.DB.GetInboxEntryBySerial(&entry, hMailbox, int64(expectedETag)) if err != nil { m.Logf(LogDebug, "Message to delete not found with ID %d", expectedETag) w.WriteHeader(http.StatusNotFound) return } m.Logf(LogError, "Deleting from entry %v up to %d messages\n", entry, count) - num, err := DeleteInboxEntryFromDatabaseBySerial(m.DB, &entry, count) + num, err := m.DB.DeleteInboxEntryBySerial(&entry, count) if err != nil { m.Logf(LogDebug, "Failed to delete messages: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -742,7 +682,7 @@ func (m *Mailbox) Initialize(cfg MailboxConfig) { os.Exit(1) } m.MonthlyFee = monthlyFee - updateFee, err := cfg.Ini.GetAmount("mailbox", "registration_update_fee", &talerutil.Amount{}) + updateFee, err := cfg.Ini.GetAmount("mailbox", "registration_update_fee", &talerutil.Amount{}) if err != nil { fmt.Printf("Failed to parse update fee: %v", err) os.Exit(1) @@ -758,7 +698,7 @@ func (m *Mailbox) Initialize(cfg MailboxConfig) { m.DB = cfg.DB go func() { for { - num, err := DeleteStaleRegistrationsFromDatabase(m.DB, time.Now()) + num, err := m.DB.DeleteStaleRegistrations(time.Now()) if err != nil { m.Logf(LogDebug, "Error purging stale registrations: `%v'.\n", err) } @@ -767,14 +707,14 @@ func (m *Mailbox) Initialize(cfg MailboxConfig) { } }() // Clean up pending - pendingExp, err := cfg.Ini.GetDuration("mailbox", "pending_registration_expiration", 24 * time.Hour) + pendingExp, err := cfg.Ini.GetDuration("mailbox", "pending_registration_expiration", 24*time.Hour) if err != nil { fmt.Printf("Failed to parse pending registration expiration: %v", err) os.Exit(1) } go func() { for { - num, err := DeleteStalePendingRegistrationsFromDatabase(m.DB, time.Now().Add(-pendingExp)) + num, err := m.DB.DeleteStalePendingRegistrations(time.Now().Add(-pendingExp)) if err != nil { m.Logf(LogDebug, "Error purging stale registrations: `%v'.\n", err) }