diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go index ea71fed9..fafd4cb8 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go @@ -16,6 +16,7 @@ package main import ( "context" + "database/sql" "flag" "net/http" "os" @@ -199,7 +200,21 @@ func (m *monolith) setupFederation() { func (m *monolith) setupKafka() { if m.cfg.Kafka.UseNaffka { - naff, err := naffka.New(&naffka.MemoryDatabase{}) + db, err := sql.Open("postgres", string(m.cfg.Database.Naffka)) + if err != nil { + log.WithFields(log.Fields{ + log.ErrorKey: err, + }).Panic("Failed to open naffka database") + } + + naffkaDB, err := naffka.NewPostgresqlDatabase(db) + if err != nil { + log.WithFields(log.Fields{ + log.ErrorKey: err, + }).Panic("Failed to setup naffka database") + } + + naff, err := naffka.New(naffkaDB) if err != nil { log.WithFields(log.Fields{ log.ErrorKey: err, diff --git a/src/github.com/matrix-org/dendrite/common/config/config.go b/src/github.com/matrix-org/dendrite/common/config/config.go index 65a9f898..79788902 100644 --- a/src/github.com/matrix-org/dendrite/common/config/config.go +++ b/src/github.com/matrix-org/dendrite/common/config/config.go @@ -148,6 +148,8 @@ type Dendrite struct { // The PublicRoomsAPI database stores information used to compute the public // room directory. It is only accessed by the PublicRoomsAPI server. PublicRoomsAPI DataSource `yaml:"public_rooms_api"` + // The Naffka database is used internally by the naffka library, if used. + Naffka DataSource `yaml:"naffka,omitempty"` } `yaml:"database"` // TURN Server Config @@ -386,6 +388,8 @@ func (config *Dendrite) check(monolithic bool) error { if !monolithic { problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server")) } + + checkNotEmpty("database.naffka", string(config.Database.Naffka)) } else { // If we aren't using naffka then we need to have at least one kafka // server to talk to. diff --git a/vendor/manifest b/vendor/manifest index ff61798d..ba6dab4c 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -141,7 +141,7 @@ { "importpath": "github.com/matrix-org/naffka", "repository": "https://github.com/matrix-org/naffka", - "revision": "d28656e34f96a8eeaab53e3b7678c9ce14af5786", + "revision": "662bfd0841d0194bfe0a700d54226bb96eac574d", "branch": "master" }, { diff --git a/vendor/src/github.com/matrix-org/naffka/memorydatabase.go b/vendor/src/github.com/matrix-org/naffka/memorydatabase.go index 05d1f3ee..6166c493 100644 --- a/vendor/src/github.com/matrix-org/naffka/memorydatabase.go +++ b/vendor/src/github.com/matrix-org/naffka/memorydatabase.go @@ -8,7 +8,8 @@ import ( // A MemoryDatabase stores the message history as arrays in memory. // It can be used to run unit tests. // If the process is stopped then any messages that haven't been -// processed by a consumer are lost forever. +// processed by a consumer are lost forever and all offsets become +// invalid. type MemoryDatabase struct { topicsMutex sync.Mutex topics map[string]*memoryDatabaseTopic @@ -58,10 +59,7 @@ func (m *MemoryDatabase) getTopic(topicName string) *memoryDatabaseTopic { // StoreMessages implements Database func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error { - if err := m.getTopic(topic).addMessages(messages); err != nil { - return err - } - return nil + return m.getTopic(topic).addMessages(messages) } // FetchMessages implements Database @@ -73,10 +71,10 @@ func (m *MemoryDatabase) FetchMessages(topic string, startOffset, endOffset int6 if startOffset >= endOffset { return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset) } - if startOffset < -1 { - return nil, fmt.Errorf("start offset %d less than -1", startOffset) + if startOffset < 0 { + return nil, fmt.Errorf("start offset %d less than 0", startOffset) } - return messages[startOffset+1 : endOffset], nil + return messages[startOffset:endOffset], nil } // MaxOffsets implements Database diff --git a/vendor/src/github.com/matrix-org/naffka/naffka.go b/vendor/src/github.com/matrix-org/naffka/naffka.go index d429ffda..e384b04f 100644 --- a/vendor/src/github.com/matrix-org/naffka/naffka.go +++ b/vendor/src/github.com/matrix-org/naffka/naffka.go @@ -13,6 +13,7 @@ import ( // single go process. It implements both the sarama.SyncProducer and the // sarama.Consumer interfaces. This means it can act as a drop in replacement // for kafka for testing or single instance deployment. +// Does not support multiple partitions. type Naffka struct { db Database topicsMutex sync.Mutex @@ -28,6 +29,7 @@ func New(db Database) (*Naffka, error) { } for topicName, offset := range maxOffsets { n.topics[topicName] = &topic{ + db: db, topicName: topicName, nextOffset: offset + 1, } @@ -64,7 +66,7 @@ type Database interface { // So for a given topic the message with offset n+1 is stored after the // the message with offset n. StoreMessages(topic string, messages []Message) error - // FetchMessages fetches all messages with an offset greater than but not + // FetchMessages fetches all messages with an offset greater than and // including startOffset and less than but not including endOffset. // The range of offsets requested must not overlap with those stored by a // concurrent StoreMessages. The message offsets within the requested range @@ -138,6 +140,7 @@ func (n *Naffka) Partitions(topic string) ([]int32, error) { } // ConsumePartition implements sarama.Consumer +// Note: offset is *inclusive*, i.e. it will include the message with that offset. func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) { if partition != 0 { return nil, fmt.Errorf("Unknown partition ID %d", partition) @@ -166,13 +169,16 @@ func (n *Naffka) Close() error { const channelSize = 1024 +// partitionConsumer ensures that all messages written to a particular +// topic, from an offset, get sent in order to a channel. +// Implements sarama.PartitionConsumer type partitionConsumer struct { topic *topic messages chan *sarama.ConsumerMessage - // Whether the consumer is ready for new messages or whether it - // is catching up on historic messages. + // Whether the consumer is in "catchup" mode or not. + // See "catchup" function for details. // Reads and writes to this field are proctected by the topic mutex. - ready bool + catchingUp bool } // AsyncClose implements sarama.PartitionConsumer @@ -201,66 +207,101 @@ func (c *partitionConsumer) HighWaterMarkOffset() int64 { return c.topic.highwaterMark() } -// block writes the message to the consumer blocking until the consumer is ready -// to add the message to the channel. Once the message is successfully added to -// the channel it will catch up by pulling historic messsages from the database. -func (c *partitionConsumer) block(cmsg *sarama.ConsumerMessage) { - c.messages <- cmsg - c.catchup(cmsg.Offset) +// catchup makes the consumer go into "catchup" mode, where messages are read +// from the database instead of directly from producers. +// Once the consumer is up to date, i.e. no new messages in the database, then +// the consumer will go back into normal mode where new messages are written +// directly to the channel. +// Must be called with the c.topic.mutex lock +func (c *partitionConsumer) catchup(fromOffset int64) { + // If we're already in catchup mode or up to date, noop + if c.catchingUp || fromOffset == c.topic.nextOffset { + return + } + + c.catchingUp = true + + // Due to the checks above there can only be one of these goroutines + // running at a time + go func() { + for { + // Check if we're up to date yet. If we are we exit catchup mode. + c.topic.mutex.Lock() + nextOffset := c.topic.nextOffset + if fromOffset == nextOffset { + c.catchingUp = false + c.topic.mutex.Unlock() + return + } + c.topic.mutex.Unlock() + + // Limit the number of messages we request from the database to be the + // capacity of the channel. + if nextOffset > fromOffset+int64(cap(c.messages)) { + nextOffset = fromOffset + int64(cap(c.messages)) + } + // Fetch the messages from the database. + msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset) + if err != nil { + // TODO: Add option to write consumer errors to an errors channel + // as an alternative to logging the errors. + log.Print("Error reading messages: ", err) + // Wait before retrying. + // TODO: Maybe use an exponentional backoff scheme here. + // TODO: This timeout should take account of all the other goroutines + // that might be doing the same thing. (If there are a 10000 consumers + // then we don't want to end up retrying every millisecond) + time.Sleep(10 * time.Second) + continue + } + if len(msgs) == 0 { + // This should only happen if the database is corrupted and has lost the + // messages between the requested offsets. + log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset) + } + + // Pass the messages into the consumer channel. + // Blocking each write until the channel has enough space for the message. + for i := range msgs { + c.messages <- msgs[i].consumerMessage(c.topic.topicName) + } + // Update our the offset for the next loop iteration. + fromOffset = msgs[len(msgs)-1].Offset + 1 + } + }() } -// catchup reads historic messages from the database until the consumer has caught -// up on all the historic messages. -func (c *partitionConsumer) catchup(fromOffset int64) { - for { - // First check if we have caught up. - caughtUp, nextOffset := c.topic.hasCaughtUp(c, fromOffset) - if caughtUp { - return - } - // Limit the number of messages we request from the database to be the - // capacity of the channel. - if nextOffset > fromOffset+int64(cap(c.messages)) { - nextOffset = fromOffset + int64(cap(c.messages)) - } - // Fetch the messages from the database. - msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset) - if err != nil { - // TODO: Add option to write consumer errors to an errors channel - // as an alternative to logging the errors. - log.Print("Error reading messages: ", err) - // Wait before retrying. - // TODO: Maybe use an exponentional backoff scheme here. - // TODO: This timeout should take account of all the other goroutines - // that might be doing the same thing. (If there are a 10000 consumers - // then we don't want to end up retrying every millisecond) - time.Sleep(10 * time.Second) - continue - } - if len(msgs) == 0 { - // This should only happen if the database is corrupted and has lost the - // messages between the requested offsets. - log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset) - } +// notifyNewMessage tells the consumer about a new message +// Must be called with the c.topic.mutex lock +func (c *partitionConsumer) notifyNewMessage(cmsg *sarama.ConsumerMessage) { + // If we're in "catchup" mode then the catchup routine will send the + // message later, since cmsg has already been written to the database + if c.catchingUp { + return + } - // Pass the messages into the consumer channel. - // Blocking each write until the channel has enough space for the message. - for i := range msgs { - c.messages <- msgs[i].consumerMessage(c.topic.topicName) - } - // Update our the offset for the next loop iteration. - fromOffset = msgs[len(msgs)-1].Offset + // Otherwise, lets try writing the message directly to the channel + select { + case c.messages <- cmsg: + default: + // The messages channel has filled up, so lets go into catchup + // mode. Once the channel starts being read from again messages + // will be read from the database + c.catchup(cmsg.Offset) } } type topic struct { - db Database - topicName string - mutex sync.Mutex - consumers []*partitionConsumer + db Database + topicName string + mutex sync.Mutex + consumers []*partitionConsumer + // nextOffset is the offset that will be assigned to the next message in + // this topic, i.e. one greater than the last message offset. nextOffset int64 } +// send writes messages to a topic. func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error { var err error // Encode the message keys and values. @@ -298,21 +339,10 @@ func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error { t.nextOffset = offset // Now notify the consumers about the messages. - for i := range msgs { - cmsg := msgs[i].consumerMessage(t.topicName) + for _, msg := range msgs { + cmsg := msg.consumerMessage(t.topicName) for _, c := range t.consumers { - if c.ready { - select { - case c.messages <- cmsg: - default: - // The consumer wasn't ready to receive a message because - // the channel buffer was full. - // Fork a goroutine to send the message so that we don't - // block sending messages to the other consumers. - c.ready = false - go c.block(cmsg) - } - } + c.notifyNewMessage(cmsg) } } @@ -330,27 +360,17 @@ func (t *topic) consume(offset int64) *partitionConsumer { offset = t.nextOffset } if offset == sarama.OffsetOldest { - offset = -1 + offset = 0 } c.messages = make(chan *sarama.ConsumerMessage, channelSize) t.consumers = append(t.consumers, c) - // Start catching up on historic messages in the background. - go c.catchup(offset) - return c -} -func (t *topic) hasCaughtUp(c *partitionConsumer, offset int64) (bool, int64) { - t.mutex.Lock() - defer t.mutex.Unlock() - // Check if we have caught up while holding a lock on the topic so there - // isn't a way for our check to race with a new message being sent on the topic. - if offset+1 == t.nextOffset { - // We've caught up, the consumer can now receive messages as they are - // sent rather than fetching them from the database. - c.ready = true - return true, t.nextOffset + // If we're not streaming from the latest offset we need to go into + // "catchup" mode + if offset != t.nextOffset { + c.catchup(offset) } - return false, t.nextOffset + return c } func (t *topic) highwaterMark() int64 { diff --git a/vendor/src/github.com/matrix-org/naffka/naffka_test.go b/vendor/src/github.com/matrix-org/naffka/naffka_test.go index d1a26710..dd69c95f 100644 --- a/vendor/src/github.com/matrix-org/naffka/naffka_test.go +++ b/vendor/src/github.com/matrix-org/naffka/naffka_test.go @@ -1,6 +1,7 @@ package naffka import ( + "strconv" "testing" "time" @@ -84,3 +85,142 @@ func TestDelayedReceive(t *testing.T) { t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value)) } } + +func TestCatchup(t *testing.T) { + naffka, err := New(&MemoryDatabase{}) + if err != nil { + t.Fatal(err) + } + producer := sarama.SyncProducer(naffka) + consumer := sarama.Consumer(naffka) + + const topic = "testTopic" + const value = "Hello, World" + + message := sarama.ProducerMessage{ + Value: sarama.StringEncoder(value), + Topic: topic, + } + + if _, _, err = producer.SendMessage(&message); err != nil { + t.Fatal(err) + } + + c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest) + if err != nil { + t.Fatal(err) + } + + var result *sarama.ConsumerMessage + select { + case result = <-c.Messages(): + case _ = <-time.NewTimer(10 * time.Second).C: + t.Fatal("expected to receive a message") + } + + if string(result.Value) != value { + t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value)) + } + + currOffset := result.Offset + + const value2 = "Hello, World2" + const value3 = "Hello, World3" + + _, _, err = producer.SendMessage(&sarama.ProducerMessage{ + Value: sarama.StringEncoder(value2), + Topic: topic, + }) + if err != nil { + t.Fatal(err) + } + + _, _, err = producer.SendMessage(&sarama.ProducerMessage{ + Value: sarama.StringEncoder(value3), + Topic: topic, + }) + if err != nil { + t.Fatal(err) + } + + t.Logf("Streaming from %q", currOffset+1) + + c2, err := consumer.ConsumePartition(topic, 0, currOffset+1) + if err != nil { + t.Fatal(err) + } + + var result2 *sarama.ConsumerMessage + select { + case result2 = <-c2.Messages(): + case _ = <-time.NewTimer(10 * time.Second).C: + t.Fatal("expected to receive a message") + } + + if string(result2.Value) != value2 { + t.Fatalf("wrong value: wanted %q got %q", value2, string(result2.Value)) + } +} + +func TestChannelSaturation(t *testing.T) { + // The channel returned by c.Messages() has a fixed capacity + + naffka, err := New(&MemoryDatabase{}) + if err != nil { + t.Fatal(err) + } + producer := sarama.SyncProducer(naffka) + consumer := sarama.Consumer(naffka) + const topic = "testTopic" + const baseValue = "testValue: " + + c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest) + if err != nil { + t.Fatal(err) + } + + channelSize := cap(c.Messages()) + + // We want to send enough messages to fill up the channel, so lets double + // the size of the channel. And add three in case its a zero sized channel + numberMessagesToSend := 2*channelSize + 3 + + var sentMessages []string + + for i := 0; i < numberMessagesToSend; i++ { + value := baseValue + strconv.Itoa(i) + + message := sarama.ProducerMessage{ + Topic: topic, + Value: sarama.StringEncoder(value), + } + + sentMessages = append(sentMessages, value) + + if _, _, err = producer.SendMessage(&message); err != nil { + t.Fatal(err) + } + } + + var result *sarama.ConsumerMessage + + j := 0 + for ; j < numberMessagesToSend; j++ { + select { + case result = <-c.Messages(): + case _ = <-time.NewTimer(10 * time.Second).C: + t.Fatalf("failed to receive message %d out of %d", j+1, numberMessagesToSend) + } + + expectedValue := sentMessages[j] + if string(result.Value) != expectedValue { + t.Fatalf("wrong value: wanted %q got %q", expectedValue, string(result.Value)) + } + } + + select { + case result = <-c.Messages(): + t.Fatalf("expected to only receive %d messages", numberMessagesToSend) + default: + } +} diff --git a/vendor/src/github.com/matrix-org/naffka/postgresqldatabase.go b/vendor/src/github.com/matrix-org/naffka/postgresqldatabase.go new file mode 100644 index 00000000..d121630c --- /dev/null +++ b/vendor/src/github.com/matrix-org/naffka/postgresqldatabase.go @@ -0,0 +1,296 @@ +package naffka + +import ( + "database/sql" + "sync" + "time" +) + +const postgresqlSchema = ` +-- The topic table assigns each topic a unique numeric ID. +CREATE SEQUENCE IF NOT EXISTS naffka_topic_nid_seq; +CREATE TABLE IF NOT EXISTS naffka_topics ( + topic_name TEXT PRIMARY KEY, + topic_nid BIGINT NOT NULL DEFAULT nextval('naffka_topic_nid_seq') +); + +-- The messages table contains the actual messages. +CREATE TABLE IF NOT EXISTS naffka_messages ( + topic_nid BIGINT NOT NULL, + message_offset BIGINT NOT NULL, + message_key BYTEA NOT NULL, + message_value BYTEA NOT NULL, + message_timestamp_ns BIGINT NOT NULL, + UNIQUE (topic_nid, message_offset) +); +` + +const insertTopicSQL = "" + + "INSERT INTO naffka_topics (topic_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + + " RETURNING (topic_nid)" + +const selectTopicSQL = "" + + "SELECT topic_nid FROM naffka_topics WHERE topic_name = $1" + +const selectTopicsSQL = "" + + "SELECT topic_name, topic_nid FROM naffka_topics" + +const insertMessageSQL = "" + + "INSERT INTO naffka_messages (topic_nid, message_offset, message_key, message_value, message_timestamp_ns)" + + " VALUES ($1, $2, $3, $4, $5)" + +const selectMessagesSQL = "" + + "SELECT message_offset, message_key, message_value, message_timestamp_ns" + + " FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" + + " ORDER BY message_offset ASC" + +const selectMaxOffsetSQL = "" + + "SELECT message_offset FROM naffka_messages WHERE topic_nid = $1" + + " ORDER BY message_offset DESC LIMIT 1" + +type postgresqlDatabase struct { + db *sql.DB + topicsMutex sync.Mutex + topicNIDs map[string]int64 + insertTopicStmt *sql.Stmt + selectTopicStmt *sql.Stmt + selectTopicsStmt *sql.Stmt + insertMessageStmt *sql.Stmt + selectMessagesStmt *sql.Stmt + selectMaxOffsetStmt *sql.Stmt +} + +// NewPostgresqlDatabase creates a new naffka database using a postgresql database. +// Returns an error if there was a problem setting up the database. +func NewPostgresqlDatabase(db *sql.DB) (Database, error) { + var err error + + p := &postgresqlDatabase{ + db: db, + topicNIDs: map[string]int64{}, + } + + if _, err = db.Exec(postgresqlSchema); err != nil { + return nil, err + } + + for _, s := range []struct { + sql string + stmt **sql.Stmt + }{ + {insertTopicSQL, &p.insertTopicStmt}, + {selectTopicSQL, &p.selectTopicStmt}, + {selectTopicsSQL, &p.selectTopicsStmt}, + {insertMessageSQL, &p.insertMessageStmt}, + {selectMessagesSQL, &p.selectMessagesStmt}, + {selectMaxOffsetSQL, &p.selectMaxOffsetStmt}, + } { + *s.stmt, err = db.Prepare(s.sql) + if err != nil { + return nil, err + } + } + return p, nil +} + +// StoreMessages implements Database. +func (p *postgresqlDatabase) StoreMessages(topic string, messages []Message) error { + // Store the messages inside a single database transaction. + return withTransaction(p.db, func(txn *sql.Tx) error { + s := txn.Stmt(p.insertMessageStmt) + topicNID, err := p.assignTopicNID(txn, topic) + if err != nil { + return err + } + for _, m := range messages { + _, err = s.Exec(topicNID, m.Offset, m.Key, m.Value, m.Timestamp.UnixNano()) + if err != nil { + return err + } + } + return nil + }) +} + +// FetchMessages implements Database. +func (p *postgresqlDatabase) FetchMessages(topic string, startOffset, endOffset int64) (messages []Message, err error) { + topicNID, err := p.getTopicNID(nil, topic) + if err != nil { + return + } + rows, err := p.selectMessagesStmt.Query(topicNID, startOffset, endOffset) + if err != nil { + return + } + defer rows.Close() + for rows.Next() { + var ( + offset int64 + key []byte + value []byte + timestampNano int64 + ) + if err = rows.Scan(&offset, &key, &value, ×tampNano); err != nil { + return + } + messages = append(messages, Message{ + Offset: offset, + Key: key, + Value: value, + Timestamp: time.Unix(0, timestampNano), + }) + } + return +} + +// MaxOffsets implements Database. +func (p *postgresqlDatabase) MaxOffsets() (map[string]int64, error) { + topicNames, err := p.selectTopics() + if err != nil { + return nil, err + } + result := map[string]int64{} + for topicName, topicNID := range topicNames { + // Lookup the maximum offset. + maxOffset, err := p.selectMaxOffset(topicNID) + if err != nil { + return nil, err + } + if maxOffset > -1 { + // Don't include the topic if we haven't sent any messages on it. + result[topicName] = maxOffset + } + // Prefill the numeric ID cache. + p.addTopicNIDToCache(topicName, topicNID) + } + return result, nil +} + +// selectTopics fetches the names and numeric IDs for all the topics the +// database is aware of. +func (p *postgresqlDatabase) selectTopics() (map[string]int64, error) { + rows, err := p.selectTopicsStmt.Query() + if err != nil { + return nil, err + } + defer rows.Close() + result := map[string]int64{} + for rows.Next() { + var ( + topicName string + topicNID int64 + ) + if err = rows.Scan(&topicName, &topicNID); err != nil { + return nil, err + } + result[topicName] = topicNID + } + return result, nil +} + +// selectMaxOffset selects the maximum offset for a topic. +// Returns -1 if there aren't any messages for that topic. +// Returns an error if there was a problem talking to the database. +func (p *postgresqlDatabase) selectMaxOffset(topicNID int64) (maxOffset int64, err error) { + err = p.selectMaxOffsetStmt.QueryRow(topicNID).Scan(&maxOffset) + if err == sql.ErrNoRows { + return -1, nil + } + return maxOffset, err +} + +// getTopicNID finds the numeric ID for a topic. +// The txn argument is optional, this can be used outside a transaction +// by setting the txn argument to nil. +func (p *postgresqlDatabase) getTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) { + // Get from the cache. + topicNID = p.getTopicNIDFromCache(topicName) + if topicNID != 0 { + return topicNID, nil + } + // Get from the database + s := p.selectTopicStmt + if txn != nil { + s = txn.Stmt(s) + } + err = s.QueryRow(topicName).Scan(&topicNID) + if err == sql.ErrNoRows { + return 0, nil + } + if err != nil { + return 0, err + } + // Update the shared cache. + p.addTopicNIDToCache(topicName, topicNID) + return topicNID, nil +} + +// assignTopicNID assigns a new numeric ID to a topic. +// The txn argument is mandatory, this is always called inside a transaction. +func (p *postgresqlDatabase) assignTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) { + // Check if we already have a numeric ID for the topic name. + topicNID, err = p.getTopicNID(txn, topicName) + if err != nil { + return 0, err + } + if topicNID != 0 { + return topicNID, err + } + // We don't have a numeric ID for the topic name so we add an entry to the + // topics table. If the insert stmt succeeds then it will return the ID. + err = txn.Stmt(p.insertTopicStmt).QueryRow(topicName).Scan(&topicNID) + if err == sql.ErrNoRows { + // If the insert stmt succeeded, but didn't return any rows then it + // means that someone has added a row for the topic name between us + // selecting it the first time and us inserting our own row. + // (N.B. postgres only returns modified rows when using "RETURNING") + // So we can now just select the row that someone else added. + // TODO: This is probably unnecessary since naffka writes to a topic + // from a single thread. + return p.getTopicNID(txn, topicName) + } + if err != nil { + return 0, err + } + // Update the cache. + p.addTopicNIDToCache(topicName, topicNID) + return topicNID, nil +} + +// getTopicNIDFromCache returns the topicNID from the cache or returns 0 if the +// topic is not in the cache. +func (p *postgresqlDatabase) getTopicNIDFromCache(topicName string) (topicNID int64) { + p.topicsMutex.Lock() + defer p.topicsMutex.Unlock() + return p.topicNIDs[topicName] +} + +// addTopicNIDToCache adds the numeric ID for the topic to the cache. +func (p *postgresqlDatabase) addTopicNIDToCache(topicName string, topicNID int64) { + p.topicsMutex.Lock() + defer p.topicsMutex.Unlock() + p.topicNIDs[topicName] = topicNID +} + +// withTransaction runs a block of code passing in an SQL transaction +// If the code returns an error or panics then the transactions is rolledback +// Otherwise the transaction is committed. +func withTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { + txn, err := db.Begin() + if err != nil { + return + } + defer func() { + if r := recover(); r != nil { + txn.Rollback() + panic(r) + } else if err != nil { + txn.Rollback() + } else { + err = txn.Commit() + } + }() + err = fn(txn) + return +}