From 33961a269bf51f6c713fcb00576cdc0b9ac98ee9 Mon Sep 17 00:00:00 2001 From: Timo Tietavainen Date: Wed, 7 Oct 2020 07:07:24 +0300 Subject: [PATCH] Fix Go routine race condition in DB notification map Map cbMap was accessible by the main Go routine and also by a Go routine, which handled incoming notifications from database (Redis). Problem was visible, when unit tests were run with '-race' flag. Fix the issue by adding a mutex lock to guard cbMap accesses. Go routine what handles incoming notification has a local lockless copy of the cbMap to make sure that notification handling speed won't be suffering from mutex locks. Local copy of cbMap is updated whenever the cbMap has been updated. Signed-off-by: Timo Tietavainen Change-Id: I066ff1d71340db2240a7ea6aeeb575f098488608 --- docs/release-notes.rst | 6 +++- internal/sdlgoredis/sdlgoredis.go | 60 +++++++++++++++++++++++++++------- internal/sdlgoredis/sdlgoredis_test.go | 37 +++++++++++++++++++++ sdl.go | 2 -- 4 files changed, 90 insertions(+), 15 deletions(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 0c9700f..d2c1c37 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -30,7 +30,11 @@ This document provides the release notes of the sdlgo. Version history --------------- -[v0.5.3] - 2020-08-17 +[0.5.4] - 2020-10-07 + +* Fix Go routine race condition when new DB notifications are subscribed. + +[0.5.3] - 2020-08-17 * Take Redis client version 6.15.9 into use. diff --git a/internal/sdlgoredis/sdlgoredis.go b/internal/sdlgoredis/sdlgoredis.go index 574e7b0..72eaebe 100644 --- a/internal/sdlgoredis/sdlgoredis.go +++ b/internal/sdlgoredis/sdlgoredis.go @@ -25,12 +25,12 @@ package sdlgoredis import ( "errors" "fmt" + "github.com/go-redis/redis" "os" "strconv" "strings" + "sync" "time" - - "github.com/go-redis/redis" ) type ChannelNotificationCb func(channel string, payload ...string) @@ -41,11 +41,16 @@ type intChannels struct { exit chan bool } +type sharedCbMap struct { + m sync.Mutex + cbMap map[string]ChannelNotificationCb +} + type DB struct { client RedisClient subscribe SubscribeFn redisModules bool - cbMap map[string]ChannelNotificationCb + sCbMap *sharedCbMap ch intChannels } @@ -112,7 +117,7 @@ func CreateDB(client RedisClient, subscribe SubscribeFn) *DB { client: client, subscribe: subscribe, redisModules: true, - cbMap: make(map[string]ChannelNotificationCb, 0), + sCbMap: &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)}, ch: intChannels{ addChannel: make(chan string), removeChannel: make(chan string), @@ -180,37 +185,40 @@ func (db *DB) CloseDB() error { func (db *DB) UnsubscribeChannelDB(channels ...string) { for _, v := range channels { + db.sCbMap.Remove(v) db.ch.removeChannel <- v - delete(db.cbMap, v) - if len(db.cbMap) == 0 { + if db.sCbMap.Count() == 0 { db.ch.exit <- true } } } func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) { - if len(db.cbMap) == 0 { + if db.sCbMap.Count() == 0 { for _, v := range channels { - db.cbMap[v] = cb + db.sCbMap.Add(v, cb) } - go func(cbMap *map[string]ChannelNotificationCb, + go func(sCbMap *sharedCbMap, channelPrefix, eventSeparator string, ch intChannels, channels ...string) { sub := db.subscribe(db.client, channels...) rxChannel := sub.Channel() + lCbMap := sCbMap.GetMapCopy() for { select { case msg := <-rxChannel: - cb, ok := (*cbMap)[msg.Channel] + cb, ok := lCbMap[msg.Channel] if ok { cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...) } case channel := <-ch.addChannel: + lCbMap = sCbMap.GetMapCopy() sub.Subscribe(channel) case channel := <-ch.removeChannel: + lCbMap = sCbMap.GetMapCopy() sub.Unsubscribe(channel) case exit := <-ch.exit: if exit { @@ -221,11 +229,11 @@ func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, even } } } - }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...) + }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...) } else { for _, v := range channels { - db.cbMap[v] = cb + db.sCbMap.Add(v, cb) db.ch.addChannel <- v } } @@ -367,3 +375,31 @@ func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) } return errors.New("Lock not held") } + +func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) { + sCbMap.m.Lock() + defer sCbMap.m.Unlock() + sCbMap.cbMap[channel] = cb +} + +func (sCbMap *sharedCbMap) Remove(channel string) { + sCbMap.m.Lock() + defer sCbMap.m.Unlock() + delete(sCbMap.cbMap, channel) +} + +func (sCbMap *sharedCbMap) Count() int { + sCbMap.m.Lock() + defer sCbMap.m.Unlock() + return len(sCbMap.cbMap) +} + +func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb { + sCbMap.m.Lock() + defer sCbMap.m.Unlock() + mapCopy := make(map[string]ChannelNotificationCb, 0) + for i, v := range sCbMap.cbMap { + mapCopy[i] = v + } + return mapCopy +} diff --git a/internal/sdlgoredis/sdlgoredis_test.go b/internal/sdlgoredis/sdlgoredis_test.go index ce72607..4869d8f 100644 --- a/internal/sdlgoredis/sdlgoredis_test.go +++ b/internal/sdlgoredis/sdlgoredis_test.go @@ -715,6 +715,7 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) { receivedChannel2 = channel }, "{prefix}", "---", "{prefix}channel2") + time.Sleep(1 * time.Second) db.UnsubscribeChannelDB("{prefix}channel1") ch <- &msg2 db.UnsubscribeChannelDB("{prefix}channel2") @@ -726,6 +727,42 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) { ps.AssertExpectations(t) } +func TestSubscribeChannelReDBSubscribeAfterUnsubscribe(t *testing.T) { + ps, r, db := setup(true) + ch := make(chan *redis.Message) + msg := redis.Message{ + Channel: "{prefix}channel", + Pattern: "pattern", + Payload: "event", + } + ps.On("Channel").Return(ch) + ps.On("Unsubscribe").Return(nil) + ps.On("Close").Return(nil) + count := 0 + receivedChannel := "" + + db.SubscribeChannelDB(func(channel string, payload ...string) { + count++ + receivedChannel = channel + }, "{prefix}", "---", "{prefix}channel") + ch <- &msg + db.UnsubscribeChannelDB("{prefix}channel") + time.Sleep(1 * time.Second) + + db.SubscribeChannelDB(func(channel string, payload ...string) { + count++ + receivedChannel = channel + }, "{prefix}", "---", "{prefix}channel") + ch <- &msg + db.UnsubscribeChannelDB("{prefix}channel") + + time.Sleep(1 * time.Second) + assert.Equal(t, 2, count) + assert.Equal(t, "channel", receivedChannel) + r.AssertExpectations(t) + ps.AssertExpectations(t) +} + func TestPTTLSuccessfully(t *testing.T) { _, r, db := setup(true) expectedKey := "key" diff --git a/sdl.go b/sdl.go index e5e8076..25194c8 100644 --- a/sdl.go +++ b/sdl.go @@ -93,8 +93,6 @@ func NewSdlInstance(NameSpace string, db *Database) *SdlInstance { //callback as quickly as possible. E.g. reading in callback context should be avoided //and using of Go signals is recommended. Also it should be noted that in case of several //events received from different channels, callbacks are called in series one by one. -// -//This function is NOT SAFE FOR CONCURRENT USE by multiple goroutines. func (s *SdlInstance) SubscribeChannel(cb func(string, ...string), channels ...string) error { s.SubscribeChannelDB(cb, s.nsPrefix, s.eventSeparator, s.setNamespaceToChannels(channels...)...) return nil -- 2.16.6