package federationapi import ( "bytes" "context" "crypto/ed25519" "encoding/json" "fmt" "io" "net/http" "os" "testing" "time" "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/setup/jetstream" "github.com/element-hq/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/element-hq/dendrite/federationapi/api" "github.com/element-hq/dendrite/federationapi/routing" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/setup/config" ) type server struct { name spec.ServerName // server name validity time.Duration // key validity duration from now config *config.FederationAPI // skeleton config, from TestMain fedclient fclient.FederationClient // uses MockRoundTripper cache *caching.Caches // server-specific cache api api.FederationInternalAPI // server-specific server key API } func (s *server) renew() { // This updates the validity period to be an hour in the // future, which is particularly useful in server A and // server C's cases which have validity either as now or // in the past. s.validity = time.Hour s.config.Matrix.KeyValidityPeriod = s.validity } var ( serverKeyID = gomatrixserverlib.KeyID("ed25519:auto") serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now serverB = &server{name: "b.com", validity: time.Hour} // expires in an hour serverC = &server{name: "c.com", validity: -time.Hour} // expired an hour ago ) var servers = map[string]*server{ "a.com": serverA, "b.com": serverB, "c.com": serverC, } func TestMain(m *testing.M) { // Set up the server key API for each "server" that we // will use in our tests. os.Exit(func() int { for _, s := range servers { // Generate a new key. _, testPriv, err := ed25519.GenerateKey(nil) if err != nil { panic("can't generate identity key: " + err.Error()) } // Create a new cache but don't enable prometheus! s.cache = caching.NewRistrettoCache(8*1024*1024, time.Hour, false) natsInstance := jetstream.NATSInstance{} // Create a temporary directory for JetStream. d, err := os.MkdirTemp("./", "jetstream*") if err != nil { panic(err) } defer os.RemoveAll(d) // Draw up just enough Dendrite config for the server key // API to work. cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ Generate: true, SingleDatabase: false, }) cfg.Global.ServerName = spec.ServerName(s.name) cfg.Global.PrivateKey = testPriv cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.TopicPrefix = string(s.name[:1]) cfg.Global.JetStream.StoragePath = config.Path(d) cfg.Global.KeyID = serverKeyID cfg.Global.KeyValidityPeriod = s.validity cfg.FederationAPI.KeyPerspectives = nil f, err := os.CreateTemp(d, "federation_keys_test*.db") if err != nil { return -1 } defer f.Close() cfg.FederationAPI.Database.ConnectionString = config.DataSource("file:" + f.Name()) s.config = &cfg.FederationAPI // Create a transport which redirects federation requests to // the mock round tripper. Since we're not *really* listening for // federation requests then this will return the key instead. transport := &http.Transport{} transport.RegisterProtocol("matrix", &MockRoundTripper{}) // Create the federation client. s.fedclient = fclient.NewFederationClient( s.config.Matrix.SigningIdentities(), fclient.WithTransport(transport), ) // Finally, build the server key APIs. processCtx := process.NewProcessContext() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) s.api = NewInternalAPI(processCtx, cfg, cm, &natsInstance, s.fedclient, nil, s.cache, nil, true) } // Now that we have built our server key APIs, start the // rest of the tests. return m.Run() }()) } type MockRoundTripper struct{} func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) { // Check if the request is looking for keys from a server that // we know about in the test. The only reason this should go wrong // is if the test is broken. s, ok := servers[req.Host] if !ok { return nil, fmt.Errorf("server not known: %s", req.Host) } // We're intercepting /matrix/key/v2/server requests here, so check // that the URL supplied in the request is for that. if req.URL.Path != "/_matrix/key/v2/server" { return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path) } // Get the keys and JSON-ify them. keys := routing.LocalKeys(s.config, spec.ServerName(req.Host)) body, err := json.MarshalIndent(keys.JSON, "", " ") if err != nil { return nil, err } // And respond. res = &http.Response{ StatusCode: 200, Body: io.NopCloser(bytes.NewReader(body)), } return } func TestServersRequestOwnKeys(t *testing.T) { // Each server will request its own keys. There's no reason // for this to fail as each server should know its own keys. for name, s := range servers { req := gomatrixserverlib.PublicKeyLookupRequest{ ServerName: s.name, KeyID: serverKeyID, } res, err := s.api.FetchKeys( context.Background(), map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ req: spec.AsTimestamp(time.Now()), }, ) if err != nil { t.Fatalf("server could not fetch own key: %s", err) } if _, ok := res[req]; !ok { t.Fatalf("server didn't return its own key in the results") } t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time()) } } func TestRenewalBehaviour(t *testing.T) { // Server A will request Server C's key but their validity period // is an hour in the past. We'll retrieve the key as, even though it's // past its validity, it will be able to verify past events. req := gomatrixserverlib.PublicKeyLookupRequest{ ServerName: serverC.name, KeyID: serverKeyID, } res, err := serverA.api.FetchKeys( context.Background(), map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ req: spec.AsTimestamp(time.Now()), }, ) if err != nil { t.Fatalf("server A failed to retrieve server C key: %s", err) } if len(res) != 1 { t.Fatalf("server C should have returned one key but instead returned %d keys", len(res)) } if _, ok := res[req]; !ok { t.Fatalf("server C isn't included in the key fetch response") } originalValidity := res[req].ValidUntilTS // We're now going to kick server C into renewing its key. Since we're // happy at this point that the key that we already have is from the past // then repeating a key fetch should cause us to try and renew the key. // If so, then the new key will end up in our cache. serverC.renew() res, err = serverA.api.FetchKeys( context.Background(), map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ req: spec.AsTimestamp(time.Now()), }, ) if err != nil { t.Fatalf("server A failed to retrieve server C key: %s", err) } if len(res) != 1 { t.Fatalf("server C should have returned one key but instead returned %d keys", len(res)) } if _, ok := res[req]; !ok { t.Fatalf("server C isn't included in the key fetch response") } currentValidity := res[req].ValidUntilTS if originalValidity == currentValidity { t.Fatalf("server C key should have renewed but didn't") } }