import (
"errors"
"fmt"
+ "github.com/go-redis/redis"
"os"
"strconv"
"strings"
+ "sync"
"time"
-
- "github.com/go-redis/redis"
)
type ChannelNotificationCb func(channel string, payload ...string)
exit chan bool
}
+type sharedCbMap struct {
+ m sync.Mutex
+ cbMap map[string]ChannelNotificationCb
+}
+
type DB struct {
client RedisClient
subscribe SubscribeFn
redisModules bool
- cbMap map[string]ChannelNotificationCb
+ sCbMap *sharedCbMap
ch intChannels
}
client: client,
subscribe: subscribe,
redisModules: true,
- cbMap: make(map[string]ChannelNotificationCb, 0),
+ sCbMap: &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
ch: intChannels{
addChannel: make(chan string),
removeChannel: make(chan string),
func (db *DB) UnsubscribeChannelDB(channels ...string) {
for _, v := range channels {
+ db.sCbMap.Remove(v)
db.ch.removeChannel <- v
- delete(db.cbMap, v)
- if len(db.cbMap) == 0 {
+ if db.sCbMap.Count() == 0 {
db.ch.exit <- true
}
}
}
func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
- if len(db.cbMap) == 0 {
+ if db.sCbMap.Count() == 0 {
for _, v := range channels {
- db.cbMap[v] = cb
+ db.sCbMap.Add(v, cb)
}
- go func(cbMap *map[string]ChannelNotificationCb,
+ go func(sCbMap *sharedCbMap,
channelPrefix,
eventSeparator string,
ch intChannels,
channels ...string) {
sub := db.subscribe(db.client, channels...)
rxChannel := sub.Channel()
+ lCbMap := sCbMap.GetMapCopy()
for {
select {
case msg := <-rxChannel:
- cb, ok := (*cbMap)[msg.Channel]
+ cb, ok := lCbMap[msg.Channel]
if ok {
cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
}
case channel := <-ch.addChannel:
+ lCbMap = sCbMap.GetMapCopy()
sub.Subscribe(channel)
case channel := <-ch.removeChannel:
+ lCbMap = sCbMap.GetMapCopy()
sub.Unsubscribe(channel)
case exit := <-ch.exit:
if exit {
}
}
}
- }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
+ }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
} else {
for _, v := range channels {
- db.cbMap[v] = cb
+ db.sCbMap.Add(v, cb)
db.ch.addChannel <- v
}
}
}
return errors.New("Lock not held")
}
+
+func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
+ sCbMap.m.Lock()
+ defer sCbMap.m.Unlock()
+ sCbMap.cbMap[channel] = cb
+}
+
+func (sCbMap *sharedCbMap) Remove(channel string) {
+ sCbMap.m.Lock()
+ defer sCbMap.m.Unlock()
+ delete(sCbMap.cbMap, channel)
+}
+
+func (sCbMap *sharedCbMap) Count() int {
+ sCbMap.m.Lock()
+ defer sCbMap.m.Unlock()
+ return len(sCbMap.cbMap)
+}
+
+func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
+ sCbMap.m.Lock()
+ defer sCbMap.m.Unlock()
+ mapCopy := make(map[string]ChannelNotificationCb, 0)
+ for i, v := range sCbMap.cbMap {
+ mapCopy[i] = v
+ }
+ return mapCopy
+}
receivedChannel2 = channel
}, "{prefix}", "---", "{prefix}channel2")
+ time.Sleep(1 * time.Second)
db.UnsubscribeChannelDB("{prefix}channel1")
ch <- &msg2
db.UnsubscribeChannelDB("{prefix}channel2")
ps.AssertExpectations(t)
}
+func TestSubscribeChannelReDBSubscribeAfterUnsubscribe(t *testing.T) {
+ ps, r, db := setup(true)
+ ch := make(chan *redis.Message)
+ msg := redis.Message{
+ Channel: "{prefix}channel",
+ Pattern: "pattern",
+ Payload: "event",
+ }
+ ps.On("Channel").Return(ch)
+ ps.On("Unsubscribe").Return(nil)
+ ps.On("Close").Return(nil)
+ count := 0
+ receivedChannel := ""
+
+ db.SubscribeChannelDB(func(channel string, payload ...string) {
+ count++
+ receivedChannel = channel
+ }, "{prefix}", "---", "{prefix}channel")
+ ch <- &msg
+ db.UnsubscribeChannelDB("{prefix}channel")
+ time.Sleep(1 * time.Second)
+
+ db.SubscribeChannelDB(func(channel string, payload ...string) {
+ count++
+ receivedChannel = channel
+ }, "{prefix}", "---", "{prefix}channel")
+ ch <- &msg
+ db.UnsubscribeChannelDB("{prefix}channel")
+
+ time.Sleep(1 * time.Second)
+ assert.Equal(t, 2, count)
+ assert.Equal(t, "channel", receivedChannel)
+ r.AssertExpectations(t)
+ ps.AssertExpectations(t)
+}
+
func TestPTTLSuccessfully(t *testing.T) {
_, r, db := setup(true)
expectedKey := "key"