Use TransactionWriter in roomserver SQLite (#1208)
This commit is contained in:
parent
489f34fed7
commit
d76eb1b994
@ -49,13 +49,16 @@ const bulkSelectEventJSONSQL = `
|
||||
|
||||
type eventJSONStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventJSONStmt *sql.Stmt
|
||||
bulkSelectEventJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||
s := &eventJSONStatements{}
|
||||
s.db = db
|
||||
s := &eventJSONStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(eventJSONSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -69,8 +72,10 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||
func (s *eventJSONStatements) InsertEventJSON(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
|
@ -64,6 +64,7 @@ const bulkSelectEventStateKeyNIDSQL = `
|
||||
|
||||
type eventStateKeyStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventStateKeyNIDStmt *sql.Stmt
|
||||
selectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||
@ -71,8 +72,10 @@ type eventStateKeyStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
||||
s := &eventStateKeyStatements{}
|
||||
s.db = db
|
||||
s := &eventStateKeyStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(eventStateKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -89,12 +92,18 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
var err error
|
||||
var res sql.Result
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventStateKeyNID, err = res.LastInsertId()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
|
@ -78,6 +78,7 @@ const bulkSelectEventTypeNIDSQL = `
|
||||
|
||||
type eventTypeStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventTypeNIDStmt *sql.Stmt
|
||||
insertEventTypeNIDResultStmt *sql.Stmt
|
||||
selectEventTypeNIDStmt *sql.Stmt
|
||||
@ -85,8 +86,10 @@ type eventTypeStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
||||
s := &eventTypeStatements{}
|
||||
s.db = db
|
||||
s := &eventTypeStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(eventTypesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -104,12 +107,15 @@ func (s *eventTypeStatements) InsertEventTypeNID(
|
||||
ctx context.Context, tx *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
var err error
|
||||
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
|
||||
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
|
||||
if _, err = insertStmt.ExecContext(ctx, eventType); err == nil {
|
||||
err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
|
||||
}
|
||||
err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
|
||||
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
|
||||
_, err := insertStmt.ExecContext(ctx, eventType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
|
||||
})
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
|
@ -99,6 +99,7 @@ const selectRoomNIDForEventNIDSQL = "" +
|
||||
|
||||
type eventStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventStmt *sql.Stmt
|
||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||
@ -115,8 +116,10 @@ type eventStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
||||
s := &eventStatements{}
|
||||
s.db = db
|
||||
s := &eventStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(eventsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -151,19 +154,23 @@ func (s *eventStatements) InsertEvent(
|
||||
depth int64,
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
// attempt to insert: the last_row_id is the event NID
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
result, err := insertStmt.ExecContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
modified, err := result.RowsAffected()
|
||||
if modified == 0 && err == nil {
|
||||
return 0, 0, sql.ErrNoRows
|
||||
}
|
||||
eventNID, err := result.LastInsertId()
|
||||
var eventNID int64
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
result, err := insertStmt.ExecContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modified, err := result.RowsAffected()
|
||||
if modified == 0 && err == nil {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
eventNID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
return types.EventNID(eventNID), 0, err
|
||||
}
|
||||
|
||||
@ -279,8 +286,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
func (s *eventStatements) UpdateEventState(
|
||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||
return err
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEventSentToOutput(
|
||||
@ -288,17 +297,15 @@ func (s *eventStatements) SelectEventSentToOutput(
|
||||
) (sentToOutput bool, err error) {
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt)
|
||||
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
|
||||
//err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
|
||||
if err != nil {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
||||
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
||||
//_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID))
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEventID(
|
||||
|
@ -63,6 +63,8 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
|
||||
`
|
||||
|
||||
type inviteStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteActiveForUserInRoomStmt *sql.Stmt
|
||||
updateInviteRetiredStmt *sql.Stmt
|
||||
@ -70,7 +72,10 @@ type inviteStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||
s := &inviteStatements{}
|
||||
s := &inviteStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(inviteSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -90,42 +95,48 @@ func (s *inviteStatements) InsertInviteEvent(
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
result, err := stmt.ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
count, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count != 0, nil
|
||||
var count int64
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
result, err := stmt.ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count != 0, err
|
||||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventIDs []string, err error) {
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer (func() { err = rows.Close() })()
|
||||
for rows.Next() {
|
||||
var inviteEventID string
|
||||
if err = rows.Scan(&inviteEventID); err != nil {
|
||||
return nil, err
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventIDs = append(eventIDs, inviteEventID)
|
||||
}
|
||||
|
||||
// now retire the invites
|
||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
defer (func() { err = rows.Close() })()
|
||||
for rows.Next() {
|
||||
var inviteEventID string
|
||||
if err = rows.Scan(&inviteEventID); err != nil {
|
||||
return err
|
||||
}
|
||||
eventIDs = append(eventIDs, inviteEventID)
|
||||
}
|
||||
// now retire the invites
|
||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -76,6 +76,8 @@ const updateMembershipSQL = "" +
|
||||
" WHERE room_nid = $4 AND target_nid = $5"
|
||||
|
||||
type membershipStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipForUpdateStmt *sql.Stmt
|
||||
selectMembershipFromRoomAndTargetStmt *sql.Stmt
|
||||
@ -87,7 +89,10 @@ type membershipStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
s := &membershipStatements{}
|
||||
s := &membershipStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(membershipSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -110,9 +115,11 @@ func (s *membershipStatements) InsertMembership(
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
localTarget bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||
@ -194,9 +201,11 @@ func (s *membershipStatements) UpdateMembership(
|
||||
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
|
||||
)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
@ -53,12 +53,17 @@ const selectPreviousEventExistsSQL = `
|
||||
`
|
||||
|
||||
type previousEventStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertPreviousEventStmt *sql.Stmt
|
||||
selectPreviousEventExistsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
||||
s := &previousEventStatements{}
|
||||
s := &previousEventStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(previousEventSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -77,11 +82,13 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
||||
previousEventReferenceSHA256 []byte,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||
)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Check if the event reference exists
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
@ -43,13 +44,18 @@ const selectPublishedSQL = "" +
|
||||
"SELECT published FROM roomserver_published WHERE room_id = $1"
|
||||
|
||||
type publishedStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
upsertPublishedStmt *sql.Stmt
|
||||
selectAllPublishedStmt *sql.Stmt
|
||||
selectPublishedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||
s := &publishedStatements{}
|
||||
s := &publishedStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(publishedSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -64,8 +70,10 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||
func (s *publishedStatements) UpsertRoomPublished(
|
||||
ctx context.Context, roomID string, published bool,
|
||||
) (err error) {
|
||||
_, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
|
||||
return
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
|
@ -52,6 +52,8 @@ const markRedactionValidatedSQL = "" +
|
||||
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
|
||||
|
||||
type redactionStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertRedactionStmt *sql.Stmt
|
||||
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
||||
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
|
||||
@ -59,7 +61,10 @@ type redactionStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||
s := &redactionStatements{}
|
||||
s := &redactionStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(redactionsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -76,9 +81,11 @@ func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||
func (s *redactionStatements) InsertRedaction(
|
||||
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
||||
@ -114,7 +121,9 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
|
||||
func (s *redactionStatements) MarkRedactionValidated(
|
||||
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
@ -55,6 +56,8 @@ const deleteRoomAliasSQL = `
|
||||
`
|
||||
|
||||
type roomAliasesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertRoomAliasStmt *sql.Stmt
|
||||
selectRoomIDFromAliasStmt *sql.Stmt
|
||||
selectAliasesFromRoomIDStmt *sql.Stmt
|
||||
@ -63,7 +66,10 @@ type roomAliasesStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{}
|
||||
s := &roomAliasesStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(roomAliasesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -80,8 +86,10 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
func (s *roomAliasesStatements) InsertRoomAlias(
|
||||
ctx context.Context, alias string, roomID string, creatorUserID string,
|
||||
) (err error) {
|
||||
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||
return
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
@ -130,6 +138,8 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
func (s *roomAliasesStatements) DeleteRoomAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (err error) {
|
||||
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
||||
return
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
@ -64,6 +64,8 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
selectLatestEventNIDsStmt *sql.Stmt
|
||||
@ -74,7 +76,10 @@ type roomStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
s := &roomStatements{}
|
||||
s := &roomStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(roomsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -94,9 +99,12 @@ func (s *roomStatements) InsertRoomNID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (types.RoomNID, error) {
|
||||
var err error
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil {
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
_, err := insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||
return err
|
||||
})
|
||||
if err == nil {
|
||||
return s.SelectRoomNID(ctx, txn, roomID)
|
||||
} else {
|
||||
return types.RoomNID(0), err
|
||||
@ -155,15 +163,17 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||
lastEventSentNID types.EventNID,
|
||||
stateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
roomNID,
|
||||
)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
roomNID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionForRoomID(
|
||||
|
@ -74,6 +74,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||
|
||||
type stateBlockStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertStateDataStmt *sql.Stmt
|
||||
selectNextStateBlockNIDStmt *sql.Stmt
|
||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
@ -81,8 +82,10 @@ type stateBlockStatements struct {
|
||||
}
|
||||
|
||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{}
|
||||
s.db = db
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -104,24 +107,26 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||
return 0, nil
|
||||
}
|
||||
var stateBlockNID types.StateBlockNID
|
||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
return stateBlockNID, nil
|
||||
for _, entry := range entries {
|
||||
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return stateBlockNID, err
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
|
@ -50,13 +50,16 @@ const bulkSelectStateBlockNIDsSQL = "" +
|
||||
|
||||
type stateSnapshotStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertStateStmt *sql.Stmt
|
||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{}
|
||||
s.db = db
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -75,14 +78,19 @@ func (s *stateSnapshotStatements) InsertState(
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||
if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)); err2 == nil {
|
||||
lastRowID, err3 := res.LastInsertId()
|
||||
if err3 != nil {
|
||||
err = err3
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastRowID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(lastRowID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -44,12 +44,17 @@ const selectTransactionEventIDSQL = `
|
||||
`
|
||||
|
||||
type transactionStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertTransactionStmt *sql.Stmt
|
||||
selectTransactionEventIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
|
||||
s := &transactionStatements{}
|
||||
s := &transactionStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(transactionsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -68,11 +73,13 @@ func (s *transactionStatements) InsertTransaction(
|
||||
userID string,
|
||||
eventID string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||
_, err = stmt.ExecContext(
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *transactionStatements) SelectTransactionEventID(
|
||||
|
Loading…
Reference in New Issue
Block a user