Fix sdlcli healthcheck DBAAS status in SEP install
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
index f1c7e62..77dce09 100644 (file)
    limitations under the License.
 */
 
+/*
+ * This source code is part of the near-RT RIC (RAN Intelligent Controller)
+ * platform project (RICP).
+ */
+
 package sdlgoredis
 
 import (
        "errors"
        "fmt"
+       "github.com/go-redis/redis/v7"
+       "io"
+       "log"
        "os"
        "strconv"
        "strings"
+       "sync"
        "time"
-
-       "github.com/go-redis/redis"
 )
 
 type ChannelNotificationCb func(channel string, payload ...string)
+type RedisClientCreator func(addr, port, clusterName string, isHa bool) RedisClient
 
 type intChannels struct {
        addChannel    chan string
@@ -36,12 +44,29 @@ type intChannels struct {
        exit          chan bool
 }
 
+type sharedCbMap struct {
+       m     sync.Mutex
+       cbMap map[string]ChannelNotificationCb
+}
+
+type Config struct {
+       hostname        string
+       port            string
+       masterName      string
+       sentinelPort    string
+       clusterAddrList string
+       nodeCnt         string
+}
+
 type DB struct {
        client       RedisClient
+       sentinel     RedisSentinelCreateCb
        subscribe    SubscribeFn
        redisModules bool
-       cbMap        map[string]ChannelNotificationCb
+       sCbMap       *sharedCbMap
        ch           intChannels
+       cfg          Config
+       addr         string
 }
 
 type Subscriber interface {
@@ -73,6 +98,18 @@ type RedisClient interface {
        EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
        ScriptExists(scripts ...string) *redis.BoolSliceCmd
        ScriptLoad(script string) *redis.StringCmd
+       Info(section ...string) *redis.StringCmd
+}
+
+var dbLogger *log.Logger
+
+func init() {
+       dbLogger = log.New(os.Stdout, "database: ", log.LstdFlags|log.Lshortfile)
+       redis.SetLogger(dbLogger)
+}
+
+func SetDbLogger(out io.Writer) {
+       dbLogger.SetOutput(out)
 }
 
 func checkResultAndError(result interface{}, err error) (bool, error) {
@@ -92,8 +129,14 @@ func checkIntResultAndError(result interface{}, err error) (bool, error) {
        if err != nil {
                return false, err
        }
-       if result.(int64) == int64(1) {
-               return true, nil
+       if n, ok := result.(int64); ok {
+               if n == 1 {
+                       return true, nil
+               }
+       } else if n, ok := result.(int); ok {
+               if n == 1 {
+                       return true, nil
+               }
        }
        return false, nil
 }
@@ -102,57 +145,124 @@ func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
        return client.Subscribe(channels...)
 }
 
-func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
+func CreateDB(client RedisClient, subscribe SubscribeFn, sentinelCreateCb RedisSentinelCreateCb, cfg Config, sentinelAddr string) *DB {
        db := DB{
                client:       client,
+               sentinel:     sentinelCreateCb,
                subscribe:    subscribe,
                redisModules: true,
-               cbMap:        make(map[string]ChannelNotificationCb, 0),
+               sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
                ch: intChannels{
                        addChannel:    make(chan string),
                        removeChannel: make(chan string),
                        exit:          make(chan bool),
                },
+               cfg:  cfg,
+               addr: sentinelAddr,
        }
 
        return &db
 }
 
-func Create() *DB {
-       var client *redis.Client
-       hostname := os.Getenv("DBAAS_SERVICE_HOST")
-       if hostname == "" {
-               hostname = "localhost"
-       }
-       port := os.Getenv("DBAAS_SERVICE_PORT")
-       if port == "" {
-               port = "6379"
-       }
-       sentinelPort := os.Getenv("DBAAS_SERVICE_SENTINEL_PORT")
-       masterName := os.Getenv("DBAAS_MASTER_NAME")
-       if sentinelPort == "" {
-               redisAddress := hostname + ":" + port
-               client = redis.NewClient(&redis.Options{
-                       Addr:       redisAddress,
-                       Password:   "", // no password set
-                       DB:         0,  // use default DB
-                       PoolSize:   20,
-                       MaxRetries: 2,
-               })
+func Create() []*DB {
+       osimpl := osImpl{}
+       return ReadConfigAndCreateDbClients(osimpl, newRedisClient, subscribeNotifications, newRedisSentinel)
+}
+
+func readConfig(osI OS) Config {
+       cfg := Config{
+               hostname:        osI.Getenv("DBAAS_SERVICE_HOST", "localhost"),
+               port:            osI.Getenv("DBAAS_SERVICE_PORT", "6379"),
+               masterName:      osI.Getenv("DBAAS_MASTER_NAME", ""),
+               sentinelPort:    osI.Getenv("DBAAS_SERVICE_SENTINEL_PORT", ""),
+               clusterAddrList: osI.Getenv("DBAAS_CLUSTER_ADDR_LIST", ""),
+               nodeCnt:         osI.Getenv("DBAAS_NODE_COUNT", "1"),
+       }
+       return cfg
+}
+
+type OS interface {
+       Getenv(key string, defValue string) string
+}
+
+type osImpl struct{}
+
+func (osImpl) Getenv(key string, defValue string) string {
+       val := os.Getenv(key)
+       if val == "" {
+               val = defValue
+       }
+       return val
+}
+
+func ReadConfigAndCreateDbClients(osI OS, clientCreator RedisClientCreator,
+       subscribe SubscribeFn,
+       sentinelCreateCb RedisSentinelCreateCb) []*DB {
+       cfg := readConfig(osI)
+       return createDbClients(cfg, clientCreator, subscribe, sentinelCreateCb)
+}
+
+func createDbClients(cfg Config, clientCreator RedisClientCreator,
+       subscribe SubscribeFn,
+       sentinelCreateCb RedisSentinelCreateCb) []*DB {
+       if cfg.clusterAddrList == "" {
+               return []*DB{createLegacyDbClient(cfg, clientCreator, subscribe, sentinelCreateCb)}
+       }
+
+       dbs := []*DB{}
+
+       addrList := strings.Split(cfg.clusterAddrList, ",")
+       for _, addr := range addrList {
+               db := createDbClient(cfg, addr, clientCreator, subscribe, sentinelCreateCb)
+               dbs = append(dbs, db)
+       }
+       return dbs
+}
+
+func createLegacyDbClient(cfg Config, clientCreator RedisClientCreator,
+       subscribe SubscribeFn,
+       sentinelCreateCb RedisSentinelCreateCb) *DB {
+       return createDbClient(cfg, cfg.hostname, clientCreator, subscribe, sentinelCreateCb)
+}
+
+func createDbClient(cfg Config, hostName string, clientCreator RedisClientCreator,
+       subscribe SubscribeFn,
+       sentinelCreateCb RedisSentinelCreateCb) *DB {
+       var client RedisClient
+       var db *DB
+       if cfg.sentinelPort == "" {
+               client = clientCreator(hostName, cfg.port, "", false)
+               db = CreateDB(client, subscribe, nil, cfg, hostName)
        } else {
-               sentinelAddress := hostname + ":" + sentinelPort
-               client = redis.NewFailoverClient(&redis.FailoverOptions{
-                       MasterName:    masterName,
-                       SentinelAddrs: []string{sentinelAddress},
-                       PoolSize:      20,
-                       MaxRetries:    2,
-               })
-       }
-       db := CreateDB(client, subscribeNotifications)
+               client = clientCreator(hostName, cfg.sentinelPort, cfg.masterName, true)
+               db = CreateDB(client, subscribe, sentinelCreateCb, cfg, hostName)
+       }
        db.CheckCommands()
        return db
 }
 
+func newRedisClient(addr, port, clusterName string, isHa bool) RedisClient {
+       if isHa == true {
+               sentinelAddress := addr + ":" + port
+               return redis.NewFailoverClient(
+                       &redis.FailoverOptions{
+                               MasterName:    clusterName,
+                               SentinelAddrs: []string{sentinelAddress},
+                               PoolSize:      20,
+                               MaxRetries:    2,
+                       },
+               )
+       }
+       redisAddress := addr + ":" + port
+       return redis.NewClient(&redis.Options{
+               Addr:       redisAddress,
+               Password:   "", // no password set
+               DB:         0,  // use default DB
+               PoolSize:   20,
+               MaxRetries: 2,
+       })
+}
+
 func (db *DB) CheckCommands() {
        commands, err := db.client.Command().Result()
        if err == nil {
@@ -165,7 +275,7 @@ func (db *DB) CheckCommands() {
                        }
                }
        } else {
-               fmt.Println(err)
+               dbLogger.Printf("SDL DB commands checking failure: %s\n", err)
        }
 }
 
@@ -175,52 +285,55 @@ func (db *DB) CloseDB() error {
 
 func (db *DB) UnsubscribeChannelDB(channels ...string) {
        for _, v := range channels {
+               db.sCbMap.Remove(v)
                db.ch.removeChannel <- v
-               delete(db.cbMap, v)
-               if len(db.cbMap) == 0 {
+               if db.sCbMap.Count() == 0 {
                        db.ch.exit <- true
                }
        }
 }
 
-func (db *DB) SubscribeChannelDB(cb ChannelNotificationCb, channelPrefix, eventSeparator string, channels ...string) {
-       if len(db.cbMap) == 0 {
+func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
+       if db.sCbMap.Count() == 0 {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                }
 
-               go func(cbMap *map[string]ChannelNotificationCb,
+               go func(sCbMap *sharedCbMap,
                        channelPrefix,
                        eventSeparator string,
                        ch intChannels,
                        channels ...string) {
                        sub := db.subscribe(db.client, channels...)
                        rxChannel := sub.Channel()
+                       lCbMap := sCbMap.GetMapCopy()
                        for {
                                select {
                                case msg := <-rxChannel:
-                                       cb, ok := (*cbMap)[msg.Channel]
+                                       cb, ok := lCbMap[msg.Channel]
                                        if ok {
                                                cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
                                        }
                                case channel := <-ch.addChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Subscribe(channel)
                                case channel := <-ch.removeChannel:
+                                       lCbMap = sCbMap.GetMapCopy()
                                        sub.Unsubscribe(channel)
                                case exit := <-ch.exit:
                                        if exit {
                                                if err := sub.Close(); err != nil {
-                                                       fmt.Println(err)
+                                                       dbLogger.Printf("SDL DB channel closing failure: %s\n", err)
                                                }
                                                return
                                        }
                                }
                        }
-               }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
+               }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
 
        } else {
                for _, v := range channels {
-                       db.cbMap[v] = cb
+                       db.sCbMap.Add(v, cb)
                        db.ch.addChannel <- v
                }
        }
@@ -288,28 +401,53 @@ func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
        return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
 }
 
-func (db *DB) SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error) {
+func (db *DB) SetIEPub(channelsAndEvents []string, key string, oldData, newData interface{}) (bool, error) {
        if !db.redisModules {
-               return false, errors.New("Redis deployment not supporting command SETIEPUB")
+               return false, errors.New("Redis deployment not supporting command SETIEMPUB")
+       }
+       capacity := 4 + len(channelsAndEvents)
+       command := make([]interface{}, 0, capacity)
+       command = append(command, "SETIEMPUB")
+       command = append(command, key)
+       command = append(command, newData)
+       command = append(command, oldData)
+       for _, ce := range channelsAndEvents {
+               command = append(command, ce)
        }
-       return checkResultAndError(db.client.Do("SETIEPUB", key, newData, oldData, channel, message).Result())
+       return checkResultAndError(db.client.Do(command...).Result())
 }
 
-func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, error) {
+func (db *DB) SetNXPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
        if !db.redisModules {
-               return false, errors.New("Redis deployment not supporting command SETNXPUB")
+               return false, errors.New("Redis deployment not supporting command SETNXMPUB")
        }
-       return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
+       capacity := 3 + len(channelsAndEvents)
+       command := make([]interface{}, 0, capacity)
+       command = append(command, "SETNXMPUB")
+       command = append(command, key)
+       command = append(command, data)
+       for _, ce := range channelsAndEvents {
+               command = append(command, ce)
+       }
+       return checkResultAndError(db.client.Do(command...).Result())
 }
 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
        return db.client.SetNX(key, data, expiration).Result()
 }
 
-func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
+func (db *DB) DelIEPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
        if !db.redisModules {
-               return false, errors.New("Redis deployment not supporting command")
+               return false, errors.New("Redis deployment not supporting command DELIEMPUB")
        }
-       return checkIntResultAndError(db.client.Do("DELIEPUB", key, data, channel, message).Result())
+       capacity := 3 + len(channelsAndEvents)
+       command := make([]interface{}, 0, capacity)
+       command = append(command, "DELIEMPUB")
+       command = append(command, key)
+       command = append(command, data)
+       for _, ce := range channelsAndEvents {
+               command = append(command, ce)
+       }
+       return checkIntResultAndError(db.client.Do(command...).Result())
 }
 
 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
@@ -349,6 +487,81 @@ func (db *DB) PTTL(key string) (time.Duration, error) {
        return result, err
 }
 
+func (db *DB) Info() (*DbInfo, error) {
+       var info DbInfo
+       resultStr, err := db.client.Info("all").Result()
+       if err != nil {
+               return &info, err
+       }
+
+       result := strings.Split(strings.ReplaceAll(resultStr, "\r\n", "\n"), "\n")
+       err = readRedisInfoReplyFields(result, &info)
+       return &info, err
+}
+
+func readRedisInfoReplyFields(input []string, info *DbInfo) error {
+       for _, line := range input {
+               if idx := strings.Index(line, "role:"); idx != -1 {
+                       roleStr := line[idx+len("role:"):]
+                       if roleStr == "master" {
+                               info.Fields.PrimaryRole = true
+                       }
+               } else if idx := strings.Index(line, "connected_slaves:"); idx != -1 {
+                       cntStr := line[idx+len("connected_slaves:"):]
+                       cnt, err := strconv.ParseUint(cntStr, 10, 32)
+                       if err != nil {
+                               return fmt.Errorf("Info reply error: %s", err.Error())
+                       }
+                       info.Fields.ConnectedReplicaCnt = uint32(cnt)
+               }
+       }
+       return nil
+}
+
+func (db *DB) State() (*DbState, error) {
+       dbState := new(DbState)
+       if db.cfg.sentinelPort != "" {
+               //Establish connection to Redis sentinel. The reason why connection is done
+               //here instead of time of the SDL instance creation is that for the time being
+               //sentinel connection is needed only here to get state information and
+               //state information is needed only by 'sdlcli' hence it is not time critical
+               //and also we want to avoid opening unnecessary TCP connections towards Redis
+               //sentinel for every SDL instance. Now it is done only when 'sdlcli' is used.
+               sentinelClient := db.sentinel(&db.cfg, db.addr)
+               return sentinelClient.GetDbState()
+       } else {
+               info, err := db.Info()
+               if err != nil {
+                       dbState.PrimaryDbState.Err = err
+                       return dbState, err
+               }
+               return db.fillDbStateFromDbInfo(info)
+       }
+}
+
+func (db *DB) fillDbStateFromDbInfo(info *DbInfo) (*DbState, error) {
+       var dbState DbState
+       if info.Fields.PrimaryRole == true {
+               dbState = DbState{
+                       PrimaryDbState: PrimaryDbState{
+                               Fields: PrimaryDbStateFields{
+                                       Role:  "master",
+                                       Flags: "master",
+                               },
+                       },
+               }
+       }
+
+       cnt, err := strconv.Atoi(db.cfg.nodeCnt)
+       if err != nil {
+               dbState.Err = fmt.Errorf("DBAAS_NODE_COUNT configuration value '%s' conversion to integer failed", db.cfg.nodeCnt)
+       } else {
+               dbState.ConfigNodeCnt = cnt
+       }
+
+       return &dbState, dbState.Err
+}
+
 var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
 
 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
@@ -362,3 +575,31 @@ func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration)
        }
        return errors.New("Lock not held")
 }
+
+func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       sCbMap.cbMap[channel] = cb
+}
+
+func (sCbMap *sharedCbMap) Remove(channel string) {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       delete(sCbMap.cbMap, channel)
+}
+
+func (sCbMap *sharedCbMap) Count() int {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       return len(sCbMap.cbMap)
+}
+
+func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
+       sCbMap.m.Lock()
+       defer sCbMap.m.Unlock()
+       mapCopy := make(map[string]ChannelNotificationCb, 0)
+       for i, v := range sCbMap.cbMap {
+               mapCopy[i] = v
+       }
+       return mapCopy
+}