dendrite/test/memory_relay_db.go

134 lines
2.9 KiB
Go

// Copyright 2024 New Vector Ltd.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package test
import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
type InMemoryRelayDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[spec.ServerName][]int64
}
func NewInMemoryRelayDatabase() *InMemoryRelayDatabase {
return &InMemoryRelayDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[spec.ServerName][]int64),
}
}
func (d *InMemoryRelayDatabase) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName spec.ServerName,
nid int64,
) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *InMemoryRelayDatabase) DeleteQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName spec.ServerName,
jsonNIDs []int64,
) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *InMemoryRelayDatabase) SelectQueueEntries(
ctx context.Context,
txn *sql.Tx, serverName spec.ServerName,
limit int,
) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *InMemoryRelayDatabase) SelectQueueEntryCount(
ctx context.Context,
txn *sql.Tx,
serverName spec.ServerName,
) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *InMemoryRelayDatabase) InsertQueueJSON(
ctx context.Context,
txn *sql.Tx,
json string,
) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *InMemoryRelayDatabase) DeleteQueueJSON(
ctx context.Context,
txn *sql.Tx,
nids []int64,
) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *InMemoryRelayDatabase) SelectQueueJSON(
ctx context.Context,
txn *sql.Tx,
jsonNIDs []int64,
) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}