Fix Go routine race condition in DB notification map 15/4815/3 cherry v0.5.4
authorTimo Tietavainen <timo.tietavainen@nokia.com>
Wed, 7 Oct 2020 04:07:24 +0000 (07:07 +0300)
committerTimo Tietavainen <timo.tietavainen@nokia.com>
Wed, 7 Oct 2020 17:15:48 +0000 (20:15 +0300)
Map cbMap was accessible by the main Go routine and also by a Go routine, which
handled incoming notifications from database (Redis). Problem was visible, when
unit tests were run with '-race' flag. Fix the issue by adding a mutex lock to
guard cbMap accesses. Go routine what handles incoming notification has a local
lockless copy of the cbMap to make sure that notification handling speed won't
be suffering from mutex locks. Local copy of cbMap is updated whenever the
cbMap has been updated.

Signed-off-by: Timo Tietavainen <timo.tietavainen@nokia.com>
Change-Id: I066ff1d71340db2240a7ea6aeeb575f098488608

docs/release-notes.rst
internal/sdlgoredis/sdlgoredis.go
internal/sdlgoredis/sdlgoredis_test.go
sdl.go

index 0c9700f..d2c1c37 100644 (file)
@@ -30,7 +30,11 @@ This document provides the release notes of the sdlgo.
 Version history
 ---------------
 
-[v0.5.3] - 2020-08-17
+[0.5.4] - 2020-10-07
+
+* Fix Go routine race condition when new DB notifications are subscribed.
+
+[0.5.3] - 2020-08-17
 
 * Take Redis client version 6.15.9 into use.
 
index 574e7b0..72eaebe 100644 (file)
@@ -25,12 +25,12 @@ package sdlgoredis
 import (
        "errors"
        "fmt"
+       "github.com/go-redis/redis"
        "os"
        "strconv"
        "strings"
+       "sync"
        "time"
-
-       "github.com/go-redis/redis"
 )
 
 type ChannelNotificationCb func(channel string, payload ...string)
@@ -41,11 +41,16 @@ type intChannels struct {
        exit          chan bool
 }
 
+type sharedCbMap struct {
+       m     sync.Mutex
+       cbMap map[string]ChannelNotificationCb
+}
+
 type DB struct {
        client       RedisClient
        subscribe    SubscribeFn
        redisModules bool
-       cbMap        map[string]ChannelNotificationCb
+       sCbMap       *sharedCbMap
        ch           intChannels
 }
 
@@ -112,7 +117,7 @@ func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
                client:       client,
                subscribe:    subscribe,
                redisModules: true,
-               cbMap:        make(map[string]ChannelNotificationCb, 0),
+               sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
                ch: intChannels{
                        addChannel:    make(chan string),
                        removeChannel: make(chan string),
@@ -180,37 +185,40 @@ func (db *DB) CloseDB() error {
 
 func (db *DB) UnsubscribeChannelDB(channels ...string) {
        for _, v := range channels {
+               db.sCbMap.Remove(v)
                db.ch.removeChannel <- v
-               delete(db.cbMap, v)
-               if len(db.cbMap) == 0 {
+               if db.sCbMap.Count() == 0 {
                        db.ch.exit <- true
                }
        }
 }
 
 func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
-       if len(db.cbMap) == 0 {
+       if db.sCbMap.Count() == 0 {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                }
 
-               go func(cbMap *map[string]ChannelNotificationCb,
+               go func(sCbMap *sharedCbMap,
                        channelPrefix,
                        eventSeparator string,
                        ch intChannels,
                        channels ...string) {
                        sub := db.subscribe(db.client, channels...)
                        rxChannel := sub.Channel()
+                       lCbMap := sCbMap.GetMapCopy()
                        for {
                                select {
                                case msg := <-rxChannel:
-                                       cb, ok := (*cbMap)[msg.Channel]
+                                       cb, ok := lCbMap[msg.Channel]
                                        if ok {
                                                cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
                                        }
                                case channel := <-ch.addChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Subscribe(channel)
                                case channel := <-ch.removeChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Unsubscribe(channel)
                                case exit := <-ch.exit:
                                        if exit {
@@ -221,11 +229,11 @@ func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, even
                                        }
                                }
                        }
-               }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
+               }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
 
        } else {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                        db.ch.addChannel <- v
                }
        }
@@ -367,3 +375,31 @@ func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration)
        }
        return errors.New("Lock not held")
 }
+
+func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       sCbMap.cbMap[channel] = cb
+}
+
+func (sCbMap *sharedCbMap) Remove(channel string) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       delete(sCbMap.cbMap, channel)
+}
+
+func (sCbMap *sharedCbMap) Count() int {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       return len(sCbMap.cbMap)
+}
+
+func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       mapCopy := make(map[string]ChannelNotificationCb, 0)
+       for i, v := range sCbMap.cbMap {
+               mapCopy[i] = v
+       }
+       return mapCopy
+}
index ce72607..4869d8f 100644 (file)
@@ -715,6 +715,7 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) {
                receivedChannel2 = channel
        }, "{prefix}", "---", "{prefix}channel2")
 
+       time.Sleep(1 * time.Second)
        db.UnsubscribeChannelDB("{prefix}channel1")
        ch <- &msg2
        db.UnsubscribeChannelDB("{prefix}channel2")
@@ -726,6 +727,42 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) {
        ps.AssertExpectations(t)
 }
 
+func TestSubscribeChannelReDBSubscribeAfterUnsubscribe(t *testing.T) {
+       ps, r, db := setup(true)
+       ch := make(chan *redis.Message)
+       msg := redis.Message{
+               Channel: "{prefix}channel",
+               Pattern: "pattern",
+               Payload: "event",
+       }
+       ps.On("Channel").Return(ch)
+       ps.On("Unsubscribe").Return(nil)
+       ps.On("Close").Return(nil)
+       count := 0
+       receivedChannel := ""
+
+       db.SubscribeChannelDB(func(channel string, payload ...string) {
+               count++
+               receivedChannel = channel
+       }, "{prefix}", "---", "{prefix}channel")
+       ch <- &msg
+       db.UnsubscribeChannelDB("{prefix}channel")
+       time.Sleep(1 * time.Second)
+
+       db.SubscribeChannelDB(func(channel string, payload ...string) {
+               count++
+               receivedChannel = channel
+       }, "{prefix}", "---", "{prefix}channel")
+       ch <- &msg
+       db.UnsubscribeChannelDB("{prefix}channel")
+
+       time.Sleep(1 * time.Second)
+       assert.Equal(t, 2, count)
+       assert.Equal(t, "channel", receivedChannel)
+       r.AssertExpectations(t)
+       ps.AssertExpectations(t)
+}
+
 func TestPTTLSuccessfully(t *testing.T) {
        _, r, db := setup(true)
        expectedKey := "key"
diff --git a/sdl.go b/sdl.go
index e5e8076..25194c8 100644 (file)
--- a/sdl.go
+++ b/sdl.go
@@ -93,8 +93,6 @@ func NewSdlInstance(NameSpace string, db *Database) *SdlInstance {
 //callback as quickly as possible. E.g. reading in callback context should be avoided
 //and using of Go signals is recommended. Also it should be noted that in case of several
 //events received from different channels, callbacks are called in series one by one.
-//
-//This function is NOT SAFE FOR CONCURRENT USE by multiple goroutines.
 func (s *SdlInstance) SubscribeChannel(cb func(string, ...string), channels ...string) error {
        s.SubscribeChannelDB(cb, s.nsPrefix, s.eventSeparator, s.setNamespaceToChannels(channels...)...)
        return nil