diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go index c50cd1fd..9e3b7d6e 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go @@ -17,6 +17,8 @@ package accounts import ( "context" "database/sql" + + "github.com/matrix-org/gomatrixserverlib" ) const filterSchema = ` @@ -38,12 +40,16 @@ CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart) const selectFilterSQL = "" + "SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" +const selectFilterIDByContentSQL = "" + + "SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2" + const insertFilterSQL = "" + "INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id" type filterStatements struct { - selectFilterStmt *sql.Stmt - insertFilterStmt *sql.Stmt + selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt } func (s *filterStatements) prepare(db *sql.DB) (err error) { @@ -54,6 +60,9 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return } + if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { + return + } if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { return } @@ -62,14 +71,37 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { func (s *filterStatements) selectFilter( ctx context.Context, localpart string, filterID string, -) (filter string, err error) { +) (filter []byte, err error) { err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter) return } func (s *filterStatements) insertFilter( - ctx context.Context, filter string, localpart string, -) (pos string, err error) { - err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos) + ctx context.Context, filter []byte, localpart string, +) (filterID string, err error) { + var existingFilterID string + + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + filterJSON, err := gomatrixserverlib.CanonicalJSON(filter) + if err != nil { + return "", err + } + + // Check if filter already exists in the database + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil { + return "", err + } + // If it does, return the existing ID + if len(existingFilterID) != 0 { + return existingFilterID, err + } + + // Otherwise insert the filter and return the new ID + err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). + Scan(&filterID) return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 33fbbd86..d5712eb5 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -321,22 +321,25 @@ func (d *Database) GetThreePIDsForLocalpart( } // GetFilter looks up the filter associated with a given local user and filter ID. -// Returns an error if no such filter exists or if there was an error taling to the database. +// Returns a filter represented as a byte slice. Otherwise returns an error if +// no such filter exists or if there was an error talking to the database. func (d *Database) GetFilter( ctx context.Context, localpart string, filterID string, -) (string, error) { +) ([]byte, error) { return d.filter.selectFilter(ctx, localpart, filterID) } // PutFilter puts the passed filter into the database. -// Returns an error if something goes wrong. +// Returns the filterID as a string. Otherwise returns an error if something +// goes wrong. func (d *Database) PutFilter( - ctx context.Context, localpart, filter string, + ctx context.Context, localpart string, filter []byte, ) (string, error) { return d.filter.insertFilter(ctx, filter, localpart) } -// CheckAccountAvailability checks if the username/localpart is already present in the database. +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/filter.go b/src/github.com/matrix-org/dendrite/clientapi/routing/filter.go index 3c623147..4b84e293 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/filter.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/filter.go @@ -60,7 +60,7 @@ func GetFilter( } } filter := gomatrix.Filter{} - err = json.Unmarshal([]byte(res), &filter) + err = json.Unmarshal(res, &filter) if err != nil { httputil.LogThenError(req, err) } @@ -111,7 +111,7 @@ func PutFilter( } } - filterID, err := accountDB.PutFilter(req.Context(), localpart, string(filterArray)) + filterID, err := accountDB.PutFilter(req.Context(), localpart, filterArray) if err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 92c9b427..29e25764 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/common/config" - log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" @@ -37,6 +36,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" ) const (