Add context to the federationsender database (#231)
This commit is contained in:
parent
dc5dd4c5d2
commit
5ada8872bb
@ -123,8 +123,12 @@ func (s *OutputRoomEvent) processMessage(ore api.OutputNewRoomEvent) error {
|
||||
// TODO: handle EventIDMismatchError and recover the current state by talking
|
||||
// to the roomserver
|
||||
oldJoinedHosts, err := s.db.UpdateRoom(
|
||||
ore.Event.RoomID(), ore.LastSentEventID, ore.Event.EventID(),
|
||||
addsJoinedHosts, ore.RemovesStateEventIDs,
|
||||
context.TODO(),
|
||||
ore.Event.RoomID(),
|
||||
ore.LastSentEventID,
|
||||
ore.Event.EventID(),
|
||||
addsJoinedHosts,
|
||||
ore.RemovesStateEventIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -78,20 +79,29 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) insertJoinedHosts(
|
||||
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomID, eventID string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
|
||||
stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
|
||||
_, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
|
||||
func (s *joinedHostsStatements) deleteJoinedHosts(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) error {
|
||||
stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
|
||||
func (s *joinedHostsStatements) selectJoinedHosts(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID)
|
||||
stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -66,17 +67,22 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||
|
||||
// insertRoom inserts the room if it didn't already exist.
|
||||
// If the room didn't exist then last_event_id is set to the empty string.
|
||||
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
|
||||
_, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID)
|
||||
func (s *roomStatements) insertRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) error {
|
||||
_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||
// The row must already exist in the table. Callers can ensure that the row
|
||||
// exists by calling insertRoom first.
|
||||
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
|
||||
func (s *roomStatements) selectRoomForUpdate(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (string, error) {
|
||||
var lastEventID string
|
||||
err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
|
||||
stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -85,7 +91,10 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
|
||||
|
||||
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
|
||||
// have already been called earlier within the transaction.
|
||||
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
|
||||
_, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID)
|
||||
func (s *roomStatements) updateRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
|
||||
) error {
|
||||
stmt := common.TxStmt(txn, s.updateRoomStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
|
||||
return err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -73,35 +74,38 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
|
||||
// UpdateRoom updates the joined hosts for a room and returns what the joined
|
||||
// hosts were before the update.
|
||||
func (d *Database) UpdateRoom(
|
||||
ctx context.Context,
|
||||
roomID, oldEventID, newEventID string,
|
||||
addHosts []types.JoinedHost,
|
||||
removeHosts []string,
|
||||
) (joinedHosts []types.JoinedHost, err error) {
|
||||
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if err = d.insertRoom(txn, roomID); err != nil {
|
||||
if err = d.insertRoom(ctx, txn, roomID); err != nil {
|
||||
return err
|
||||
}
|
||||
lastSentEventID, err := d.selectRoomForUpdate(txn, roomID)
|
||||
lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lastSentEventID != oldEventID {
|
||||
return types.EventIDMismatchError{lastSentEventID, oldEventID}
|
||||
return types.EventIDMismatchError{
|
||||
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
|
||||
}
|
||||
}
|
||||
joinedHosts, err = d.selectJoinedHosts(txn, roomID)
|
||||
joinedHosts, err = d.selectJoinedHosts(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, add := range addHosts {
|
||||
err = d.insertJoinedHosts(txn, roomID, add.MemberEventID, add.ServerName)
|
||||
err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = d.deleteJoinedHosts(txn, removeHosts); err != nil {
|
||||
if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
|
||||
return err
|
||||
}
|
||||
return d.updateRoom(txn, roomID, newEventID)
|
||||
return d.updateRoom(ctx, txn, roomID, newEventID)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user