dendrite/userapi/internal/device_list_update_test.go

541 lines
16 KiB
Go

// Copyright 2024 New Vector Ltd.
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package internal
import (
"context"
"crypto/ed25519"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
api2 "github.com/element-hq/dendrite/federationapi/api"
"github.com/element-hq/dendrite/federationapi/statistics"
"github.com/element-hq/dendrite/internal/caching"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
roomserver "github.com/element-hq/dendrite/roomserver/api"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/setup/process"
"github.com/element-hq/dendrite/test"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage"
)
var (
ctx = context.Background()
)
type mockKeyChangeProducer struct {
events []api.DeviceMessage
}
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
p.events = append(p.events, keys...)
return nil
}
type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool
prevIDsExist func(string, []int64) bool
storedKeys []api.DeviceMessage
mu sync.Mutex // protect staleUsers
}
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
return nil
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) {
d.mu.Lock()
defer d.mu.Unlock()
var result []string
for userID, isStale := range d.staleUsers {
if !isStale {
continue
}
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil, err
}
if len(domains) == 0 {
result = append(result, userID)
continue
}
for _, d := range domains {
if remoteServer == d {
result = append(result, userID)
break
}
}
}
return result, nil
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
d.mu.Lock()
defer d.mu.Unlock()
d.staleUsers[userID] = isStale
return nil
}
func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
d.mu.Lock()
defer d.mu.Unlock()
return d.staleUsers[userID]
}
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
d.storedKeys = append(d.storedKeys, keys...)
return nil
}
// PrevIDsExists returns true if all prev IDs exist for this user.
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
return d.prevIDsExist(userID, prevIDs), nil
}
func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return nil
}
type mockDeviceListUpdaterAPI struct {
}
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
}
var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) {
return &statistics.ServerStatistics{}, nil
}
type roundTripper struct {
fn func(*http.Request) (*http.Response, error)
}
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.fn(req)
}
func newFedClient(tripper func(*http.Request) (*http.Response, error)) fclient.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil)
fedClient := fclient.NewFederationClient(
[]*fclient.SigningIdentity{
{
ServerName: spec.ServerName("example.test"),
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
PrivateKey: pkey,
},
},
fclient.WithTransport(&roundTripper{tripper}),
)
return fedClient
}
// Test that the device keys get persisted and emitted if we have the previous IDs.
func TestUpdateHavePrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int64) bool {
return true
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost", caching.DisableMetrics, testIsBlacklistedOrBackingOff)
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
DeviceID: "FOO",
Keys: []byte(`{"key":"value"}`),
PrevID: []int64{0},
StreamID: 1,
UserID: "@alice:localhost",
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
want := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: event.StreamID,
DeviceKeys: &api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys,
UserID: event.UserID,
},
}
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
}
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
}
if db.isStale(event.UserID) {
t.Errorf("%s incorrectly marked as stale", event.UserID)
}
}
// Test that device keys are fetched from the remote server if we are missing prev IDs
// and that the user's devices are marked as stale until it succeeds.
func TestUpdateNoPrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int64) bool {
return false
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
remoteUserID := "@alice:example.somewhere"
var wg sync.WaitGroup
wg.Add(1)
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
defer wg.Done()
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
}
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`
{
"user_id": "` + remoteUserID + `",
"stream_id": 5,
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": ` + keyJSON + `,
"device_display_name": "Mobile Phone"
}
]
}
`)),
}, nil
})
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test", caching.DisableMetrics, testIsBlacklistedOrBackingOff)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Mobile Phone",
Deleted: false,
DeviceID: "another_device_id",
Keys: []byte(`{"key":"value"}`),
PrevID: []int64{3},
StreamID: 4,
UserID: remoteUserID,
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
t.Log("waiting for /users/devices to be called...")
wg.Wait()
// wait a bit for db to be updated...
time.Sleep(100 * time.Millisecond)
want := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: 5,
DeviceKeys: &api.DeviceKeys{
DeviceID: "JLAFKJWSCS",
DisplayName: "Mobile Phone",
UserID: remoteUserID,
KeyJSON: []byte(keyJSON),
},
}
// Now we should have a fresh list and the keys and emitted something
if db.isStale(event.UserID) {
t.Errorf("%s still marked as stale", event.UserID)
}
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
}
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
}
}
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
// update is still ongoing.
func TestDebounce(t *testing.T) {
t.Skipf("panic on closed channel on GHA")
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int64) bool {
return true
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
fedCh := make(chan *http.Response, 1)
srv := spec.ServerName("example.com")
userID := "@alice:example.com"
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
incomingFedReq := make(chan struct{})
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
}
close(incomingFedReq)
return <-fedCh, nil
})
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost", caching.DisableMetrics, testIsBlacklistedOrBackingOff)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
// hit this 5 times
var wg sync.WaitGroup
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
defer wg.Done()
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
t.Errorf("ManualUpdate: %s", err)
}
}()
}
// wait until the updater hits federation
select {
case <-incomingFedReq:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for updater to hit federation")
}
// user should be marked as stale
if !db.isStale(userID) {
t.Errorf("user %s not marked as stale", userID)
}
// now send the response over federation
fedCh <- &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`
{
"user_id": "` + userID + `",
"stream_id": 5,
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": ` + keyJSON + `,
"device_display_name": "Mobile Phone"
}
]
}
`)),
}
close(fedCh)
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
// and should panic with read on a closed channel
wg.Wait()
// user is no longer stale now
if db.isStale(userID) {
t.Errorf("user %s is marked as stale", userID)
}
}
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
t.Helper()
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
if err != nil {
t.Fatal(err)
}
return db, clearDB
}
type mockKeyserverRoomserverAPI struct {
leftUsers []string
}
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
res.LeftUsers = m.leftUsers
return nil
}
func TestDeviceListUpdater_CleanUp(t *testing.T) {
processCtx := process.NewProcessContext()
alice := test.NewUser(t)
bob := test.NewUser(t)
// Bob is not joined to any of our rooms
rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clearDB := mustCreateKeyserverDB(t, dbType)
defer clearDB()
// This should not get deleted
if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
t.Error(err)
}
// this one should get deleted
if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
t.Error(err)
}
updater := NewDeviceListUpdater(processCtx, db, nil,
nil, nil,
0, rsAPI, "test", caching.DisableMetrics, testIsBlacklistedOrBackingOff)
if err := updater.CleanUp(); err != nil {
t.Error(err)
}
// check that we still have Alice in our stale list
staleUsers, err := db.StaleDeviceLists(ctx, []spec.ServerName{"test"})
if err != nil {
t.Error(err)
}
// There should only be Alice
wantCount := 1
if count := len(staleUsers); count != wantCount {
t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
}
if staleUsers[0] != alice.ID {
t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
}
})
}
func Test_dedupeStateList(t *testing.T) {
alice := "@alice:localhost"
bob := "@bob:localhost"
charlie := "@charlie:notlocalhost"
invalidUserID := "iaminvalid:localhost"
tests := []struct {
name string
staleLists []string
want []string
}{
{
name: "empty stateLists",
staleLists: []string{},
want: []string{},
},
{
name: "single entry",
staleLists: []string{alice},
want: []string{alice},
},
{
name: "multiple entries without dupe servers",
staleLists: []string{alice, charlie},
want: []string{alice, charlie},
},
{
name: "multiple entries with dupe servers",
staleLists: []string{alice, bob, charlie},
want: []string{alice, charlie},
},
{
name: "list with invalid userID",
staleLists: []string{alice, bob, invalidUserID},
want: []string{alice},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := dedupeStaleLists(tt.staleLists); !reflect.DeepEqual(got, tt.want) {
t.Errorf("dedupeStaleLists() = %v, want %v", got, tt.want)
}
})
}
}
func TestDeviceListUpdaterIgnoreBlacklisted(t *testing.T) {
unreachableServer := spec.ServerName("notlocalhost")
updater := DeviceListUpdater{
workerChans: make([]chan spec.ServerName, 1),
isBlacklistedOrBackingOffFn: func(s spec.ServerName) (*statistics.ServerStatistics, error) {
switch s {
case unreachableServer:
return nil, &api2.FederationClientError{Blacklisted: true}
}
return nil, nil
},
mu: &sync.Mutex{},
userIDToChanMu: &sync.Mutex{},
userIDToChan: make(map[string]chan bool),
userIDToMutex: make(map[string]*sync.Mutex),
}
workerCh := make(chan spec.ServerName)
defer close(workerCh)
updater.workerChans[0] = workerCh
// happy case
alice := "@alice:localhost"
aliceCh := updater.assignChannel(alice)
defer updater.clearChannel(alice)
// failing case
bob := "@bob:" + unreachableServer
bobCh := updater.assignChannel(string(bob))
defer updater.clearChannel(string(bob))
expectedServers := map[spec.ServerName]struct{}{
"localhost": {},
}
unexpectedServers := make(map[spec.ServerName]struct{})
go func() {
for serverName := range workerCh {
switch serverName {
case "localhost":
delete(expectedServers, serverName)
aliceCh <- true // unblock notifyWorkers
case unreachableServer: // this should not happen as it is "filtered" away by the blacklist
unexpectedServers[serverName] = struct{}{}
bobCh <- true
default:
unexpectedServers[serverName] = struct{}{}
}
}
}()
// alice is not blacklisted
updater.notifyWorkers(alice)
// bob is blacklisted
updater.notifyWorkers(string(bob))
for server := range expectedServers {
t.Errorf("Server still in expectedServers map: %s", server)
}
for server := range unexpectedServers {
t.Errorf("unexpected server in result: %s", server)
}
}