Fix Go routine race condition in DB notification map
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
index 574e7b0..72eaebe 100644 (file)
@@ -25,12 +25,12 @@ package sdlgoredis
 import (
        "errors"
        "fmt"
+       "github.com/go-redis/redis"
        "os"
        "strconv"
        "strings"
+       "sync"
        "time"
-
-       "github.com/go-redis/redis"
 )
 
 type ChannelNotificationCb func(channel string, payload ...string)
@@ -41,11 +41,16 @@ type intChannels struct {
        exit          chan bool
 }
 
+type sharedCbMap struct {
+       m     sync.Mutex
+       cbMap map[string]ChannelNotificationCb
+}
+
 type DB struct {
        client       RedisClient
        subscribe    SubscribeFn
        redisModules bool
-       cbMap        map[string]ChannelNotificationCb
+       sCbMap       *sharedCbMap
        ch           intChannels
 }
 
@@ -112,7 +117,7 @@ func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
                client:       client,
                subscribe:    subscribe,
                redisModules: true,
-               cbMap:        make(map[string]ChannelNotificationCb, 0),
+               sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
                ch: intChannels{
                        addChannel:    make(chan string),
                        removeChannel: make(chan string),
@@ -180,37 +185,40 @@ func (db *DB) CloseDB() error {
 
 func (db *DB) UnsubscribeChannelDB(channels ...string) {
        for _, v := range channels {
+               db.sCbMap.Remove(v)
                db.ch.removeChannel <- v
-               delete(db.cbMap, v)
-               if len(db.cbMap) == 0 {
+               if db.sCbMap.Count() == 0 {
                        db.ch.exit <- true
                }
        }
 }
 
 func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
-       if len(db.cbMap) == 0 {
+       if db.sCbMap.Count() == 0 {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                }
 
-               go func(cbMap *map[string]ChannelNotificationCb,
+               go func(sCbMap *sharedCbMap,
                        channelPrefix,
                        eventSeparator string,
                        ch intChannels,
                        channels ...string) {
                        sub := db.subscribe(db.client, channels...)
                        rxChannel := sub.Channel()
+                       lCbMap := sCbMap.GetMapCopy()
                        for {
                                select {
                                case msg := <-rxChannel:
-                                       cb, ok := (*cbMap)[msg.Channel]
+                                       cb, ok := lCbMap[msg.Channel]
                                        if ok {
                                                cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
                                        }
                                case channel := <-ch.addChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Subscribe(channel)
                                case channel := <-ch.removeChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Unsubscribe(channel)
                                case exit := <-ch.exit:
                                        if exit {
@@ -221,11 +229,11 @@ func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, even
                                        }
                                }
                        }
-               }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
+               }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
 
        } else {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                        db.ch.addChannel <- v
                }
        }
@@ -367,3 +375,31 @@ func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration)
        }
        return errors.New("Lock not held")
 }
+
+func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       sCbMap.cbMap[channel] = cb
+}
+
+func (sCbMap *sharedCbMap) Remove(channel string) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       delete(sCbMap.cbMap, channel)
+}
+
+func (sCbMap *sharedCbMap) Count() int {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       return len(sCbMap.cbMap)
+}
+
+func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       mapCopy := make(map[string]ChannelNotificationCb, 0)
+       for i, v := range sCbMap.cbMap {
+               mapCopy[i] = v
+       }
+       return mapCopy
+}