Take DBAAS multi-channel publishing Redis modules into use
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
index b032d8f..327946e 100644 (file)
    limitations under the License.
 */
 
+/*
+ * This source code is part of the near-RT RIC (RAN Intelligent Controller)
+ * platform project (RICP).
+ */
+
 package sdlgoredis
 
 import (
        "errors"
        "fmt"
+       "github.com/go-redis/redis"
        "os"
+       "strconv"
        "strings"
-
-       "github.com/go-redis/redis"
+       "sync"
+       "time"
 )
 
 type ChannelNotificationCb func(channel string, payload ...string)
@@ -34,13 +41,50 @@ type intChannels struct {
        exit          chan bool
 }
 
+type sharedCbMap struct {
+       m     sync.Mutex
+       cbMap map[string]ChannelNotificationCb
+}
+
 type DB struct {
-       client       *redis.Client
+       client       RedisClient
+       subscribe    SubscribeFn
        redisModules bool
-       cbMap        map[string]ChannelNotificationCb
+       sCbMap       *sharedCbMap
        ch           intChannels
 }
 
+type Subscriber interface {
+       Channel() <-chan *redis.Message
+       Subscribe(channels ...string) error
+       Unsubscribe(channels ...string) error
+       Close() error
+}
+
+type SubscribeFn func(client RedisClient, channels ...string) Subscriber
+
+type RedisClient interface {
+       Command() *redis.CommandsInfoCmd
+       Close() error
+       Subscribe(channels ...string) *redis.PubSub
+       MSet(pairs ...interface{}) *redis.StatusCmd
+       Do(args ...interface{}) *redis.Cmd
+       MGet(keys ...string) *redis.SliceCmd
+       Del(keys ...string) *redis.IntCmd
+       Keys(pattern string) *redis.StringSliceCmd
+       SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
+       SAdd(key string, members ...interface{}) *redis.IntCmd
+       SRem(key string, members ...interface{}) *redis.IntCmd
+       SMembers(key string) *redis.StringSliceCmd
+       SIsMember(key string, member interface{}) *redis.BoolCmd
+       SCard(key string) *redis.IntCmd
+       PTTL(key string) *redis.DurationCmd
+       Eval(script string, keys []string, args ...interface{}) *redis.Cmd
+       EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
+       ScriptExists(scripts ...string) *redis.BoolSliceCmd
+       ScriptLoad(script string) *redis.StringCmd
+}
+
 func checkResultAndError(result interface{}, err error) (bool, error) {
        if err != nil {
                if err == redis.Nil {
@@ -50,24 +94,48 @@ func checkResultAndError(result interface{}, err error) (bool, error) {
        }
        if result == "OK" {
                return true, nil
-       } else {
-               return false, nil
        }
+       return false, nil
 }
 
 func checkIntResultAndError(result interface{}, err error) (bool, error) {
        if err != nil {
                return false, err
        }
-       if result.(int64) == 1 {
-               return true, nil
-       } else {
-               return false, nil
+       if n, ok := result.(int64); ok {
+               if n == 1 {
+                       return true, nil
+               }
+       } else if n, ok := result.(int); ok {
+               if n == 1 {
+                       return true, nil
+               }
        }
+       return false, nil
+}
 
+func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
+       return client.Subscribe(channels...)
+}
+
+func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
+       db := DB{
+               client:       client,
+               subscribe:    subscribe,
+               redisModules: true,
+               sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
+               ch: intChannels{
+                       addChannel:    make(chan string),
+                       removeChannel: make(chan string),
+                       exit:          make(chan bool),
+               },
+       }
+
+       return &db
 }
 
 func Create() *DB {
+       var client *redis.Client
        hostname := os.Getenv("DBAAS_SERVICE_HOST")
        if hostname == "" {
                hostname = "localhost"
@@ -76,28 +144,36 @@ func Create() *DB {
        if port == "" {
                port = "6379"
        }
-       redisAddress := hostname + ":" + port
-       client := redis.NewClient(&redis.Options{
-               Addr:     redisAddress,
-               Password: "", // no password set
-               DB:       0,  // use default DB
-               PoolSize: 20,
-       })
-
-       db := DB{
-               client:       client,
-               redisModules: true,
-               cbMap:        make(map[string]ChannelNotificationCb, 0),
-               ch: intChannels{
-                       addChannel:    make(chan string),
-                       removeChannel: make(chan string),
-                       exit:          make(chan bool),
-               },
+       sentinelPort := os.Getenv("DBAAS_SERVICE_SENTINEL_PORT")
+       masterName := os.Getenv("DBAAS_MASTER_NAME")
+       if sentinelPort == "" {
+               redisAddress := hostname + ":" + port
+               client = redis.NewClient(&redis.Options{
+                       Addr:       redisAddress,
+                       Password:   "", // no password set
+                       DB:         0,  // use default DB
+                       PoolSize:   20,
+                       MaxRetries: 2,
+               })
+       } else {
+               sentinelAddress := hostname + ":" + sentinelPort
+               client = redis.NewFailoverClient(&redis.FailoverOptions{
+                       MasterName:    masterName,
+                       SentinelAddrs: []string{sentinelAddress},
+                       PoolSize:      20,
+                       MaxRetries:    2,
+               })
        }
+       db := CreateDB(client, subscribeNotifications)
+       db.CheckCommands()
+       return db
+}
 
+func (db *DB) CheckCommands() {
        commands, err := db.client.Command().Result()
        if err == nil {
-               redisModuleCommands := []string{"setie", "delie", "msetpub", "setiepub", "setnxpub", "delpub"}
+               redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
+                       "msetmpub", "delmpub"}
                for _, v := range redisModuleCommands {
                        _, ok := commands[v]
                        if !ok {
@@ -107,7 +183,6 @@ func Create() *DB {
        } else {
                fmt.Println(err)
        }
-       return &db
 }
 
 func (db *DB) CloseDB() error {
@@ -116,37 +191,40 @@ func (db *DB) CloseDB() error {
 
 func (db *DB) UnsubscribeChannelDB(channels ...string) {
        for _, v := range channels {
+               db.sCbMap.Remove(v)
                db.ch.removeChannel <- v
-               delete(db.cbMap, v)
-               if len(db.cbMap) == 0 {
+               if db.sCbMap.Count() == 0 {
                        db.ch.exit <- true
                }
        }
 }
 
-func (db *DB) SubscribeChannelDB(cb ChannelNotificationCb, channelPrefix, eventSeparator string, channels ...string) {
-       if len(db.cbMap) == 0 {
+func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
+       if db.sCbMap.Count() == 0 {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                }
 
-               go func(cbMap *map[string]ChannelNotificationCb,
+               go func(sCbMap *sharedCbMap,
                        channelPrefix,
                        eventSeparator string,
                        ch intChannels,
                        channels ...string) {
-                       sub := db.client.Subscribe(channels...)
+                       sub := db.subscribe(db.client, channels...)
                        rxChannel := sub.Channel()
+                       lCbMap := sCbMap.GetMapCopy()
                        for {
                                select {
                                case msg := <-rxChannel:
-                                       cb, ok := (*cbMap)[msg.Channel]
+                                       cb, ok := lCbMap[msg.Channel]
                                        if ok {
                                                cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
                                        }
                                case channel := <-ch.addChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Subscribe(channel)
                                case channel := <-ch.removeChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Unsubscribe(channel)
                                case exit := <-ch.exit:
                                        if exit {
@@ -157,11 +235,11 @@ func (db *DB) SubscribeChannelDB(cb ChannelNotificationCb, channelPrefix, eventS
                                        }
                                }
                        }
-               }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
+               }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
 
        } else {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                        db.ch.addChannel <- v
                }
        }
@@ -171,16 +249,20 @@ func (db *DB) MSet(pairs ...interface{}) error {
        return db.client.MSet(pairs...).Err()
 }
 
-func (db *DB) MSetPub(channel, message string, pairs ...interface{}) error {
+func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
        if !db.redisModules {
-               return errors.New("Redis deployment doesn't support MSETPUB command")
+               return errors.New("Redis deployment doesn't support MSETMPUB command")
        }
        command := make([]interface{}, 0)
-       command = append(command, "MSETPUB")
+       command = append(command, "MSETMPUB")
+       command = append(command, len(pairs)/2)
+       command = append(command, len(channelsAndEvents)/2)
        for _, d := range pairs {
                command = append(command, d)
        }
-       command = append(command, channel, message)
+       for _, d := range channelsAndEvents {
+               command = append(command, d)
+       }
        _, err := db.client.Do(command...).Result()
        return err
 }
@@ -189,18 +271,23 @@ func (db *DB) MGet(keys []string) ([]interface{}, error) {
        return db.client.MGet(keys...).Result()
 }
 
-func (db *DB) DelPub(channel, message string, keys []string) error {
+func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
        if !db.redisModules {
-               return errors.New("Redis deployment not supporting command DELPUB")
+               return errors.New("Redis deployment not supporting command DELMPUB")
        }
        command := make([]interface{}, 0)
-       command = append(command, "DELPUB")
+       command = append(command, "DELMPUB")
+       command = append(command, len(keys))
+       command = append(command, len(channelsAndEvents)/2)
        for _, d := range keys {
                command = append(command, d)
        }
-       command = append(command, channel, message)
+       for _, d := range channelsAndEvents {
+               command = append(command, d)
+       }
        _, err := db.client.Do(command...).Result()
        return err
+
 }
 
 func (db *DB) Del(keys []string) error {
@@ -220,28 +307,53 @@ func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
        return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
 }
 
-func (db *DB) SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error) {
+func (db *DB) SetIEPub(channelsAndEvents []string, key string, oldData, newData interface{}) (bool, error) {
        if !db.redisModules {
-               return false, errors.New("Redis deployment not supporting command SETIEPUB")
+               return false, errors.New("Redis deployment not supporting command SETIEMPUB")
+       }
+       capacity := 4 + len(channelsAndEvents)
+       command := make([]interface{}, 0, capacity)
+       command = append(command, "SETIEMPUB")
+       command = append(command, key)
+       command = append(command, newData)
+       command = append(command, oldData)
+       for _, ce := range channelsAndEvents {
+               command = append(command, ce)
        }
-       return checkResultAndError(db.client.Do("SETIEPUB", key, newData, oldData, channel, message).Result())
+       return checkResultAndError(db.client.Do(command...).Result())
 }
 
-func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, error) {
+func (db *DB) SetNXPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
        if !db.redisModules {
-               return false, errors.New("Redis deployment not supporting command SETNXPUB")
+               return false, errors.New("Redis deployment not supporting command SETNXMPUB")
        }
-       return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
+       capacity := 3 + len(channelsAndEvents)
+       command := make([]interface{}, 0, capacity)
+       command = append(command, "SETNXMPUB")
+       command = append(command, key)
+       command = append(command, data)
+       for _, ce := range channelsAndEvents {
+               command = append(command, ce)
+       }
+       return checkResultAndError(db.client.Do(command...).Result())
 }
-func (db *DB) SetNX(key string, data interface{}) (bool, error) {
-       return db.client.SetNX(key, data, 0).Result()
+func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
+       return db.client.SetNX(key, data, expiration).Result()
 }
 
-func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
+func (db *DB) DelIEPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
        if !db.redisModules {
-               return false, errors.New("Redis deployment not supporting command")
+               return false, errors.New("Redis deployment not supporting command DELIEMPUB")
        }
-       return checkIntResultAndError(db.client.Do("DELIEPUB", key, data, channel, message).Result())
+       capacity := 3 + len(channelsAndEvents)
+       command := make([]interface{}, 0, capacity)
+       command = append(command, "DELIEMPUB")
+       command = append(command, key)
+       command = append(command, data)
+       for _, ce := range channelsAndEvents {
+               command = append(command, ce)
+       }
+       return checkIntResultAndError(db.client.Do(command...).Result())
 }
 
 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
@@ -275,3 +387,50 @@ func (db *DB) SCard(key string) (int64, error) {
        result, err := db.client.SCard(key).Result()
        return result, err
 }
+
+func (db *DB) PTTL(key string) (time.Duration, error) {
+       result, err := db.client.PTTL(key).Result()
+       return result, err
+}
+
+var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
+
+func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
+       expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
+       result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
+       if err != nil {
+               return err
+       }
+       if result == int64(1) {
+               return nil
+       }
+       return errors.New("Lock not held")
+}
+
+func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       sCbMap.cbMap[channel] = cb
+}
+
+func (sCbMap *sharedCbMap) Remove(channel string) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       delete(sCbMap.cbMap, channel)
+}
+
+func (sCbMap *sharedCbMap) Count() int {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       return len(sCbMap.cbMap)
+}
+
+func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       mapCopy := make(map[string]ChannelNotificationCb, 0)
+       for i, v := range sCbMap.cbMap {
+               mapCopy[i] = v
+       }
+       return mapCopy
+}