commit f3bd8342954f4f15bc149c7464c4e5adcd9cbe2e
parent dc8e47b663891246cad3e60ac46fb20ef402e4f3
Author: Martin Schanzenbach <schanzen@gnunet.org>
Date: Sun, 22 Mar 2026 15:54:43 +0100
refactor for prepared statements
Diffstat:
4 files changed, 387 insertions(+), 243 deletions(-)
diff --git a/cmd/mailbox-server/main.go b/cmd/mailbox-server/main.go
@@ -19,7 +19,6 @@
package main
import (
- "database/sql"
"flag"
"fmt"
"log"
@@ -97,7 +96,7 @@ func main() {
}
psqlconn := iniCfg.GetString("mailbox-pq", "connection_string", "postgres:///taler-mailbox")
- db, err := sql.Open("postgres", psqlconn)
+ db, err := mailbox.OpenDatabase(psqlconn)
if err != nil {
log.Panic(err)
}
diff --git a/cmd/mailbox-server/main_test.go b/cmd/mailbox-server/main_test.go
@@ -5,7 +5,6 @@ import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha512"
- "database/sql"
"encoding/binary"
"encoding/json"
"fmt"
@@ -83,18 +82,12 @@ func TestMain(m *testing.M) {
os.Exit(1)
}
psqlconn := cfg.GetString("mailbox-pq", "connection_string", "postgres:///taler-mailbox")
- segments := strings.Split(strings.Split(psqlconn, "?")[0], "/")
- dbName := segments[len(segments)-1]
- db, err := sql.Open("postgres", psqlconn)
+ db, err := mailbox.OpenDatabase(psqlconn)
if err != nil {
log.Panic(err)
}
defer db.Close()
- err = talerutil.DBInit(db, "../..", dbName, "taler-mailbox")
- if err != nil {
- log.Fatalf("Failed to apply versioning or patches: %v", err)
- }
merchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var orderResp merchant.PostOrderRequest
if r.URL.Path == "/config" {
@@ -143,12 +136,12 @@ func TestMain(m *testing.M) {
code := m.Run()
// Purge DB
- mailbox.DeleteAllInboxEntriesFromDatabase(a.DB)
+ a.DB.DeleteAllInboxEntries()
os.Exit(code)
}
func TestEmptyMailbox(t *testing.T) {
- mailbox.DeleteAllInboxEntriesFromDatabase(a.DB)
+ a.DB.DeleteAllInboxEntries()
req, _ := http.NewRequest("GET", "/"+testAliceHashedSigningKeyString, nil)
response := executeRequest(req)
@@ -162,7 +155,7 @@ func TestEmptyMailbox(t *testing.T) {
func TestSendMessage(t *testing.T) {
testMessage := make([]byte, 256)
- mailbox.DeleteAllInboxEntriesFromDatabase(a.DB)
+ a.DB.DeleteAllInboxEntries()
req, _ := http.NewRequest("POST", "/"+testAliceHashedSigningKeyString, bytes.NewReader(testMessage))
response := executeRequest(req)
@@ -191,7 +184,7 @@ func TestSendMessagePaid(t *testing.T) {
setMailboxPaid(true)
// Cleanup
- mailbox.DeleteAllInboxEntriesFromDatabase(a.DB)
+ a.DB.DeleteAllInboxEntries()
testMessage := make([]byte, 256)
rand.Read(testMessage)
@@ -275,8 +268,8 @@ func TestMailboxRegistration(t *testing.T) {
if respMsg.EncryptionKey != msg.MailboxMetadata.EncryptionKey {
fmt.Printf("Keys mismatch! %v %v\n", respMsg, msg.MailboxMetadata)
}
- mailbox.DeleteAllMailboxesFromDatabase(a.DB)
- mailbox.DeleteAllPendingRegistrationsFromDatabase(a.DB)
+ a.DB.DeleteAllMailboxes()
+ a.DB.DeleteAllPendingRegistrations()
}
func createMailboxMetadata(encKey []byte, signKey []byte, info string) mailbox.MailboxMetadata {
@@ -344,8 +337,8 @@ func TestMailboxRegistrationWithInfo(t *testing.T) {
if respMsg.Info != "Hello World" {
fmt.Printf("Info field missing! %v %v\n", respMsg, msg.MailboxMetadata)
}
- mailbox.DeleteAllMailboxesFromDatabase(a.DB)
- mailbox.DeleteAllPendingRegistrationsFromDatabase(a.DB)
+ a.DB.DeleteAllMailboxes()
+ a.DB.DeleteAllPendingRegistrations()
}
func TestMailboxRegistrationPaid(t *testing.T) {
@@ -400,7 +393,7 @@ func TestPostThenDeleteMessage(t *testing.T) {
numMessagesToPost := (a.MessageResponseLimit + 7)
testMessages := make([]byte, 256*numMessagesToPost)
_, _ = rand.Read(testMessages)
- mailbox.DeleteAllInboxEntriesFromDatabase(a.DB)
+ a.DB.DeleteAllInboxEntries()
for i := 0; i < int(numMessagesToPost); i++ {
testMessage := testMessages[i*256 : (i+1)*256]
diff --git a/pkg/rest/db.go b/pkg/rest/db.go
@@ -22,8 +22,12 @@ import (
"context"
"database/sql"
"errors"
- _ "github.com/lib/pq"
+ "log"
+ "strings"
"time"
+
+ _ "github.com/lib/pq"
+ talerutil "github.com/schanzen/taler-go/pkg/util"
)
type Timestamp struct {
@@ -95,10 +99,251 @@ type InboxEntry struct {
HashedSigningKey string
}
-func InsertInboxEntryIntoDatabase(db *sql.DB, e *InboxEntry) error {
- query := `INSERT INTO taler_mailbox.inbox_entries
- VALUES (DEFAULT, $1, $2);`
- rows, err := db.Query(query, e.HashedSigningKey, e.Body)
+// MailboxDatabase is the main taldir database connection handle
+type MailboxDatabase struct {
+ // SQL connection
+ db *sql.DB
+
+ // Get mailbox by key
+ getMailboxMetadataBySigningKeyStmt *sql.Stmt
+
+ // Get entry by key and body
+ getInboxEntryBySigningKeyAndBodyStmt *sql.Stmt
+
+ // Get entry statement
+ getInboxEntryBySerialStmt *sql.Stmt
+
+ // Insert entry statement
+ insertInboxEntryStmt *sql.Stmt
+
+ // Delete entry statement
+ deleteInboxEntryBySerialStmt *sql.Stmt
+
+ // Insert pending registration
+ insertPendingRegistrationStmt *sql.Stmt
+
+ // Update pending registration oder ID
+ updatePendingRegistrationOrderIdStmt *sql.Stmt
+
+ // Delete pending registration
+ deletePendingRegistrationStmt *sql.Stmt
+
+ // Insert mailbox registration
+ insertMailboxRegistrationStmt *sql.Stmt
+
+ // Update the expiration of a Mailbox
+ updateMailboxExpirationStmt *sql.Stmt
+
+ // Get pending registration by signing key
+ getPendingMailboxRegistrationBySigningKeyStmt *sql.Stmt
+
+ // Delete stale registrations
+ deleteStaleRegistrationsStmt *sql.Stmt
+
+ // Delete stale pending registrations
+ deleteStalePendingRegistrationsStmt *sql.Stmt
+
+ // Get messages
+ getMessagesBySigningKeyStmt *sql.Stmt
+
+ // Count messages
+ countMessagesBySigningKeyStmt *sql.Stmt
+}
+
+func (db *MailboxDatabase) Close() {
+ for _, s := range []*sql.Stmt{
+ db.insertInboxEntryStmt,
+ db.getMailboxMetadataBySigningKeyStmt,
+ db.getInboxEntryBySerialStmt,
+ db.getInboxEntryBySigningKeyAndBodyStmt,
+ db.deleteInboxEntryBySerialStmt,
+ db.insertPendingRegistrationStmt,
+ db.updatePendingRegistrationOrderIdStmt,
+ db.deletePendingRegistrationStmt,
+ db.insertMailboxRegistrationStmt,
+ db.updateMailboxExpirationStmt,
+ db.getPendingMailboxRegistrationBySigningKeyStmt,
+ db.deleteStalePendingRegistrationsStmt,
+ db.deleteStaleRegistrationsStmt,
+ db.getMessagesBySigningKeyStmt,
+ db.countMessagesBySigningKeyStmt,
+ } {
+ if s != nil {
+ s.Close()
+ }
+ }
+ db.db.Close()
+}
+
+func OpenDatabase(psqlconn string) (*MailboxDatabase, error) {
+ db, err := sql.Open("postgres", psqlconn)
+ if err != nil {
+ return nil, err
+ }
+ segments := strings.Split(strings.Split(psqlconn, "?")[0], "/")
+ dbName := segments[len(segments)-1]
+
+ err = talerutil.DBInit(db, "../..", dbName, "taler-directory")
+ if err != nil {
+ log.Fatalf("Failed to apply versioning or patches: %v", err)
+ }
+ insertInboxEntryStmt, err := db.Prepare(`INSERT INTO taler_mailbox.inbox_entries
+ VALUES (DEFAULT, $1, $2);`)
+ if err != nil {
+ return nil, err
+ }
+ insertPendingRegistrationStmt, err := db.Prepare(`INSERT INTO taler_mailbox.pending_mailbox_registrations
+ VALUES (DEFAULT, $1, $2, $3, $4);`)
+ if err != nil {
+ return nil, err
+ }
+ insertMailboxRegistrationStmt, err := db.Prepare(`INSERT INTO taler_mailbox.mailbox_metadata
+ VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7);`)
+ if err != nil {
+ return nil, err
+ }
+ updatePendingRegistrationOrderIdStmt, err := db.Prepare(`UPDATE taler_mailbox.pending_mailbox_registrations
+ SET
+ "order_id" = $2
+ WHERE "hashed_signing_key" = $1;`)
+ if err != nil {
+ return nil, err
+ }
+ updateMailboxExpirationStmt, err := db.Prepare(`UPDATE taler_mailbox.mailbox_metadata
+ SET
+ "expiration" = $2
+ WHERE "hashed_signing_key" = $1;`)
+ getPendingRegistrationBySingingKeyStmt, err := db.Prepare(`SELECT
+ "serial",
+ "hashed_signing_key",
+ "registration_duration",
+ "order_id"
+ FROM taler_mailbox.pending_mailbox_registrations
+ WHERE
+ "hashed_signing_key"=$1
+ LIMIT 1
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ getMailboxMetadataBySigningKey, err := db.Prepare(`SELECT
+ "serial",
+ "hashed_signing_key",
+ "signing_key",
+ "signing_key_type",
+ "encryption_key",
+ "encryption_key_type",
+ "expiration",
+ "info"
+ FROM taler_mailbox.mailbox_metadata
+ WHERE
+ "hashed_signing_key"=$1
+ LIMIT 1
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ getInboxEntryBySigningKeyAndBodyStmt, err := db.Prepare(`SELECT
+ "serial",
+ "hashed_signing_key",
+ "body"
+ FROM taler_mailbox.inbox_entries
+ WHERE
+ "hashed_signing_key"=$1 AND
+ "body"=$2
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ getInboxEntryBySerialStmt, err := db.Prepare(`SELECT
+ "serial",
+ "hashed_signing_key",
+ "body"
+ FROM taler_mailbox.inbox_entries
+ WHERE
+ "serial"=$1 AND
+ "hashed_signing_key"=$2
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ deletePendingRegistrationStmt, err := db.Prepare(`DELETE
+ FROM taler_mailbox.pending_mailbox_registrations
+ WHERE
+ "serial" = $1
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ deleteInboxEntryBySerialStmt, err := db.Prepare(`DELETE FROM taler_mailbox.inbox_entries
+ WHERE serial IN (
+ SELECT serial FROM taler_mailbox.inbox_entries
+ WHERE
+ "hashed_signing_key"=$1 AND
+ "serial">=$2
+ LIMIT $3
+ );`)
+ if err != nil {
+ return nil, err
+ }
+ deleteStaleRegistrationsStmt, err := db.Prepare(`DELETE
+ FROM taler_mailbox.mailbox_metadata
+ WHERE
+ "expiration" < $1
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ deleteStalePendingRegistrationsStmt, err := db.Prepare(`DELETE
+ FROM taler_mailbox.pending_mailbox_registrations
+ WHERE
+ "created_at" < $1
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ getMessagesBySigningKeyStmt, err := db.Prepare(`SELECT
+ "serial",
+ "hashed_signing_key",
+ "body"
+ FROM taler_mailbox.inbox_entries
+ WHERE
+ "hashed_signing_key" = $1
+ LIMIT $2
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ countMessagesBySigningKeyStmt, err := db.Prepare(`SELECT COUNT(*) AS num_messages
+ FROM taler_mailbox.inbox_entries
+ WHERE
+ "hashed_signing_key"=$1
+ ;`)
+ if err != nil {
+ return nil, err
+ }
+ return &MailboxDatabase{
+ db: db,
+ insertInboxEntryStmt: insertInboxEntryStmt,
+ insertPendingRegistrationStmt: insertPendingRegistrationStmt,
+ insertMailboxRegistrationStmt: insertMailboxRegistrationStmt,
+ updatePendingRegistrationOrderIdStmt: updatePendingRegistrationOrderIdStmt,
+ updateMailboxExpirationStmt: updateMailboxExpirationStmt,
+ getPendingMailboxRegistrationBySigningKeyStmt: getPendingRegistrationBySingingKeyStmt,
+ getMailboxMetadataBySigningKeyStmt: getMailboxMetadataBySigningKey,
+ getInboxEntryBySigningKeyAndBodyStmt: getInboxEntryBySigningKeyAndBodyStmt,
+ getInboxEntryBySerialStmt: getInboxEntryBySerialStmt,
+ deletePendingRegistrationStmt: deletePendingRegistrationStmt,
+ deleteInboxEntryBySerialStmt: deleteInboxEntryBySerialStmt,
+ deleteStaleRegistrationsStmt: deleteStaleRegistrationsStmt,
+ deleteStalePendingRegistrationsStmt: deleteStalePendingRegistrationsStmt,
+ getMessagesBySigningKeyStmt: getMessagesBySigningKeyStmt,
+ countMessagesBySigningKeyStmt: countMessagesBySigningKeyStmt,
+ }, nil
+}
+
+func (db *MailboxDatabase) InsertInboxEntry(e *InboxEntry) error {
+ rows, err := db.insertInboxEntryStmt.Query(e.HashedSigningKey, e.Body)
if err != nil {
return err
}
@@ -106,11 +351,9 @@ func InsertInboxEntryIntoDatabase(db *sql.DB, e *InboxEntry) error {
return nil
}
-func InsertPendingRegistrationIntoDatabase(db *sql.DB, pr *PendingMailboxRegistration) error {
+func (db *MailboxDatabase) InsertPendingMailboxRegistration(pr *PendingMailboxRegistration) error {
pr.CreatedAt = time.Now().Unix()
- query := `INSERT INTO taler_mailbox.pending_mailbox_registrations
- VALUES (DEFAULT, $1, $2, $3, $4);`
- rows, err := db.Query(query, pr.CreatedAt, pr.HashedSigningKey, pr.OrderID, pr.Duration)
+ rows, err := db.insertPendingRegistrationStmt.Query(pr.CreatedAt, pr.HashedSigningKey, pr.OrderID, pr.Duration)
if err != nil {
return err
}
@@ -118,10 +361,8 @@ func InsertPendingRegistrationIntoDatabase(db *sql.DB, pr *PendingMailboxRegistr
return nil
}
-func InsertMailboxRegistrationIntoDatabase(db *sql.DB, mb *MailboxMetadata) error {
- query := `INSERT INTO taler_mailbox.mailbox_metadata
- VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7);`
- rows, err := db.Query(query, mb.HashedSigningKey, mb.SigningKey, mb.SigningKeyType, mb.EncryptionKey, mb.EncryptionKeyType, mb.Expiration.Seconds, mb.Info)
+func (db *MailboxDatabase) InsertMailboxRegistration(mb *MailboxMetadata) error {
+ rows, err := db.insertMailboxRegistrationStmt.Query(mb.HashedSigningKey, mb.SigningKey, mb.SigningKeyType, mb.EncryptionKey, mb.EncryptionKeyType, mb.Expiration.Seconds, mb.Info)
if err != nil {
return err
}
@@ -129,12 +370,8 @@ func InsertMailboxRegistrationIntoDatabase(db *sql.DB, mb *MailboxMetadata) erro
return nil
}
-func UpdatePendingRegistrationOrderIdInDatabase(db *sql.DB, pr *PendingMailboxRegistration) error {
- query := `UPDATE taler_mailbox.pending_mailbox_registrations
- SET
- "order_id" = $2
- WHERE "hashed_signing_key" = $1;`
- rows, err := db.Query(query, pr.HashedSigningKey, pr.OrderID)
+func (db *MailboxDatabase) UpdatePendingMailboxRegistrationOrderId(pr *PendingMailboxRegistration) error {
+ rows, err := db.updatePendingRegistrationOrderIdStmt.Query(pr.HashedSigningKey, pr.OrderID)
if err != nil {
return err
}
@@ -142,12 +379,8 @@ func UpdatePendingRegistrationOrderIdInDatabase(db *sql.DB, pr *PendingMailboxRe
return nil
}
-func UpdateMailboxExpirationInDatabase(db *sql.DB, mb *MailboxMetadata) error {
- query := `UPDATE taler_mailbox.mailbox_metadata
- SET
- "expiration" = $2
- WHERE "hashed_signing_key" = $1;`
- rows, err := db.Query(query, mb.HashedSigningKey, mb.Expiration.Seconds)
+func (db *MailboxDatabase) UpdateMailboxExpiration(mb *MailboxMetadata) error {
+ rows, err := db.updateMailboxExpirationStmt.Query(mb.HashedSigningKey, mb.Expiration.Seconds)
if err != nil {
return err
}
@@ -155,19 +388,9 @@ func UpdateMailboxExpirationInDatabase(db *sql.DB, mb *MailboxMetadata) error {
return nil
}
-func GetPendingRegistrationFromDatabaseBySigningKey(db *sql.DB, pr *PendingMailboxRegistration, hashedKey string) error {
- query := `SELECT
- "serial",
- "hashed_signing_key",
- "registration_duration",
- "order_id"
- FROM taler_mailbox.pending_mailbox_registrations
- WHERE
- "hashed_signing_key"=$1
- LIMIT 1
- ;`
+func (db *MailboxDatabase) GetPendingMailboxRegistrationBySigningKey(pr *PendingMailboxRegistration, hashedKey string) error {
// Execute Query
- rows, err := db.Query(query, hashedKey)
+ rows, err := db.getPendingMailboxRegistrationBySigningKeyStmt.Query(hashedKey)
if err != nil {
return err
}
@@ -184,23 +407,9 @@ func GetPendingRegistrationFromDatabaseBySigningKey(db *sql.DB, pr *PendingMailb
)
}
-func GetMailboxMetadataFromDatabaseBySigningKey(db *sql.DB, mb *MailboxMetadata, hashedKey string) error {
- query := `SELECT
- "serial",
- "hashed_signing_key",
- "signing_key",
- "signing_key_type",
- "encryption_key",
- "encryption_key_type",
- "expiration",
- "info"
- FROM taler_mailbox.mailbox_metadata
- WHERE
- "hashed_signing_key"=$1
- LIMIT 1
- ;`
+func (db *MailboxDatabase) GetMailboxMetadataBySigningKey(mb *MailboxMetadata, hashedKey string) error {
// Execute Query
- rows, err := db.Query(query, hashedKey)
+ rows, err := db.getMailboxMetadataBySigningKeyStmt.Query(hashedKey)
if err != nil {
return err
}
@@ -221,18 +430,9 @@ func GetMailboxMetadataFromDatabaseBySigningKey(db *sql.DB, mb *MailboxMetadata,
)
}
-func GetInboxEntryFromDatabaseBySigningKeyAndBody(db *sql.DB, e *InboxEntry, hashedKey string, body []byte) error {
- query := `SELECT
- "serial",
- "hashed_signing_key",
- "body"
- FROM taler_mailbox.inbox_entries
- WHERE
- "hashed_signing_key"=$1 AND
- "body"=$2
- ;`
+func (db *MailboxDatabase) GetInboxEntryBySigningKeyAndBody(e *InboxEntry, hashedKey string, body []byte) error {
// Execute Query
- rows, err := db.Query(query, hashedKey, body)
+ rows, err := db.getInboxEntryBySigningKeyAndBodyStmt.Query(hashedKey, body)
if err != nil {
return err
}
@@ -248,18 +448,9 @@ func GetInboxEntryFromDatabaseBySigningKeyAndBody(db *sql.DB, e *InboxEntry, has
)
}
-func GetInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, hashedKey string, serial int64) error {
- query := `SELECT
- "serial",
- "hashed_signing_key",
- "body"
- FROM taler_mailbox.inbox_entries
- WHERE
- "serial"=$1 AND
- "hashed_signing_key"=$2
- ;`
+func (db *MailboxDatabase) GetInboxEntryBySerial(e *InboxEntry, hashedKey string, serial int64) error {
// Execute Query
- rows, err := db.Query(query, serial, hashedKey)
+ rows, err := db.getInboxEntryBySerialStmt.Query(serial, hashedKey)
if err != nil {
return err
}
@@ -275,22 +466,17 @@ func GetInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, hashedKey stri
)
}
-func DeletePendingRegistrationFromDatabase(db *sql.DB, pr *PendingMailboxRegistration) (int64, error) {
+func (db *MailboxDatabase) DeletePendingRegistration(pr *PendingMailboxRegistration) (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
- query := `DELETE
- FROM taler_mailbox.pending_mailbox_registrations
- WHERE
- "serial" = $1
- ;`
// Execute Query
- result, err := conn.ExecContext(ctx, query, pr.Serial)
+ result, err := db.deletePendingRegistrationStmt.ExecContext(ctx, pr.Serial)
if err != nil {
return 0, err
}
@@ -301,27 +487,18 @@ func DeletePendingRegistrationFromDatabase(db *sql.DB, pr *PendingMailboxRegistr
return rows, nil
}
-// DeleteInboxEntryFromDatabaseBySerial Deletes all entries starting from given serial
-func DeleteInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, count int) (int64, error) {
+// DeleteInboxEntryBySerial Deletes all entries starting from given serial
+func (db *MailboxDatabase) DeleteInboxEntryBySerial(e *InboxEntry, count int) (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
- query := `DELETE FROM taler_mailbox.inbox_entries
- WHERE serial IN (
- SELECT serial FROM taler_mailbox.inbox_entries
- WHERE
- "hashed_signing_key"=$1 AND
- "serial">=$2
- LIMIT $3
- )
- ;`
// Execute Query
- result, err := conn.ExecContext(ctx, query, e.HashedSigningKey, e.Serial, count)
+ result, err := db.deleteInboxEntryBySerialStmt.ExecContext(ctx, e.HashedSigningKey, e.Serial, count)
if err != nil {
return 0, err
}
@@ -332,23 +509,18 @@ func DeleteInboxEntryFromDatabaseBySerial(db *sql.DB, e *InboxEntry, count int)
return rows, nil
}
-// DeleteStaleRegstrationsFromDatabase purges stale registrations
-func DeleteStaleRegistrationsFromDatabase(db *sql.DB, registrationExpiration time.Time) (int64, error) {
+// DeleteStaleRegstrations purges stale registrations
+func (db *MailboxDatabase) DeleteStaleRegistrations(registrationExpiration time.Time) (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
- query := `DELETE
- FROM taler_mailbox.mailbox_metadata
- WHERE
- "expiration" < $1
- ;`
// Execute Query
- result, err := conn.ExecContext(ctx, query, registrationExpiration.Unix())
+ result, err := db.deleteStaleRegistrationsStmt.ExecContext(ctx, registrationExpiration.Unix())
if err != nil {
return 0, err
}
@@ -359,23 +531,18 @@ func DeleteStaleRegistrationsFromDatabase(db *sql.DB, registrationExpiration tim
return rows, nil
}
-// DeleteStalePendingRegstrationsFromDatabase purges stale registrations
-func DeleteStalePendingRegistrationsFromDatabase(db *sql.DB, registrationExpiration time.Time) (int64, error) {
+// DeleteStalePendingRegstrations purges stale registrations
+func (db *MailboxDatabase) DeleteStalePendingRegistrations(registrationExpiration time.Time) (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
- query := `DELETE
- FROM taler_mailbox.pending_mailbox_registrations
- WHERE
- "created_at" < $1
- ;`
// Execute Query
- result, err := conn.ExecContext(ctx, query, registrationExpiration.Unix())
+ result, err := db.deleteStalePendingRegistrationsStmt.ExecContext(ctx, registrationExpiration.Unix())
if err != nil {
return 0, err
}
@@ -386,20 +553,20 @@ func DeleteStalePendingRegistrationsFromDatabase(db *sql.DB, registrationExpirat
return rows, nil
}
-func DeleteAllPendingRegistrationsFromDatabase(db *sql.DB) (int64, error) {
+func (db *MailboxDatabase) DeleteAllPendingRegistrations() (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
query := `DELETE
- FROM taler_mailbox.pending_mailbox_registrations
- WHERE
- 1=1
- ;`
+ FROM taler_mailbox.pending_mailbox_registrations
+ WHERE
+ 1=1
+ ;`
// Execute Query
result, err := conn.ExecContext(ctx, query)
if err != nil {
@@ -412,20 +579,20 @@ func DeleteAllPendingRegistrationsFromDatabase(db *sql.DB) (int64, error) {
return rows, nil
}
-func DeleteAllMailboxesFromDatabase(db *sql.DB) (int64, error) {
+func (db *MailboxDatabase) DeleteAllMailboxes() (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
query := `DELETE
- FROM taler_mailbox.mailbox_metadata
- WHERE
- 1=1
- ;`
+ FROM taler_mailbox.mailbox_metadata
+ WHERE
+ 1=1
+ ;`
// Execute Query
result, err := conn.ExecContext(ctx, query)
if err != nil {
@@ -438,20 +605,20 @@ func DeleteAllMailboxesFromDatabase(db *sql.DB) (int64, error) {
return rows, nil
}
-func DeleteAllInboxEntriesFromDatabase(db *sql.DB) (int64, error) {
+func (db *MailboxDatabase) DeleteAllInboxEntries() (int64, error) {
var ctx context.Context
ctx, stop := context.WithCancel(context.Background())
defer stop()
- conn, err := db.Conn(ctx)
+ conn, err := db.db.Conn(ctx)
if err != nil {
return 0, err
}
defer conn.Close()
query := `DELETE
- FROM taler_mailbox.inbox_entries
- WHERE
- 1=1
- ;`
+ FROM taler_mailbox.inbox_entries
+ WHERE
+ 1=1
+ ;`
// Execute Query
result, err := conn.ExecContext(ctx, query)
if err != nil {
@@ -463,3 +630,48 @@ func DeleteAllInboxEntriesFromDatabase(db *sql.DB) (int64, error) {
}
return rows, nil
}
+
+// Get Hash-salted alias from database
+func (db *MailboxDatabase) GetMessages(hashedKey string, limit int) ([]InboxEntry, error) {
+ // Execute Query
+ rows, err := db.getMessagesBySigningKeyStmt.Query(hashedKey, limit)
+ if err != nil {
+ return []InboxEntry{}, err
+ }
+ defer rows.Close()
+ var entries = make([]InboxEntry, 0)
+ for rows.Next() {
+ var e InboxEntry
+ err = rows.Scan(
+ &e.Serial,
+ &e.HashedSigningKey,
+ &e.Body,
+ )
+ if err != nil {
+ return entries, err
+ }
+ entries = append(entries, e)
+ }
+ return entries, nil
+}
+
+// Get Hash-salted alias from database
+func (db *MailboxDatabase) GetMessagesCount(hashedKey string) (int64, error) {
+ // Execute Query
+ rows, err := db.countMessagesBySigningKeyStmt.Query(hashedKey)
+ if err != nil {
+ return 0, err
+ }
+ defer rows.Close()
+ if !rows.Next() {
+ return 0, nil
+ }
+ var res int64
+ err = rows.Scan(
+ &res,
+ )
+ if err != nil {
+ return 0, err
+ }
+ return res, nil
+}
diff --git a/pkg/rest/mailbox.go b/pkg/rest/mailbox.go
@@ -22,7 +22,6 @@ package mailbox
import (
"crypto/ed25519"
"crypto/sha512"
- "database/sql"
"encoding/binary"
"encoding/json"
"fmt"
@@ -74,7 +73,7 @@ type MailboxConfig struct {
Ini talerutil.TalerConfiguration
// The database connection to use
- DB *sql.DB
+ DB *MailboxDatabase
// Merchant connection
Merchant merchant.Merchant
@@ -90,7 +89,7 @@ type Mailbox struct {
Router *mux.Router
// The database connection to use
- DB *sql.DB
+ DB *MailboxDatabase
// Our configuration from the ini
Cfg MailboxConfig
@@ -205,7 +204,7 @@ type MailboxRateLimitedResponse struct {
}
func (m *Mailbox) configResponse(w http.ResponseWriter, r *http.Request) {
- dp, err := m.Cfg.Ini.GetDuration("mailbox", "delivery_period", 3 * 24 * time.Hour)
+ dp, err := m.Cfg.Ini.GetDuration("mailbox", "delivery_period", 3*24*time.Hour)
if err != nil {
log.Fatal(err)
}
@@ -223,65 +222,6 @@ func (m *Mailbox) configResponse(w http.ResponseWriter, r *http.Request) {
w.Write(response)
}
-// Get Hash-salted alias from database
-func GetMessagesCountFromDatabase(db *sql.DB, hashedKey string) (int64, error) {
- query := `SELECT COUNT(*) AS num_messages
- FROM taler_mailbox.inbox_entries
- WHERE
- "hashed_signing_key"=$1
- ;`
- // Execute Query
- rows, err := db.Query(query, hashedKey)
- if err != nil {
- return 0, err
- }
- defer rows.Close()
- if !rows.Next() {
- return 0, nil
- }
- var res int64
- err = rows.Scan(
- &res,
- )
- if err != nil {
- return 0, err
- }
- return res, nil
-}
-
-// Get Hash-salted alias from database
-func GetMessagesFromDatabase(db *sql.DB, hashedKey string, limit int) ([]InboxEntry, error) {
- query := `SELECT
- "serial",
- "hashed_signing_key",
- "body"
- FROM taler_mailbox.inbox_entries
- WHERE
- "hashed_signing_key" = $1
- LIMIT $2
- ;`
- // Execute Query
- rows, err := db.Query(query, hashedKey, limit)
- if err != nil {
- return []InboxEntry{}, err
- }
- defer rows.Close()
- var entries = make([]InboxEntry, 0)
- for rows.Next() {
- var e InboxEntry
- err = rows.Scan(
- &e.Serial,
- &e.HashedSigningKey,
- &e.Body,
- )
- if err != nil {
- return entries, err
- }
- entries = append(entries, e)
- }
- return entries, nil
-}
-
func (m *Mailbox) getMessagesResponse(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
//to, toSet := vars["timeout_ms"]
@@ -289,8 +229,8 @@ func (m *Mailbox) getMessagesResponse(w http.ResponseWriter, r *http.Request) {
// FIXME rate limit
// FIXME timeout
// FIXME possibly limit results here
- m.checkPendingRegistrationUpdates(vars["h_mailbox"])
- entries, err := GetMessagesFromDatabase(m.DB, vars["h_mailbox"], int(m.MessageResponseLimit))
+ m.checkPendingMailboxRegistrationUpdates(vars["h_mailbox"])
+ entries, err := m.DB.GetMessages(vars["h_mailbox"], int(m.MessageResponseLimit))
if err != nil {
m.Logf(LogError, "Error getting messages: %v", err)
w.WriteHeader(http.StatusNotFound)
@@ -327,7 +267,7 @@ func (m *Mailbox) sendMessageResponse(w http.ResponseWriter, r *http.Request) {
}
if !m.MessageFee.IsZero() {
var count int64
- count, err = GetMessagesCountFromDatabase(m.DB, vars["h_mailbox"])
+ count, err = m.DB.GetMessagesCount(vars["h_mailbox"])
if nil != err {
m.Logf(LogError, "Error getting messages: %v", err)
http.Error(w, "Cannot look for entries", http.StatusBadRequest)
@@ -339,15 +279,15 @@ func (m *Mailbox) sendMessageResponse(w http.ResponseWriter, r *http.Request) {
return
}
}
- m.checkPendingRegistrationUpdates(vars["h_mailbox"])
- err = GetInboxEntryFromDatabaseBySigningKeyAndBody(m.DB, &entry, vars["h_mailbox"], body)
+ m.checkPendingMailboxRegistrationUpdates(vars["h_mailbox"])
+ err = m.DB.GetInboxEntryBySigningKeyAndBody(&entry, vars["h_mailbox"], body)
if err == nil {
w.WriteHeader(http.StatusNotModified)
return
}
entry.HashedSigningKey = vars["h_mailbox"]
entry.Body = body
- err = InsertInboxEntryIntoDatabase(m.DB, &entry)
+ err = m.DB.InsertInboxEntry(&entry)
if err != nil {
m.Logf(LogError, "Error storing message: %v", err)
w.WriteHeader(http.StatusInternalServerError)
@@ -359,8 +299,8 @@ func (m *Mailbox) sendMessageResponse(w http.ResponseWriter, r *http.Request) {
func (m *Mailbox) getKeysResponse(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
var keyEntry MailboxMetadata
- m.checkPendingRegistrationUpdates(vars["h_mailbox"])
- err := GetMailboxMetadataFromDatabaseBySigningKey(m.DB, &keyEntry, vars["h_mailbox"])
+ m.checkPendingMailboxRegistrationUpdates(vars["h_mailbox"])
+ err := m.DB.GetMailboxMetadataBySigningKey(&keyEntry, vars["h_mailbox"])
if err != nil {
m.Logf(LogError, "Error finding mailbox: %v", err)
w.WriteHeader(http.StatusNotFound)
@@ -469,7 +409,7 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request
reqExpiration := time.Unix(int64(msg.MailboxMetadata.Expiration.Seconds), 0)
now := time.Now()
reqDuration := reqExpiration.Sub(now).Round(monthDuration)
- err = GetMailboxMetadataFromDatabaseBySigningKey(m.DB, ®istrationEntry, hMailbox)
+ err = m.DB.GetMailboxMetadataBySigningKey(®istrationEntry, hMailbox)
if err == nil {
// This probably means the registration is modified or extended or both
entryModified := (registrationEntry.EncryptionKey != msg.MailboxMetadata.EncryptionKey)
@@ -486,25 +426,25 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request
hAddr := sha512.New()
hAddr.Write(pk)
registrationEntry.HashedSigningKey = util.Base32CrockfordEncode(hAddr.Sum(nil))
- err = InsertMailboxRegistrationIntoDatabase(m.DB, ®istrationEntry)
+ err = m.DB.InsertMailboxRegistration(®istrationEntry)
if nil != err {
m.Logf(LogError, "%v\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
- err = GetPendingRegistrationFromDatabaseBySigningKey(m.DB, &pendingRegistration, hMailbox)
+ err = m.DB.GetPendingMailboxRegistrationBySigningKey(&pendingRegistration, hMailbox)
pendingRegistrationExists := (nil == err)
if !pendingRegistrationExists {
pendingRegistration.HashedSigningKey = hMailbox
pendingRegistration.Duration = reqDuration.Microseconds()
- err = InsertPendingRegistrationIntoDatabase(m.DB, &pendingRegistration)
+ err = m.DB.InsertPendingMailboxRegistration(&pendingRegistration)
if nil != err {
m.Logf(LogError, "Error inserting pending registration: %v\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
- err = GetPendingRegistrationFromDatabaseBySigningKey(m.DB, &pendingRegistration, hMailbox)
+ err = m.DB.GetPendingMailboxRegistrationBySigningKey(&pendingRegistration, hMailbox)
if nil != err {
m.Logf(LogError, "Error getting pending registration: %v\n", err)
w.WriteHeader(http.StatusInternalServerError)
@@ -544,7 +484,7 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request
return
}
if len(payto) != 0 {
- err = UpdatePendingRegistrationOrderIdInDatabase(m.DB, &pendingRegistration)
+ err = m.DB.UpdatePendingMailboxRegistrationOrderId(&pendingRegistration)
if err != nil {
m.Logf(LogError, "Error updating pending registration: %v\n", err)
w.WriteHeader(http.StatusInternalServerError)
@@ -557,13 +497,13 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request
}
// Update expiration time of registration.
registrationEntry.Expiration.Seconds += uint64(reqDuration.Seconds())
- _, err = DeletePendingRegistrationFromDatabase(m.DB, &pendingRegistration)
+ _, err = m.DB.DeletePendingRegistration(&pendingRegistration)
if nil != err {
m.Logf(LogError, "Error deleting pending registration: %v\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
- err = UpdateMailboxExpirationInDatabase(m.DB, ®istrationEntry)
+ err = m.DB.UpdateMailboxExpiration(®istrationEntry)
if nil != err {
m.Logf(LogError, "Error updating mailbox registration: %v\n", err)
w.WriteHeader(http.StatusInternalServerError)
@@ -572,10 +512,10 @@ func (m *Mailbox) registerMailboxResponse(w http.ResponseWriter, r *http.Request
w.WriteHeader(http.StatusNoContent)
}
-func (m *Mailbox) checkPendingRegistrationUpdates(hMailbox string) {
+func (m *Mailbox) checkPendingMailboxRegistrationUpdates(hMailbox string) {
var pendingEntry PendingMailboxRegistration
var registrationEntry MailboxMetadata
- err := GetPendingRegistrationFromDatabaseBySigningKey(m.DB, &pendingEntry, hMailbox)
+ err := m.DB.GetPendingMailboxRegistrationBySigningKey(&pendingEntry, hMailbox)
if err != nil {
return
}
@@ -585,7 +525,7 @@ func (m *Mailbox) checkPendingRegistrationUpdates(hMailbox string) {
if rc == http.StatusNotFound {
m.Logf(LogInfo, "Registration order for `%s' not found, removing\n", hMailbox)
}
- _, err = DeletePendingRegistrationFromDatabase(m.DB, &pendingEntry)
+ _, err = m.DB.DeletePendingRegistration(&pendingEntry)
if nil != err {
m.Logf(LogInfo, "Error deleting pending registration: %v\n", err)
}
@@ -594,15 +534,15 @@ func (m *Mailbox) checkPendingRegistrationUpdates(hMailbox string) {
m.Logf(LogDebug, "Order status for %s is %s", pendingEntry.HashedSigningKey, orderStatus)
if merchant.OrderPaid == orderStatus {
m.Logf(LogDebug, "Order for %v appears to be paid", pendingEntry)
- err = GetMailboxMetadataFromDatabaseBySigningKey(m.DB, ®istrationEntry, hMailbox)
+ err = m.DB.GetMailboxMetadataBySigningKey(®istrationEntry, hMailbox)
if err == nil {
m.Logf(LogDebug, "Adding %d seconds to entry expiration", pendingEntry.Duration)
registrationEntry.Expiration.Seconds += uint64(pendingEntry.Duration)
- err = UpdateMailboxExpirationInDatabase(m.DB, ®istrationEntry)
+ err = m.DB.UpdateMailboxExpiration(®istrationEntry)
if nil != err {
m.Logf(LogInfo, "Error updating mailbox expiration: %v\n", err)
}
- _, err = DeletePendingRegistrationFromDatabase(m.DB, &pendingEntry)
+ _, err = m.DB.DeletePendingRegistration(&pendingEntry)
if nil != err {
m.Logf(LogInfo, "Error deleting pending registration: %v\n", err)
}
@@ -656,7 +596,7 @@ func (m *Mailbox) deleteMessagesResponse(w http.ResponseWriter, r *http.Request)
h := sha512.New()
h.Write(pkey)
hMailbox := util.Base32CrockfordEncode(h.Sum(nil))
- m.checkPendingRegistrationUpdates(hMailbox)
+ m.checkPendingMailboxRegistrationUpdates(hMailbox)
var signedMsg [4 * 4]byte
binary.BigEndian.PutUint32(signedMsg[0:4], 4*4)
binary.BigEndian.PutUint32(signedMsg[4:8], gana.TalerSignaturePurposeMailboxMessagesDelete)
@@ -668,14 +608,14 @@ func (m *Mailbox) deleteMessagesResponse(w http.ResponseWriter, r *http.Request)
}
// Check that expectedETag actually exists
var entry InboxEntry
- err = GetInboxEntryFromDatabaseBySerial(m.DB, &entry, hMailbox, int64(expectedETag))
+ err = m.DB.GetInboxEntryBySerial(&entry, hMailbox, int64(expectedETag))
if err != nil {
m.Logf(LogDebug, "Message to delete not found with ID %d", expectedETag)
w.WriteHeader(http.StatusNotFound)
return
}
m.Logf(LogError, "Deleting from entry %v up to %d messages\n", entry, count)
- num, err := DeleteInboxEntryFromDatabaseBySerial(m.DB, &entry, count)
+ num, err := m.DB.DeleteInboxEntryBySerial(&entry, count)
if err != nil {
m.Logf(LogDebug, "Failed to delete messages: %v", err)
w.WriteHeader(http.StatusInternalServerError)
@@ -742,7 +682,7 @@ func (m *Mailbox) Initialize(cfg MailboxConfig) {
os.Exit(1)
}
m.MonthlyFee = monthlyFee
- updateFee, err := cfg.Ini.GetAmount("mailbox", "registration_update_fee", &talerutil.Amount{})
+ updateFee, err := cfg.Ini.GetAmount("mailbox", "registration_update_fee", &talerutil.Amount{})
if err != nil {
fmt.Printf("Failed to parse update fee: %v", err)
os.Exit(1)
@@ -758,7 +698,7 @@ func (m *Mailbox) Initialize(cfg MailboxConfig) {
m.DB = cfg.DB
go func() {
for {
- num, err := DeleteStaleRegistrationsFromDatabase(m.DB, time.Now())
+ num, err := m.DB.DeleteStaleRegistrations(time.Now())
if err != nil {
m.Logf(LogDebug, "Error purging stale registrations: `%v'.\n", err)
}
@@ -767,14 +707,14 @@ func (m *Mailbox) Initialize(cfg MailboxConfig) {
}
}()
// Clean up pending
- pendingExp, err := cfg.Ini.GetDuration("mailbox", "pending_registration_expiration", 24 * time.Hour)
+ pendingExp, err := cfg.Ini.GetDuration("mailbox", "pending_registration_expiration", 24*time.Hour)
if err != nil {
fmt.Printf("Failed to parse pending registration expiration: %v", err)
os.Exit(1)
}
go func() {
for {
- num, err := DeleteStalePendingRegistrationsFromDatabase(m.DB, time.Now().Add(-pendingExp))
+ num, err := m.DB.DeleteStalePendingRegistrations(time.Now().Add(-pendingExp))
if err != nil {
m.Logf(LogDebug, "Error purging stale registrations: `%v'.\n", err)
}