Keep track of membership in Client API (#159)
* Saving memberships * Removed unused index * Removed useless log * Fixed membership not being saved on the right conditions + added membership removal * Updated outdated comment * Use server lib method + check server name + use new roomserver API * Better handling of events from the room server * Fixed membership removal * Corrected indentation * Fix tests (hopefully) * Replace broken kafka mirror * Apply requested changes on database management * Remove useless check and function * Moved memberships update to the database package * Use new common function * Remove useless function
This commit is contained in:
parent
b06d1124f7
commit
d9b8e5de45
@ -0,0 +1,85 @@
|
||||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const membershipSchema = `
|
||||
-- Stores data about users memberships to rooms.
|
||||
CREATE TABLE IF NOT EXISTS memberships (
|
||||
-- The Matrix user ID localpart for the member
|
||||
localpart TEXT NOT NULL,
|
||||
-- The room this user is a member of
|
||||
room_id TEXT NOT NULL,
|
||||
-- The ID of the join membership event
|
||||
event_id TEXT NOT NULL,
|
||||
|
||||
-- A user can only be member of a room once
|
||||
PRIMARY KEY (localpart, room_id)
|
||||
);
|
||||
|
||||
-- Use index to process deletion by ID more efficiently
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS membership_event_id ON memberships(event_id);
|
||||
`
|
||||
|
||||
const insertMembershipSQL = "" +
|
||||
"INSERT INTO memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT * from memberships WHERE localpart = $1 AND room_id = $2"
|
||||
|
||||
const selectMembershipsByLocalpartSQL = "" +
|
||||
"SELECT room_id FROM memberships WHERE localpart = $1"
|
||||
|
||||
const deleteMembershipsByEventIDsSQL = "" +
|
||||
"DELETE FROM memberships WHERE event_id = ANY($1)"
|
||||
|
||||
type membershipStatements struct {
|
||||
deleteMembershipsByEventIDsStmt *sql.Stmt
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipByEventIDStmt *sql.Stmt
|
||||
selectMembershipsByLocalpartStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(membershipSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) {
|
||||
_, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) {
|
||||
_, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs))
|
||||
return
|
||||
}
|
@ -18,6 +18,7 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
// Import the postgres database driver.
|
||||
@ -26,9 +27,12 @@ import (
|
||||
|
||||
// Database represents an account database
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
db *sql.DB
|
||||
partitions common.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
memberships membershipStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
@ -38,6 +42,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
partitions := common.PartitionOffsetStatements{}
|
||||
if err = partitions.Prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a := accountsStatements{}
|
||||
if err = a.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
@ -46,7 +54,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||
if err = p.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{db, a, p}, nil
|
||||
m := membershipStatements{}
|
||||
if err = m.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{db, partitions, a, p, m, serverName}, nil
|
||||
}
|
||||
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
@ -93,6 +105,85 @@ func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtype
|
||||
return d.accounts.insertAccount(localpart, hash)
|
||||
}
|
||||
|
||||
// PartitionOffsets implements common.PartitionStorer
|
||||
func (d *Database) PartitionOffsets(topic string) ([]common.PartitionOffset, error) {
|
||||
return d.partitions.SelectPartitionOffsets(topic)
|
||||
}
|
||||
|
||||
// SetPartitionOffset implements common.PartitionStorer
|
||||
func (d *Database) SetPartitionOffset(topic string, partition int32, offset int64) error {
|
||||
return d.partitions.UpsertPartitionOffset(topic, partition, offset)
|
||||
}
|
||||
|
||||
// SaveMembership saves the user matching a given localpart as a member of a given
|
||||
// room. It also stores the ID of the `join` membership event.
|
||||
// If a membership already exists between the user and the room, or of the
|
||||
// insert fails, returns the SQL error
|
||||
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error {
|
||||
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
|
||||
}
|
||||
|
||||
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
|
||||
// event ID is included in a given array of events IDs
|
||||
// If the removal fails, or if there is no membership to remove, returns an error
|
||||
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
|
||||
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
|
||||
}
|
||||
|
||||
// UpdateMemberships adds the "join" membership events included in a given state
|
||||
// events array, and removes those which ID is included in a given array of events
|
||||
// IDs. All of the process is run in a transaction, which commits only once/if every
|
||||
// insertion and deletion has been successfully processed.
|
||||
// Returns a SQL error if there was an issue with any part of the process
|
||||
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
|
||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range eventsToAdd {
|
||||
if err := d.newMembership(event, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// newMembership will save a new membership in the database if the given state
|
||||
// event is a "join" membership event
|
||||
// If the event isn't a "join" membership event, does nothing
|
||||
// If an error occurred, returns it
|
||||
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
|
||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only want state events from local users
|
||||
if string(serverName) != string(d.serverName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
eventID := ev.EventID()
|
||||
roomID := ev.RoomID()
|
||||
membership, err := ev.Membership()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only "join" membership events can be considered as new memberships
|
||||
if membership == "join" {
|
||||
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashPassword(plaintext string) (hash string, err error) {
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
||||
return string(hashBytes), err
|
||||
|
@ -0,0 +1,141 @@
|
||||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/dendrite/common/config"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
sarama "gopkg.in/Shopify/sarama.v1"
|
||||
)
|
||||
|
||||
// OutputRoomEvent consumes events that originated in the room server.
|
||||
type OutputRoomEvent struct {
|
||||
roomServerConsumer *common.ContinualConsumer
|
||||
db *accounts.Database
|
||||
query api.RoomserverQueryAPI
|
||||
serverName string
|
||||
}
|
||||
|
||||
// NewOutputRoomEvent creates a new OutputRoomEvent consumer. Call Start() to begin consuming from room servers.
|
||||
func NewOutputRoomEvent(cfg *config.Dendrite, store *accounts.Database) (*OutputRoomEvent, error) {
|
||||
kafkaConsumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomServerURL := cfg.RoomServerURL()
|
||||
|
||||
consumer := common.ContinualConsumer{
|
||||
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
}
|
||||
s := &OutputRoomEvent{
|
||||
roomServerConsumer: &consumer,
|
||||
db: store,
|
||||
query: api.NewRoomserverQueryAPIHTTP(roomServerURL, nil),
|
||||
serverName: string(cfg.Matrix.ServerName),
|
||||
}
|
||||
consumer.ProcessMessage = s.onMessage
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEvent) Start() error {
|
||||
return s.roomServerConsumer.Start()
|
||||
}
|
||||
|
||||
// onMessage is called when the sync server receives a new event from the room server output log.
|
||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||
// sync stream position may race and be incorrectly calculated.
|
||||
func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return nil
|
||||
}
|
||||
|
||||
if output.Type != api.OutputTypeNewRoomEvent {
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"roomserver output log: ignoring unknown output type",
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
ev := output.NewRoomEvent.Event
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": ev.EventID(),
|
||||
"room_id": ev.RoomID(),
|
||||
"type": ev.Type(),
|
||||
}).Info("received event from roomserver")
|
||||
|
||||
events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookupStateEvents looks up the state events that are added by a new event.
|
||||
func (s *OutputRoomEvent) lookupStateEvents(
|
||||
addsStateEventIDs []string, event gomatrixserverlib.Event,
|
||||
) ([]gomatrixserverlib.Event, error) {
|
||||
// Fast path if there aren't any new state events.
|
||||
if len(addsStateEventIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Fast path if the only state event added is the event itself.
|
||||
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
|
||||
return []gomatrixserverlib.Event{event}, nil
|
||||
}
|
||||
|
||||
result := []gomatrixserverlib.Event{}
|
||||
missing := []string{}
|
||||
for _, id := range addsStateEventIDs {
|
||||
// Append the current event in the results if its ID is in the events list
|
||||
if id == event.EventID() {
|
||||
result = append(result, event)
|
||||
} else {
|
||||
// If the event isn't the current one, add it to the list of events
|
||||
// to retrieve from the roomserver
|
||||
missing = append(missing, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Request the missing events from the roomserver
|
||||
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
|
||||
var eventResp api.QueryEventsByIDResponse
|
||||
if err := s.query.QueryEventsByID(&eventReq, &eventResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, eventResp.Events...)
|
||||
|
||||
return result, nil
|
||||
}
|
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
||||
"github.com/matrix-org/dendrite/clientapi/consumers"
|
||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
"github.com/matrix-org/dendrite/clientapi/routing"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -86,6 +87,14 @@ func main() {
|
||||
KeyDatabase: keyDB,
|
||||
}
|
||||
|
||||
consumer, err := consumers.NewOutputRoomEvent(cfg, accountDB)
|
||||
if err != nil {
|
||||
log.Panicf("startup: failed to create room server consumer: %s", err)
|
||||
}
|
||||
if err = consumer.Start(); err != nil {
|
||||
log.Panicf("startup: failed to start room server consumer")
|
||||
}
|
||||
|
||||
log.Info("Starting client API server on ", cfg.Listen.ClientAPI)
|
||||
routing.Setup(
|
||||
http.DefaultServeMux, http.DefaultClient, *cfg, roomserverProducer,
|
||||
|
@ -104,14 +104,17 @@ func startMediaAPI(suffix string, dynamicThumbnails bool) (*exec.Cmd, chan error
|
||||
|
||||
proxyCmd, proxyCmdChan := test.StartProxy(proxyAddr, cfg)
|
||||
|
||||
cmd, cmdChan := test.StartServer(
|
||||
serverType,
|
||||
serverArgs,
|
||||
test.InitDatabase(
|
||||
postgresDatabase,
|
||||
postgresContainerName,
|
||||
databases,
|
||||
)
|
||||
|
||||
cmd, cmdChan := test.CreateBackgroundCommand(
|
||||
filepath.Join(filepath.Dir(os.Args[0]), "dendrite-"+serverType+"-server"),
|
||||
serverArgs,
|
||||
)
|
||||
|
||||
fmt.Printf("==TESTSERVER== STARTED %v -> %v : %v\n", proxyAddr, cfg.Listen.MediaAPI, dir)
|
||||
return cmd, cmdChan, string(cfg.Listen.MediaAPI), proxyCmd, proxyCmdChan, proxyAddr, dir
|
||||
}
|
||||
|
@ -147,9 +147,7 @@ func startSyncServer() (*exec.Cmd, chan error) {
|
||||
testDatabaseName,
|
||||
}
|
||||
|
||||
cmd, cmdChan := test.StartServer(
|
||||
"sync-api",
|
||||
serverArgs,
|
||||
test.InitDatabase(
|
||||
postgresDatabase,
|
||||
postgresContainerName,
|
||||
databases,
|
||||
@ -165,6 +163,11 @@ func startSyncServer() (*exec.Cmd, chan error) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cmd, cmdChan := test.CreateBackgroundCommand(
|
||||
filepath.Join(filepath.Dir(os.Args[0]), "dendrite-sync-api-server"),
|
||||
serverArgs,
|
||||
)
|
||||
|
||||
return cmd, cmdChan
|
||||
}
|
||||
|
||||
|
@ -65,12 +65,8 @@ func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan err
|
||||
return cmd, cmdChan
|
||||
}
|
||||
|
||||
// StartServer creates the database and config file needed for the server to run and
|
||||
// then starts the server. The Cmd being executed is returned. A channel is also returned,
|
||||
// which will have any termination errors sent down it, followed immediately by the channel being closed.
|
||||
// If postgresContainerName is not an empty string, psql will be run from inside that container. If it is
|
||||
// an empty string, psql will be assumed to be in PATH.
|
||||
func StartServer(serverType string, serverArgs []string, postgresDatabase, postgresContainerName string, databases []string) (*exec.Cmd, chan error) {
|
||||
// InitDatabase creates the database and config file needed for the server to run
|
||||
func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
|
||||
if len(databases) > 0 {
|
||||
var dbCmd string
|
||||
var dbArgs []string
|
||||
@ -89,11 +85,6 @@ func StartServer(serverType string, serverArgs []string, postgresDatabase, postg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return CreateBackgroundCommand(
|
||||
filepath.Join(filepath.Dir(os.Args[0]), "dendrite-"+serverType+"-server"),
|
||||
serverArgs,
|
||||
)
|
||||
}
|
||||
|
||||
// StartProxy creates a reverse proxy
|
||||
|
@ -5,7 +5,7 @@ set -eu
|
||||
# The mirror to download kafka from is picked from the list of mirrors at
|
||||
# https://www.apache.org/dyn/closer.cgi?path=/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
|
||||
# TODO: Check the signature since we are downloading over HTTP.
|
||||
MIRROR=http://mirror.ox.ac.uk/sites/rsync.apache.org/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
|
||||
MIRROR=http://apache.mirror.anlx.net/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
|
||||
|
||||
# Only download the kafka if it isn't already downloaded.
|
||||
test -f kafka.tgz || wget $MIRROR -O kafka.tgz
|
||||
|
Loading…
Reference in New Issue
Block a user