diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index e2ff79c3..84103f38 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -131,17 +131,15 @@ func (m *DendriteMonolith) Start() { ) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - + stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, keyRing, + base, federation, rsAPI, stateAPI, keyRing, ) // The underlying roomserver implementation needs to be able to call the fedsender. // This is different to rsAPI which can be the http client which doesn't need this dependency rsAPI.SetFederationSenderAPI(fsAPI) - stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 4e774acc..7333e8b4 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -153,6 +153,7 @@ func main() { base, serverKeyAPI, ) + stateAPI := currentstateserver.NewInternalAPI(base.Base.Cfg, base.Base.KafkaConsumer) rsAPI := roomserver.NewInternalAPI( &base.Base, keyRing, federation, ) @@ -161,10 +162,9 @@ func main() { ) asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) fsAPI := federationsender.NewInternalAPI( - &base.Base, federation, rsAPI, keyRing, + &base.Base, federation, rsAPI, stateAPI, keyRing, ) rsAPI.SetFederationSenderAPI(fsAPI) - stateAPI := currentstateserver.NewInternalAPI(base.Base.Cfg, base.Base.KafkaConsumer) provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI, stateAPI) err = provider.Start() if err != nil { diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 0655a2a3..8f6b0eaf 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -117,17 +117,15 @@ func main() { ) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - + stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, keyRing, + base, federation, rsAPI, stateAPI, keyRing, ) rsComponent.SetFederationSenderAPI(fsAPI) embed.Embed(base.BaseMux, *instancePort, "Yggdrasil Demo") - stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, diff --git a/cmd/dendrite-federation-sender-server/main.go b/cmd/dendrite-federation-sender-server/main.go index 20bc1070..fa6cf7ab 100644 --- a/cmd/dendrite-federation-sender-server/main.go +++ b/cmd/dendrite-federation-sender-server/main.go @@ -31,7 +31,7 @@ func main() { rsAPI := base.RoomserverHTTPClient() fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, keyRing, + base, federation, rsAPI, base.CurrentStateAPIClient(), keyRing, ) federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index bce5fce0..c75ef8fb 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -109,8 +109,10 @@ func main() { asAPI = base.AppserviceHTTPClient() } + stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) + fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, keyRing, + base, federation, rsAPI, stateAPI, keyRing, ) if base.UseHTTPAPIs { federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) @@ -120,8 +122,6 @@ func main() { // This is different to rsAPI which can be the http client which doesn't need this dependency rsImpl.SetFederationSenderAPI(fsAPI) - stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 0df53e06..fd407e6e 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -208,17 +208,16 @@ func main() { KeyDatabase: fetcher, } + stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, ) - fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing) + fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, stateAPI, &keyRing) rsAPI.SetFederationSenderAPI(fedSenderAPI) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) - stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 6369c708..0d1d85b4 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -13,10 +13,11 @@ package routing import ( + "encoding/json" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - userapi "github.com/matrix-org/dendrite/userapi/api" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -24,30 +25,35 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - userAPI userapi.UserInternalAPI, + keyAPI keyapi.KeyInternalAPI, userID string, ) util.JSONResponse { - response := gomatrixserverlib.RespUserDevices{ - UserID: userID, - // TODO: we should return an incrementing stream ID each time the device - // list changes for delta changes to be recognised - StreamID: 0, - } - - var res userapi.QueryDevicesResponse - err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + var res keyapi.QueryDeviceMessagesResponse + keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ UserID: userID, }, &res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryDevices failed") + if res.Error != nil { + util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") return jsonerror.InternalServerError() } + response := gomatrixserverlib.RespUserDevices{ + UserID: userID, + StreamID: res.StreamID, + } + for _, dev := range res.Devices { + var key gomatrixserverlib.RespUserDeviceKeys + err := json.Unmarshal(dev.DeviceKeys.KeyJSON, &key) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Warnf("malformed device key: %s", string(dev.DeviceKeys.KeyJSON)) + continue + } + device := gomatrixserverlib.RespUserDevice{ - DeviceID: dev.ID, + DeviceID: dev.DeviceID, DisplayName: dev.DisplayName, - Keys: []gomatrixserverlib.RespUserDeviceKeys{}, + Keys: []gomatrixserverlib.RespUserDeviceKeys{key}, } response.Devices = append(response.Devices, device) } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 50b7bdd2..9808d623 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -186,7 +186,7 @@ func Setup( "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, userAPI, vars["userID"], + httpReq, keyAPI, vars["userID"], ) }, )).Methods(http.MethodGet) diff --git a/federationsender/consumers/keychange.go b/federationsender/consumers/keychange.go new file mode 100644 index 00000000..4c3d23b5 --- /dev/null +++ b/federationsender/consumers/keychange.go @@ -0,0 +1,135 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/Shopify/sarama" + stateapi "github.com/matrix-org/dendrite/currentstateserver/api" + "github.com/matrix-org/dendrite/federationsender/queue" + "github.com/matrix-org/dendrite/federationsender/storage" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +// KeyChangeConsumer consumes events that originate in key server. +type KeyChangeConsumer struct { + consumer *internal.ContinualConsumer + db storage.Database + queues *queue.OutgoingQueues + serverName gomatrixserverlib.ServerName + stateAPI stateapi.CurrentStateInternalAPI +} + +// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. +func NewKeyChangeConsumer( + cfg *config.Dendrite, + kafkaConsumer sarama.Consumer, + queues *queue.OutgoingQueues, + store storage.Database, + stateAPI stateapi.CurrentStateInternalAPI, +) *KeyChangeConsumer { + c := &KeyChangeConsumer{ + consumer: &internal.ContinualConsumer{ + Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + }, + queues: queues, + db: store, + serverName: cfg.Matrix.ServerName, + stateAPI: stateAPI, + } + c.consumer.ProcessMessage = c.onMessage + + return c +} + +// Start consuming from key servers +func (t *KeyChangeConsumer) Start() error { + if err := t.consumer.Start(); err != nil { + return fmt.Errorf("t.consumer.Start: %w", err) + } + return nil +} + +// onMessage is called in response to a message received on the +// key change events topic from the key server. +func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var m api.DeviceMessage + if err := json.Unmarshal(msg.Value, &m); err != nil { + log.WithError(err).Errorf("failed to read device message from key change topic") + return nil + } + logger := log.WithField("user_id", m.UserID) + + // only send key change events which originated from us + _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID) + if err != nil { + logger.WithError(err).Error("Failed to extract domain from key change event") + return nil + } + if originServerName != t.serverName { + return nil + } + + var queryRes stateapi.QueryRoomsForUserResponse + err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{ + UserID: m.UserID, + WantMembership: "join", + }, &queryRes) + if err != nil { + logger.WithError(err).Error("failed to calculate joined rooms for user") + return nil + } + // send this key change to all servers who share rooms with this user. + destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs) + if err != nil { + logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") + return nil + } + + // Pack the EDU and marshal it + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MDeviceListUpdate, + Origin: string(t.serverName), + } + event := gomatrixserverlib.DeviceListUpdateEvent{ + UserID: m.UserID, + DeviceID: m.DeviceID, + DeviceDisplayName: m.DisplayName, + StreamID: m.StreamID, + PrevID: prevID(m.StreamID), + Deleted: len(m.KeyJSON) == 0, + Keys: m.KeyJSON, + } + if edu.Content, err = json.Marshal(event); err != nil { + return err + } + + log.Infof("Sending device list update message to %q", destinations) + return t.queues.SendEDU(edu, t.serverName, destinations) +} + +func prevID(streamID int) []int { + if streamID <= 1 { + return nil + } + return []int{streamID - 1} +} diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index 9e14f6ec..fbf506aa 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -16,6 +16,7 @@ package federationsender import ( "github.com/gorilla/mux" + stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/consumers" "github.com/matrix-org/dendrite/federationsender/internal" @@ -41,6 +42,7 @@ func NewInternalAPI( base *setup.BaseDendrite, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, + stateAPI stateapi.CurrentStateInternalAPI, keyRing *gomatrixserverlib.KeyRing, ) api.FederationSenderInternalAPI { federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender), base.Cfg.DbProperties()) @@ -76,6 +78,12 @@ func NewInternalAPI( if err := tsConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start typing server consumer") } + keyConsumer := consumers.NewKeyChangeConsumer( + base.Cfg, base.KafkaConsumer, queues, federationSenderDB, stateAPI, + ) + if err := keyConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start key server consumer") + } return internal.NewFederationSenderInternalAPI(federationSenderDB, base.Cfg, rsAPI, federation, keyRing, stats, queues) } diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index b79499d3..734b368f 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -30,6 +30,8 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index af0a5258..52865996 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -60,12 +60,16 @@ const selectJoinedHostsSQL = "" + const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" +const selectJoinedHostsForRoomsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" + type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -88,6 +92,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { + return + } return } @@ -144,6 +151,27 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( return result, rows.Err() } +func (s *joinedHostsStatements) SelectJoinedHostsForRooms( + ctx context.Context, roomIDs []string, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 52f02a28..4a681de6 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -123,6 +123,10 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx) } +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries. diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index bd917c61..4ae980d7 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -59,13 +59,17 @@ const selectJoinedHostsSQL = "" + const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" +const selectJoinedHostsForRoomsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" + type joinedHostsStatements struct { - db *sql.DB - writer *sqlutil.TransactionWriter - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt + db *sql.DB + writer *sqlutil.TransactionWriter + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -89,6 +93,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { + return + } return } @@ -153,6 +160,32 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( return result, rows.Err() } +func (s *joinedHostsStatements) SelectJoinedHostsForRooms( + ctx context.Context, roomIDs []string, +) ([]gomatrixserverlib.ServerName, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i := range roomIDs { + iRoomIDs[i] = roomIDs[i] + } + + rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go index 2def48d0..c6f8a2d5 100644 --- a/federationsender/storage/tables/interface.go +++ b/federationsender/storage/tables/interface.go @@ -53,6 +53,7 @@ type FederationSenderJoinedHosts interface { SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) } type FederationSenderRooms interface { diff --git a/go.mod b/go.mod index f087b087..f956a46a 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d + github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165739-3bd1ef0f0852 github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible diff --git a/go.sum b/go.sum index de7527d9..76bd338a 100644 --- a/go.sum +++ b/go.sum @@ -425,6 +425,10 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM= github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165250-352235625587 h1:n2IZkm5LI4lACulOa5WU6QwWUhHUtBZez7YIFr1fCOs= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165250-352235625587/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165739-3bd1ef0f0852 h1:OBvHjLWaT2KS9kGarX2ES0yKBL/wMxAeQB39tRrAAls= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165739-3bd1ef0f0852/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 080d0e5f..c864b328 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -32,6 +32,7 @@ type KeyInternalAPI interface { QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) } // KeyError is returned if there was a problem performing/querying the server @@ -188,3 +189,14 @@ type QueryOneTimeKeysResponse struct { Count OneTimeKeysCount Error *KeyError } + +type QueryDeviceMessagesRequest struct { + UserID string +} + +type QueryDeviceMessagesResponse struct { + // The latest stream ID + StreamID int + Devices []DeviceMessage + Error *KeyError +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 9027cbf4..474f30ff 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -179,6 +179,24 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne res.Count = *count } +func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { + msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query DB for device keys: %s", err), + } + return + } + maxStreamID := 0 + for _, m := range msgs { + if m.StreamID > maxStreamID { + maxStreamID = m.StreamID + } + } + res.Devices = msgs + res.StreamID = maxStreamID +} + func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index b65cbdaf..93b19051 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -27,11 +27,12 @@ import ( // HTTP paths for the internal HTTP APIs const ( - PerformUploadKeysPath = "/keyserver/performUploadKeys" - PerformClaimKeysPath = "/keyserver/performClaimKeys" - QueryKeysPath = "/keyserver/queryKeys" - QueryKeyChangesPath = "/keyserver/queryKeyChanges" - QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" + PerformUploadKeysPath = "/keyserver/performUploadKeys" + PerformClaimKeysPath = "/keyserver/performClaimKeys" + QueryKeysPath = "/keyserver/queryKeys" + QueryKeyChangesPath = "/keyserver/queryKeyChanges" + QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" + QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -126,6 +127,23 @@ func (h *httpKeyInternalAPI) QueryOneTimeKeys( } } +func (h *httpKeyInternalAPI) QueryDeviceMessages( + ctx context.Context, + request *api.QueryDeviceMessagesRequest, + response *api.QueryDeviceMessagesResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages") + defer span.Finish() + + apiURL := h.apiURL + QueryDeviceMessagesPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.KeyError{ + Err: err.Error(), + } + } +} + func (h *httpKeyInternalAPI) QueryKeyChanges( ctx context.Context, request *api.QueryKeyChangesRequest, diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 615b6f80..f0cd3038 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -69,6 +69,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryDeviceMessagesPath, + httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse { + request := api.QueryDeviceMessagesRequest{} + response := api.QueryDeviceMessagesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.QueryDeviceMessages(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryKeyChangesPath, httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { request := api.QueryKeyChangesRequest{} diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index af456af4..cc33c738 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -41,6 +41,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyCh } func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { +} +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) { + } type mockCurrentStateAPI struct {