From 238646ee3c8b684f1734c554178550b0fc07199f Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 18 Sep 2017 15:51:26 +0100 Subject: [PATCH] Add contexts to device database (#233) * Add contexts to device database * Remove spurious whitespace --- hooks/pre-commit | 16 +++++---- .../dendrite/clientapi/auth/auth.go | 5 +-- .../auth/storage/devices/devices_table.go | 36 ++++++++++++------- .../clientapi/auth/storage/devices/storage.go | 21 +++++++---- .../dendrite/clientapi/readers/login.go | 4 ++- .../dendrite/clientapi/readers/logout.go | 2 +- .../dendrite/clientapi/writers/register.go | 6 ++-- .../dendrite/cmd/create-account/main.go | 4 ++- 8 files changed, 60 insertions(+), 34 deletions(-) diff --git a/hooks/pre-commit b/hooks/pre-commit index b10073ca..904d38dc 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -7,6 +7,16 @@ export GOGC=400 export GOPATH="$(pwd):$(pwd)/vendor" export PATH="$PATH:$(pwd)/vendor/bin:$(pwd)/bin" +echo "Checking that it builds" +gb build + +# Check that all the packages can build. +# When `go build` is given multiple packages it won't output anything, and just +# checks that everything builds. This seems to do a better job of handling +# missing imports than `gb build` does. +echo "Double checking it builds..." +go build github.com/matrix-org/dendrite/cmd/... + echo "Installing lint search engine..." go install github.com/alecthomas/gometalinter/ gometalinter --config=linter.json ./... --install @@ -20,11 +30,5 @@ misspell -error src *.md echo "Testing..." gb test -# Check that all the packages can build. -# When `go build` is given multiple packages it won't output anything, and just -# checks that everything builds. This seems to do a better job of handling -# missing imports than `gb build` does. -echo "Double checking it builds..." -go build github.com/matrix-org/dendrite/cmd/... echo "Done!" diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go index 24df9019..833cf544 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go @@ -16,6 +16,7 @@ package auth import ( + "context" "crypto/rand" "database/sql" "encoding/base64" @@ -42,7 +43,7 @@ var tokenByteLength = 32 // DeviceDatabase represents a device database. type DeviceDatabase interface { // Look up the device matching the given access token. - GetDeviceByAccessToken(token string) (*authtypes.Device, error) + GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error) } // VerifyAccessToken verifies that an access token was supplied in the given HTTP request @@ -57,7 +58,7 @@ func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *auth } return } - device, err = deviceDB.GetDeviceByAccessToken(token) + device, err = deviceDB.GetDeviceByAccessToken(req.Context(), token) if err != nil { if err == sql.ErrNoRows { resErr = &util.JSONResponse{ diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go index f346f169..65112cd8 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go @@ -15,10 +15,13 @@ package devices import ( + "context" "database/sql" "fmt" "time" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -84,27 +87,36 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice(txn *sql.Tx, id, localpart, accessToken string) (dev *authtypes.Device, err error) { +func (s *devicesStatements) insertDevice( + ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, +) (*authtypes.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - if _, err = txn.Stmt(s.insertDeviceStmt).Exec(id, localpart, accessToken, createdTimeMS); err == nil { - dev = &authtypes.Device{ - ID: id, - UserID: makeUserID(localpart, s.serverName), - AccessToken: accessToken, - } + stmt := common.TxStmt(txn, s.insertDeviceStmt) + if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS); err != nil { + return nil, err } - return + return &authtypes.Device{ + ID: id, + UserID: makeUserID(localpart, s.serverName), + AccessToken: accessToken, + }, nil } -func (s *devicesStatements) deleteDevice(txn *sql.Tx, id, localpart string) error { - _, err := txn.Stmt(s.deleteDeviceStmt).Exec(id, localpart) +func (s *devicesStatements) deleteDevice( + ctx context.Context, txn *sql.Tx, id, localpart string, +) error { + stmt := common.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) return err } -func (s *devicesStatements) selectDeviceByToken(accessToken string) (*authtypes.Device, error) { +func (s *devicesStatements) selectDeviceByToken( + ctx context.Context, accessToken string, +) (*authtypes.Device, error) { var dev authtypes.Device var localpart string - err := s.selectDeviceByTokenStmt.QueryRow(accessToken).Scan(&dev.ID, &localpart) + stmt := s.selectDeviceByTokenStmt + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart) if err == nil { dev.UserID = makeUserID(localpart, s.serverName) dev.AccessToken = accessToken diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go index a58c2634..6b7ff919 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go @@ -15,6 +15,7 @@ package devices import ( + "context" "database/sql" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -44,8 +45,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // GetDeviceByAccessToken returns the device matching the given access token. // Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) { - return d.devices.selectDeviceByToken(token) +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*authtypes.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) } // CreateDevice makes a new device associated with the given user ID localpart. @@ -53,15 +56,17 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro // and replaced with the given accessToken. If the given accessToken is already in use for another device, // an error will be returned. // Returns the device on success. -func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) { +func (d *Database) CreateDevice( + ctx context.Context, localpart, deviceID, accessToken string, +) (dev *authtypes.Device, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error // Revoke existing token for this device - if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil { + if err = d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != nil { return err } - dev, err = d.devices.insertDevice(txn, deviceID, localpart, accessToken) + dev, err = d.devices.insertDevice(ctx, txn, deviceID, localpart, accessToken) if err != nil { return err } @@ -74,9 +79,11 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a // matching with the given device ID and user ID localpart // If the device doesn't exist, it will not return an error // If something went wrong during the deletion, it will return the SQL error -func (d *Database) RemoveDevice(deviceID string, localpart string) error { +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } return nil diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/login.go b/src/github.com/matrix-org/dendrite/clientapi/readers/login.go index 027560cc..43011890 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/login.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/login.go @@ -98,7 +98,9 @@ func Login( } // TODO: Use the device ID in the request - dev, err := deviceDB.CreateDevice(acc.Localpart, auth.UnknownDeviceID, token) + dev, err := deviceDB.CreateDevice( + req.Context(), acc.Localpart, auth.UnknownDeviceID, token, + ) if err != nil { return util.JSONResponse{ Code: 500, diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go b/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go index 585527fc..9e46c74c 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go @@ -41,7 +41,7 @@ func Logout( return httputil.LogThenError(req, err) } - if err := deviceDB.RemoveDevice(device.ID, localpart); err != nil { + if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go index a6f2c387..dfebaf6d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go @@ -135,9 +135,7 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices switch r.Auth.Type { case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, r.Password, - ) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password) default: return util.JSONResponse{ Code: 501, @@ -182,7 +180,7 @@ func completeRegistration( } // // TODO: Use the device ID in the request. - dev, err := deviceDB.CreateDevice(username, auth.UnknownDeviceID, token) + dev, err := deviceDB.CreateDevice(ctx, username, auth.UnknownDeviceID, token) if err != nil { return util.JSONResponse{ Code: 500, diff --git a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go index 82f1fec3..d031afc2 100644 --- a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go @@ -86,7 +86,9 @@ func main() { accessToken = &t } - device, err := deviceDB.CreateDevice(*username, "create-account-script", *accessToken) + device, err := deviceDB.CreateDevice( + context.Background(), *username, "create-account-script", *accessToken, + ) if err != nil { fmt.Println(err.Error()) os.Exit(1)