diff --git a/appservice/api/query.go b/appservice/api/query.go index cd74d866..e53ad425 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -23,7 +23,7 @@ import ( "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +85,7 @@ func RetrieveUserProfile( ctx context.Context, userID string, asAPI AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) (*authtypes.Profile, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 211b8d65..acf4406c 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -283,8 +283,6 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 3d9ba8aa..8b9c88f2 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index d678ada9..a65f3b70 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -28,7 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -37,7 +37,7 @@ func AddPublicRoutes( router *mux.Router, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, - accountsDB accounts.Database, + accountsDB userdb.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index e89d8ff2..80ac2293 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -137,7 +137,7 @@ type fledglingEvent struct { func CreateRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { // TODO (#267): Check room ID doesn't clash with an existing one, and we @@ -151,7 +151,7 @@ func CreateRoom( func createRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, roomID string, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { logger := util.GetLogger(req.Context()) diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 578aaec5..d30a87a5 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // Prepare to ask the roomserver to perform the room join. diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7b9d8acd..7ecab9d4 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -36,7 +36,7 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, keyserverAPI api.KeyInternalAPI, device *userapi.Device, - accountDB accounts.Database, cfg *config.ClientAPI, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index b48b9e93..ec5c998b 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -54,7 +54,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, + req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 58f18760..11223924 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -30,7 +30,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -39,7 +39,7 @@ import ( var errMissingUserID = errors.New("'user_id' must be supplied") func SendBan( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -81,7 +81,7 @@ func SendBan( return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) } -func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, +func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { @@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us } func SendKick( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -165,7 +165,7 @@ func SendKick( } func SendUnban( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -200,7 +200,7 @@ func SendUnban( } func SendInvite( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -271,7 +271,7 @@ func SendInvite( func buildMembershipEvent( ctx context.Context, - targetUserID, reason string, accountDB accounts.Database, + targetUserID, reason string, accountDB userdb.Database, device *userapi.Device, membership, roomID string, isDirect bool, cfg *config.ClientAPI, evTime time.Time, @@ -312,7 +312,7 @@ func loadProfile( ctx context.Context, userID string, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, asAPI appserviceAPI.AppServiceQueryAPI, ) (*authtypes.Profile, error) { _, serverName, err := gomatrixserverlib.SplitID('@', userID) @@ -366,7 +366,7 @@ func checkAndProcessThreepid( body *threepid.MembershipRequest, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, evTime time.Time, ) (inviteStored bool, errRes *util.JSONResponse) { diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index b2442443..49951019 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -9,7 +9,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -29,7 +29,7 @@ type newPasswordAuth struct { func Password( req *http.Request, userAPI api.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 26aa64ce..8f89e97f 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -19,7 +19,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to @@ -82,7 +82,7 @@ func UnpeekRoomByID( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, ) util.JSONResponse { unpeekReq := roomserverAPI.PerformUnpeekRequest{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 017facd2..717cbda7 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" @@ -36,7 +36,7 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, @@ -65,7 +65,7 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -92,7 +92,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -182,7 +182,7 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -209,7 +209,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -302,7 +302,7 @@ func SetDisplayName( // Returns an error when something goes wrong or specifically // eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( - ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, + ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index f73cc662..d00d9886 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -44,7 +44,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -448,7 +448,7 @@ func validateApplicationService( func Register( req *http.Request, userAPI userapi.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { var r registerRequest @@ -899,7 +899,7 @@ type availableResponse struct { func RegisterAvailable( req *http.Request, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) util.JSONResponse { username := req.URL.Query().Get("username") diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index da2ccf2f..63dcaa41 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -34,7 +34,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -51,7 +51,7 @@ func Setup( eduAPI eduServerAPI.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 3abf3db2..fd214b34 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -33,7 +33,7 @@ type typingContentJSON struct { // sends the typing events to client API typingProducer func SendTyping( req *http.Request, device *userapi.Device, roomID string, - userID string, accountDB accounts.Database, + userID string, accountDB userdb.Database, eduAPI api.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index f4d23379..d89b6295 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -40,7 +40,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf Code: http.StatusBadRequest, JSON: jsonerror.MatrixError{ ErrCode: "M_THREEPID_IN_USE", - Err: accounts.Err3PIDInUse.Error(), + Err: userdb.Err3PIDInUse.Error(), }, } } @@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -170,7 +170,7 @@ func GetAssociated3PIDs( } // Forget3PID implements POST /account/3pid/delete -func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { +func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse { var body authtypes.ThreePID if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index db62ce06..9d9a2ba7 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -87,7 +87,7 @@ var ( func CheckAndProcessInvite( ctx context.Context, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, db accounts.Database, + rsAPI api.RoomserverInternalAPI, db userdb.Database, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index d9202eb0..3003896c 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) const usage = `Usage: %s @@ -77,9 +77,14 @@ func main() { pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, - }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) + accountDB, err := userdb.NewDatabase( + &config.DatabaseOptions{ + ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, + }, + cfg.Global.ServerName, bcrypt.DefaultCost, + cfg.UserAPI.OpenIDTokenLifetimeMS, + api.DefaultLoginTokenLifetime, + ) if err != nil { logrus.Fatalln("Failed to connect to the database:", err.Error()) } diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 7cbd0b6d..78536901 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -126,7 +126,6 @@ func main() { cfg.FederationAPI.FederationMaxRetries = 6 cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index a897dcd1..5810a7f1 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -160,7 +160,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 52e69ee5..49e096bd 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -79,7 +79,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 62eea78f..664f644f 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -164,7 +164,6 @@ func startup() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 59de07cd..0ea41b4c 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -167,7 +167,6 @@ func main() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index f87665fb..ba5a87a7 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -32,7 +32,6 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI) } cfg.Global.TrustedIDServers = []string{ "matrix.org", diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 8ed5cbd9..31a5b005 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -8,12 +8,11 @@ import ( "log" "os" - pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" "github.com/pressly/goose" + pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -26,8 +25,7 @@ const ( RoomServer = "roomserver" SigningKeyServer = "signingkeyserver" SyncAPI = "syncapi" - UserAPIAccounts = "userapi_accounts" - UserAPIDevices = "userapi_devices" + UserAPI = "userapi" ) var ( @@ -35,7 +33,7 @@ var ( flags = flag.NewFlagSet("goose", flag.ExitOnError) component = flags.String("component", "", "dendrite component name") knownDBs = []string{ - AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, + AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI, } ) @@ -143,18 +141,14 @@ Commands: func loadSQLiteDeltas(component string) { switch component { - case UserAPIAccounts: - slaccounts.LoadFromGoose() - case UserAPIDevices: - sldevices.LoadFromGoose() + case UserAPI: + slusers.LoadFromGoose() } } func loadPostgresDeltas(component string) { switch component { - case UserAPIAccounts: - pgaccounts.LoadFromGoose() - case UserAPIDevices: - pgdevices.LoadFromGoose() + case UserAPI: + pgusers.LoadFromGoose() } } diff --git a/internal/test/config.go b/internal/test/config.go index 4fb6a946..0372fb9c 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.RoomServer.Database.ConnectionString = config.DataSource(database) cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database) cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress() diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ffbcac94..1c6b0677 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne } func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) + msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domain := string(serverName) // query local devices if serverName == a.ThisServer { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -554,10 +554,60 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( } func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &userapi.QueryDevicesResponse{} + if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + if !uapidevices.UserExists { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("user %q does not exist", req.UserID), + } + return + } + existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) + for _, key := range uapidevices.Devices { + existingDeviceMap[key.ID] = struct{}{} + } + + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + if _, ok := existingDeviceMap[k.DeviceID]; !ok { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + + if len(toClean) > 0 { + if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()), + } + return + } + logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale keyserver device key entries", len(toClean)) + } + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { - _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + var serverName gomatrixserverlib.ServerName + _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) if err != nil { continue // ignore invalid users } @@ -568,6 +618,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } + // check that the device in question actually exists in the user + // API before we try and store a key for it + if _, ok := existingDeviceMap[key.DeviceID]; !ok { + continue + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { @@ -583,29 +638,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per }) } - // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceMessage, len(keysToStore)) - for i := range keysToStore { - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, - }, - } - } - if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), - } - return - } if req.OnlyDisplayNameUpdates { // add the display name field from keysToStore into existingKeys keysToStore = appendDisplayNames(existingKeys, keysToStore) } // store the device keys and emit changes - err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 0110860e..4dffe695 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -53,7 +53,7 @@ type Database interface { // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 5ae0da96..628301cf 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - countStreamIDsForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5914d28e..deee76eb 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) } func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index fa1c930d..b461424c 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true } - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index c4c99d8c..4d513724 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false) if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index e44757e1..ff70a236 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -38,7 +38,7 @@ type DeviceKeys interface { InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } diff --git a/setup/base/base.go b/setup/base/base.go index 819fe1ad..e3997754 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -38,7 +38,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/gorilla/mux" @@ -273,8 +273,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. -func (b *BaseDendrite) CreateAccountsDB() accounts.Database { - db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) +func (b *BaseDendrite) CreateAccountsDB() userdb.Database { + db, err := userdb.NewDatabase( + &b.Cfg.UserAPI.AccountDatabase, + b.Cfg.Global.ServerName, + b.Cfg.UserAPI.BCryptCost, + b.Cfg.UserAPI.OpenIDTokenLifetimeMS, + userapi.DefaultLoginTokenLifetime, + ) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index b2cde2e9..1cb5eba1 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -16,9 +16,6 @@ type UserAPI struct { // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` - // The Device database stores session information for the devices of logged - // in local users. It is accessed by the UserAPI. - DeviceDatabase DatabaseOptions `yaml:"device_database"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes @@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) { c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781" c.AccountDatabase.Defaults(10) - c.DeviceDatabase.Defaults(10) if generate { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" - c.DeviceDatabase.ConnectionString = "file:userapi_devices.db" } c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS @@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) - checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString)) checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) } diff --git a/setup/monolith.go b/setup/monolith.go index e6c95522..61125e4a 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -38,7 +38,7 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - AccountDB accounts.Database + AccountDB userdb.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client FedClient *gomatrixserverlib.FederationClient diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go index f3aa037e..e2207bb5 100644 --- a/userapi/api/api_logintoken.go +++ b/userapi/api/api_logintoken.go @@ -19,6 +19,13 @@ import ( "time" ) +// DefaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const DefaultLoginTokenLifetime = 2 * time.Minute + type LoginTokenInternalAPI interface { // PerformLoginTokenCreation creates a new login token and associates it with the provided data. PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error diff --git a/userapi/internal/api.go b/userapi/internal/api.go index f96d4804..f54cc613 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -31,13 +31,11 @@ import ( keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" ) type UserInternalAPI struct { - AccountDB accounts.Database - DeviceDB devices.Database + DB storage.Database ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService @@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) + return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) + acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { return err } res.PasswordUpdated = true @@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) } if err != nil { return err @@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil @@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if domain != a.ServerName { return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) } - prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { if err == sql.ErrNoRows { return nil @@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil } func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) + profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) if err != nil { return err } @@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer } func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { - devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { return err } @@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if domain != a.ServerName { return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) } - devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { return err } + res.UserExists = true res.Devices = devs return nil } @@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } @@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.AccountDB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local) if err != nil { return err } @@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } - device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) + device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) if err != nil { if err == sql.ErrNoRows { return nil @@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc if err != nil { return err } - acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart) + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) if err != nil { return err } @@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart) res.AccountDeactivated = err == nil return err } @@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { token := util.RandomString(24) - exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) + exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) res.Token = api.OpenIDToken{ Token: token, @@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) + openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) if err != nil { return err } @@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } return nil } - exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) if err != nil { res.Error = fmt.Sprintf("failed to delete backup: %s", err) } @@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Create metadata if req.Version == "" { - version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to create backup: %s", err) } @@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Update metadata if len(req.Keys.Rooms) == 0 { - err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to update backup: %s", err) } @@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { // you can only upload keys for the CURRENT version - version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") + version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { res.Error = fmt.Sprintf("failed to query version: %s", err) return @@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform }) } } - count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { res.Error = fmt.Sprintf("failed to upsert keys: %s", err) return @@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform } func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { @@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB res.Exists = !deleted if !req.ReturnKeys { - res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } return } - result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) return diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 86ffc58f..f1bf391e 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if domain != a.ServerName { return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) } - tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { return err } @@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap // PerformLoginTokenDeletion ensures the token doesn't exist. func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { util.GetLogger(ctx).Info("PerformLoginTokenDeletion") - return a.DeviceDB.RemoveLoginToken(ctx, req.Token) + return a.DB.RemoveLoginToken(ctx, req.Token) } // QueryLoginToken returns the data associated with a login token. If // the token is not valid, success is returned, but res.Data == nil. func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { - tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) + tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token) if err != nil { res.Data = nil if err == sql.ErrNoRows { @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if domain != a.ServerName { return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) } - if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go deleted file mode 100644 index 8ff91cf1..00000000 --- a/userapi/storage/devices/interface.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "context" - - "github.com/matrix-org/dendrite/userapi/api" -) - -type Database interface { - GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) - GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - // CreateDevice makes a new device associated with the given user ID localpart. - // If there is already a device with the same device ID for this user, that access token will be revoked - // and replaced with the given accessToken. If the given accessToken is already in use for another device, - // an error will be returned. - // If no device ID is given one is generated. - // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error - RemoveDevice(ctx context.Context, deviceID, localpart string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error - // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) - - // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. - CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) - - // RemoveLoginToken removes the named token (and may clean up other expired tokens). - RemoveLoginToken(ctx context.Context, token string) error - - // GetLoginTokenDataByToken returns the data associated with the given token. - // May return sql.ErrNoRows. - GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) -} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go deleted file mode 100644 index fd9d513f..00000000 --- a/userapi/storage/devices/postgres/storage.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 -) - -// Database represents a device database. -type Database struct { - db *sql.DB - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - var d devicesStatements - var lt loginTokenStatements - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - if err = lt.execSchema(db); err != nil { - return nil, err - } - - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - if err = d.prepare(db, serverName); err != nil { - return nil, err - } - if err = lt.prepare(db); err != nil { - return nil, err - } - - return &Database{db, d, lt, loginTokenLifetime}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() - if err != nil { - return nil, err - } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) -} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go deleted file mode 100644 index 6e90413b..00000000 --- a/userapi/storage/devices/sqlite3/storage.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - - loginTokenByteLength = 32 -) - -// Database represents a device database. -type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - writer := sqlutil.NewExclusiveWriter() - var d devicesStatements - var lt loginTokenStatements - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - if err = lt.execSchema(db); err != nil { - return nil, err - } - - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - if err = d.prepare(db, writer, serverName); err != nil { - return nil, err - } - if err = lt.prepare(db); err != nil { - return nil, err - } - return &Database{db, writer, d, lt, loginTokenLifetime}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() - if err != nil { - return nil, err - } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) -} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go deleted file mode 100644 index 15cf8150..00000000 --- a/userapi/storage/devices/storage.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package devices - -import ( - "fmt" - "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters. loginTokenLifetime determines how long a -// login token from CreateLoginToken is valid. -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go deleted file mode 100644 index 3de7880b..00000000 --- a/userapi/storage/devices/storage_wasm.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "fmt" - "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -func NewDatabase( - dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, - loginTokenLifetime time.Duration, -) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/interface.go similarity index 67% rename from userapi/storage/accounts/interface.go rename to userapi/storage/interface.go index a2185774..a131dac4 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/interface.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "context" @@ -60,6 +60,35 @@ type Database interface { UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) + + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error + RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenDataByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go similarity index 100% rename from userapi/storage/accounts/postgres/account_data_table.go rename to userapi/storage/postgres/account_data_table.go diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go similarity index 100% rename from userapi/storage/accounts/postgres/accounts_table.go rename to userapi/storage/postgres/accounts_table.go diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go similarity index 100% rename from userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go rename to userapi/storage/postgres/deltas/20200929203058_is_active.go diff --git a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go similarity index 89% rename from userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index 290f854c..1bbb0a9d 100644 --- a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go similarity index 100% rename from userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go rename to userapi/storage/postgres/deltas/2022021013023800_add_account_type.go diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go similarity index 99% rename from userapi/storage/devices/postgres/devices_table.go rename to userapi/storage/postgres/devices_table.go index 7de9f5f9..64cc0b71 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -117,6 +117,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error { } func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + if err = s.execSchema(db); err != nil { + return + } if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go similarity index 100% rename from userapi/storage/accounts/postgres/key_backup_table.go rename to userapi/storage/postgres/key_backup_table.go diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go similarity index 100% rename from userapi/storage/accounts/postgres/key_backup_version_table.go rename to userapi/storage/postgres/key_backup_version_table.go diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go similarity index 98% rename from userapi/storage/devices/postgres/logintoken_table.go rename to userapi/storage/postgres/logintoken_table.go index f601fc7d..508a6898 100644 --- a/userapi/storage/devices/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp // prepare runs statement preparation. func (s *loginTokenStatements) prepare(db *sql.DB) error { + if err := s.execSchema(db); err != nil { + return err + } return sqlutil.StatementList{ {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go similarity index 100% rename from userapi/storage/accounts/postgres/openid_table.go rename to userapi/storage/postgres/openid_table.go diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go similarity index 100% rename from userapi/storage/accounts/postgres/profile_table.go rename to userapi/storage/postgres/profile_table.go diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/postgres/storage.go similarity index 70% rename from userapi/storage/accounts/postgres/storage.go rename to userapi/storage/postgres/storage.go index d31efd25..73419279 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -16,7 +16,9 @@ package postgres import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -30,7 +32,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" // Import the postgres database driver. _ "github.com/lib/pq" @@ -47,14 +49,23 @@ type Database struct { threepids threepidStatements openIDTokens tokenStatements keyBackupVersions keyBackupVersionStatements + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 } +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) + // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err @@ -63,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver serverName: serverName, db: db, writer: sqlutil.NewDummyWriter(), + loginTokenLifetime: loginTokenLifetime, bcryptCost: bcryptCost, openIDTokenLifetimeMS: openIDTokenLifetimeMS, } @@ -74,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err @@ -103,6 +116,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.keyBackups.prepare(db); err != nil { return nil, err } + if err = d.devices.prepare(db, serverName); err != nil { + return nil, err + } + if err = d.loginTokens.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -515,3 +534,196 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go similarity index 100% rename from userapi/storage/accounts/postgres/threepid_table.go rename to userapi/storage/postgres/threepid_table.go diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/account_data_table.go rename to userapi/storage/sqlite3/account_data_table.go diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/accounts_table.go rename to userapi/storage/sqlite3/accounts_table.go diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/sqlite3/constraint_wasm.go similarity index 100% rename from userapi/storage/accounts/sqlite3/constraint_wasm.go rename to userapi/storage/sqlite3/constraint_wasm.go diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go similarity index 100% rename from userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go rename to userapi/storage/sqlite3/deltas/20200929203058_is_active.go diff --git a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go similarity index 94% rename from userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 26209826..ebf90800 100644 --- a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go similarity index 100% rename from userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go rename to userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go similarity index 99% rename from userapi/storage/devices/sqlite3/devices_table.go rename to userapi/storage/sqlite3/devices_table.go index 955d8ac7..119ecdf9 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -106,6 +106,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error { func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db s.writer = writer + if err = s.execSchema(db); err != nil { + return + } if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/key_backup_table.go rename to userapi/storage/sqlite3/key_backup_table.go diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/key_backup_version_table.go rename to userapi/storage/sqlite3/key_backup_version_table.go diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go similarity index 98% rename from userapi/storage/devices/sqlite3/logintoken_table.go rename to userapi/storage/sqlite3/logintoken_table.go index 75ef272f..52322b46 100644 --- a/userapi/storage/devices/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp // prepare runs statement preparation. func (s *loginTokenStatements) prepare(db *sql.DB) error { + if err := s.execSchema(db); err != nil { + return err + } return sqlutil.StatementList{ {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/openid_table.go rename to userapi/storage/sqlite3/openid_table.go diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/profile_table.go rename to userapi/storage/sqlite3/profile_table.go diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go similarity index 71% rename from userapi/storage/accounts/sqlite3/storage.go rename to userapi/storage/sqlite3/storage.go index 0bab16ca..56ec1b6a 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -16,7 +16,9 @@ package sqlite3 import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -31,7 +33,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // Database represents an account database @@ -47,6 +49,9 @@ type Database struct { openIDTokens tokenStatements keyBackupVersions keyBackupVersionStatements keyBackups keyBackupStatements + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -57,8 +62,14 @@ type Database struct { threepidsMu sync.Mutex } +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) + // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err @@ -67,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver serverName: serverName, db: db, writer: sqlutil.NewExclusiveWriter(), + loginTokenLifetime: loginTokenLifetime, bcryptCost: bcryptCost, openIDTokenLifetimeMS: openIDTokenLifetimeMS, } @@ -78,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err @@ -108,6 +121,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.keyBackups.prepare(db); err != nil { return nil, err } + if err = d.devices.prepare(db, d.writer, serverName); err != nil { + return nil, err + } + if err = d.loginTokens.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -547,3 +566,196 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/threepid_table.go rename to userapi/storage/sqlite3/threepid_table.go diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/storage.go similarity index 81% rename from userapi/storage/accounts/storage.go rename to userapi/storage/storage.go index f43f7efd..4711439a 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/storage.go @@ -15,26 +15,27 @@ //go:build !wasm // +build !wasm -package accounts +package storage import ( "fmt" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/storage_wasm.go similarity index 87% rename from userapi/storage/accounts/storage_wasm.go rename to userapi/storage/storage_wasm.go index 11a88a20..701dcd83 100644 --- a/userapi/storage/accounts/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) @@ -27,10 +28,11 @@ func NewDatabase( serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index c7e1f667..4a5793ab 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -23,18 +23,10 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) -// defaultLoginTokenLifetime determines how old a valid token may be. -// -// NOTSPEC: The current spec says "SHOULD be limited to around five -// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. -// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). -const defaultLoginTokenLifetime = 2 * time.Minute - // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) + db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) + return newInternalAPI(db, cfg, appServices, keyAPI) } func newInternalAPI( - accountDB accounts.Database, - deviceDB devices.Database, + db storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { return &internal.UserInternalAPI{ - AccountDB: accountDB, - DeviceDB: deviceDB, + DB: db, ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 141dd96d..4214c07f 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -31,8 +31,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" ) const ( @@ -43,23 +42,19 @@ type apiTestOpts struct { loginTokenLifetime time.Duration } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) { if opts.loginTokenLifetime == 0 { - opts.loginTokenLifetime = defaultLoginTokenLifetime + opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } dbopts := &config.DatabaseOptions{ ConnectionString: "file::memory:", MaxOpenConnections: 1, MaxIdleConnections: 1, } - accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) + accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime) if err != nil { t.Fatalf("failed to create account DB: %s", err) } - deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) - if err != nil { - t.Fatalf("failed to create device DB: %s", err) - } cfg := &config.UserAPI{ Matrix: &config.Global{ @@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a }, } - return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) {