5f649600ce598b2ead2a9dffbbbdf9c1062f5ae6
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
1 /*
2    Copyright (c) 2019 AT&T Intellectual Property.
3    Copyright (c) 2018-2019 Nokia.
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16 */
17
18 /*
19  * This source code is part of the near-RT RIC (RAN Intelligent Controller)
20  * platform project (RICP).
21  */
22
23 package sdlgoredis
24
25 import (
26         "errors"
27         "github.com/go-redis/redis/v7"
28         "io"
29         "log"
30         "os"
31         "strconv"
32         "strings"
33         "sync"
34         "time"
35 )
36
37 type ChannelNotificationCb func(channel string, payload ...string)
38 type RedisClientCreator func(addr, port, clusterName string, isHa bool) RedisClient
39
40 type intChannels struct {
41         addChannel    chan string
42         removeChannel chan string
43         exit          chan bool
44 }
45
46 type sharedCbMap struct {
47         m     sync.Mutex
48         cbMap map[string]ChannelNotificationCb
49 }
50
51 type Config struct {
52         hostname        string
53         port            string
54         masterName      string
55         sentinelPort    string
56         clusterAddrList string
57 }
58
59 type DB struct {
60         client       RedisClient
61         sentinel     RedisSentinelCreateCb
62         subscribe    SubscribeFn
63         redisModules bool
64         sCbMap       *sharedCbMap
65         ch           intChannels
66         cfg          Config
67         addr         string
68 }
69
70 type Subscriber interface {
71         Channel() <-chan *redis.Message
72         Subscribe(channels ...string) error
73         Unsubscribe(channels ...string) error
74         Close() error
75 }
76
77 type SubscribeFn func(client RedisClient, channels ...string) Subscriber
78
79 type RedisClient interface {
80         Command() *redis.CommandsInfoCmd
81         Close() error
82         Subscribe(channels ...string) *redis.PubSub
83         MSet(pairs ...interface{}) *redis.StatusCmd
84         Do(args ...interface{}) *redis.Cmd
85         MGet(keys ...string) *redis.SliceCmd
86         Del(keys ...string) *redis.IntCmd
87         Keys(pattern string) *redis.StringSliceCmd
88         SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
89         SAdd(key string, members ...interface{}) *redis.IntCmd
90         SRem(key string, members ...interface{}) *redis.IntCmd
91         SMembers(key string) *redis.StringSliceCmd
92         SIsMember(key string, member interface{}) *redis.BoolCmd
93         SCard(key string) *redis.IntCmd
94         PTTL(key string) *redis.DurationCmd
95         Eval(script string, keys []string, args ...interface{}) *redis.Cmd
96         EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
97         ScriptExists(scripts ...string) *redis.BoolSliceCmd
98         ScriptLoad(script string) *redis.StringCmd
99         Info(section ...string) *redis.StringCmd
100 }
101
102 var dbLogger *log.Logger
103
104 func init() {
105         dbLogger = log.New(os.Stdout, "database: ", log.LstdFlags|log.Lshortfile)
106         redis.SetLogger(dbLogger)
107 }
108
109 func SetDbLogger(out io.Writer) {
110         dbLogger.SetOutput(out)
111 }
112
113 func checkResultAndError(result interface{}, err error) (bool, error) {
114         if err != nil {
115                 if err == redis.Nil {
116                         return false, nil
117                 }
118                 return false, err
119         }
120         if result == "OK" {
121                 return true, nil
122         }
123         return false, nil
124 }
125
126 func checkIntResultAndError(result interface{}, err error) (bool, error) {
127         if err != nil {
128                 return false, err
129         }
130         if n, ok := result.(int64); ok {
131                 if n == 1 {
132                         return true, nil
133                 }
134         } else if n, ok := result.(int); ok {
135                 if n == 1 {
136                         return true, nil
137                 }
138         }
139         return false, nil
140 }
141
142 func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
143         return client.Subscribe(channels...)
144 }
145
146 func CreateDB(client RedisClient, subscribe SubscribeFn, sentinelCreateCb RedisSentinelCreateCb, cfg Config, sentinelAddr string) *DB {
147         db := DB{
148                 client:       client,
149                 sentinel:     sentinelCreateCb,
150                 subscribe:    subscribe,
151                 redisModules: true,
152                 sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
153                 ch: intChannels{
154                         addChannel:    make(chan string),
155                         removeChannel: make(chan string),
156                         exit:          make(chan bool),
157                 },
158                 cfg:  cfg,
159                 addr: sentinelAddr,
160         }
161
162         return &db
163 }
164
165 func Create() []*DB {
166         osimpl := osImpl{}
167         return ReadConfigAndCreateDbClients(osimpl, newRedisClient, subscribeNotifications, newRedisSentinel)
168 }
169
170 func readConfig(osI OS) Config {
171         cfg := Config{
172                 hostname:        osI.Getenv("DBAAS_SERVICE_HOST", "localhost"),
173                 port:            osI.Getenv("DBAAS_SERVICE_PORT", "6379"),
174                 masterName:      osI.Getenv("DBAAS_MASTER_NAME", ""),
175                 sentinelPort:    osI.Getenv("DBAAS_SERVICE_SENTINEL_PORT", ""),
176                 clusterAddrList: osI.Getenv("DBAAS_CLUSTER_ADDR_LIST", ""),
177         }
178         return cfg
179 }
180
181 type OS interface {
182         Getenv(key string, defValue string) string
183 }
184
185 type osImpl struct{}
186
187 func (osImpl) Getenv(key string, defValue string) string {
188         val := os.Getenv(key)
189         if val == "" {
190                 val = defValue
191         }
192         return val
193 }
194
195 func ReadConfigAndCreateDbClients(osI OS, clientCreator RedisClientCreator,
196         subscribe SubscribeFn,
197         sentinelCreateCb RedisSentinelCreateCb) []*DB {
198         cfg := readConfig(osI)
199         return createDbClients(cfg, clientCreator, subscribe, sentinelCreateCb)
200 }
201
202 func createDbClients(cfg Config, clientCreator RedisClientCreator,
203         subscribe SubscribeFn,
204         sentinelCreateCb RedisSentinelCreateCb) []*DB {
205         if cfg.clusterAddrList == "" {
206                 return []*DB{createLegacyDbClient(cfg, clientCreator, subscribe, sentinelCreateCb)}
207         }
208
209         dbs := []*DB{}
210
211         addrList := strings.Split(cfg.clusterAddrList, ",")
212         for _, addr := range addrList {
213                 db := createDbClient(cfg, addr, clientCreator, subscribe, sentinelCreateCb)
214                 dbs = append(dbs, db)
215         }
216         return dbs
217 }
218
219 func createLegacyDbClient(cfg Config, clientCreator RedisClientCreator,
220         subscribe SubscribeFn,
221         sentinelCreateCb RedisSentinelCreateCb) *DB {
222         return createDbClient(cfg, cfg.hostname, clientCreator, subscribe, sentinelCreateCb)
223 }
224
225 func createDbClient(cfg Config, hostName string, clientCreator RedisClientCreator,
226         subscribe SubscribeFn,
227         sentinelCreateCb RedisSentinelCreateCb) *DB {
228         var client RedisClient
229         var db *DB
230         if cfg.sentinelPort == "" {
231                 client = clientCreator(hostName, cfg.port, "", false)
232                 db = CreateDB(client, subscribe, nil, cfg, hostName)
233         } else {
234                 client = clientCreator(hostName, cfg.sentinelPort, cfg.masterName, true)
235                 db = CreateDB(client, subscribe, sentinelCreateCb, cfg, hostName)
236         }
237         db.CheckCommands()
238         return db
239 }
240
241 func newRedisClient(addr, port, clusterName string, isHa bool) RedisClient {
242         if isHa == true {
243                 sentinelAddress := addr + ":" + port
244                 return redis.NewFailoverClient(
245                         &redis.FailoverOptions{
246                                 MasterName:    clusterName,
247                                 SentinelAddrs: []string{sentinelAddress},
248                                 PoolSize:      20,
249                                 MaxRetries:    2,
250                         },
251                 )
252         }
253         redisAddress := addr + ":" + port
254         return redis.NewClient(&redis.Options{
255                 Addr:       redisAddress,
256                 Password:   "", // no password set
257                 DB:         0,  // use default DB
258                 PoolSize:   20,
259                 MaxRetries: 2,
260         })
261 }
262
263 func (db *DB) CheckCommands() {
264         commands, err := db.client.Command().Result()
265         if err == nil {
266                 redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
267                         "msetmpub", "delmpub"}
268                 for _, v := range redisModuleCommands {
269                         _, ok := commands[v]
270                         if !ok {
271                                 db.redisModules = false
272                         }
273                 }
274         } else {
275                 dbLogger.Printf("SDL DB commands checking failure: %s\n", err)
276         }
277 }
278
279 func (db *DB) CloseDB() error {
280         return db.client.Close()
281 }
282
283 func (db *DB) UnsubscribeChannelDB(channels ...string) {
284         for _, v := range channels {
285                 db.sCbMap.Remove(v)
286                 db.ch.removeChannel <- v
287                 if db.sCbMap.Count() == 0 {
288                         db.ch.exit <- true
289                 }
290         }
291 }
292
293 func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
294         if db.sCbMap.Count() == 0 {
295                 for _, v := range channels {
296                         db.sCbMap.Add(v, cb)
297                 }
298
299                 go func(sCbMap *sharedCbMap,
300                         channelPrefix,
301                         eventSeparator string,
302                         ch intChannels,
303                         channels ...string) {
304                         sub := db.subscribe(db.client, channels...)
305                         rxChannel := sub.Channel()
306                         lCbMap := sCbMap.GetMapCopy()
307                         for {
308                                 select {
309                                 case msg := <-rxChannel:
310                                         cb, ok := lCbMap[msg.Channel]
311                                         if ok {
312                                                 cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
313                                         }
314                                 case channel := <-ch.addChannel:
315                                         lCbMap = sCbMap.GetMapCopy()
316                                         sub.Subscribe(channel)
317                                 case channel := <-ch.removeChannel:
318                                         lCbMap = sCbMap.GetMapCopy()
319                                         sub.Unsubscribe(channel)
320                                 case exit := <-ch.exit:
321                                         if exit {
322                                                 if err := sub.Close(); err != nil {
323                                                         dbLogger.Printf("SDL DB channel closing failure: %s\n", err)
324                                                 }
325                                                 return
326                                         }
327                                 }
328                         }
329                 }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
330
331         } else {
332                 for _, v := range channels {
333                         db.sCbMap.Add(v, cb)
334                         db.ch.addChannel <- v
335                 }
336         }
337 }
338
339 func (db *DB) MSet(pairs ...interface{}) error {
340         return db.client.MSet(pairs...).Err()
341 }
342
343 func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
344         if !db.redisModules {
345                 return errors.New("Redis deployment doesn't support MSETMPUB command")
346         }
347         command := make([]interface{}, 0)
348         command = append(command, "MSETMPUB")
349         command = append(command, len(pairs)/2)
350         command = append(command, len(channelsAndEvents)/2)
351         for _, d := range pairs {
352                 command = append(command, d)
353         }
354         for _, d := range channelsAndEvents {
355                 command = append(command, d)
356         }
357         _, err := db.client.Do(command...).Result()
358         return err
359 }
360
361 func (db *DB) MGet(keys []string) ([]interface{}, error) {
362         return db.client.MGet(keys...).Result()
363 }
364
365 func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
366         if !db.redisModules {
367                 return errors.New("Redis deployment not supporting command DELMPUB")
368         }
369         command := make([]interface{}, 0)
370         command = append(command, "DELMPUB")
371         command = append(command, len(keys))
372         command = append(command, len(channelsAndEvents)/2)
373         for _, d := range keys {
374                 command = append(command, d)
375         }
376         for _, d := range channelsAndEvents {
377                 command = append(command, d)
378         }
379         _, err := db.client.Do(command...).Result()
380         return err
381
382 }
383
384 func (db *DB) Del(keys []string) error {
385         _, err := db.client.Del(keys...).Result()
386         return err
387 }
388
389 func (db *DB) Keys(pattern string) ([]string, error) {
390         return db.client.Keys(pattern).Result()
391 }
392
393 func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
394         if !db.redisModules {
395                 return false, errors.New("Redis deployment not supporting command")
396         }
397
398         return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
399 }
400
401 func (db *DB) SetIEPub(channelsAndEvents []string, key string, oldData, newData interface{}) (bool, error) {
402         if !db.redisModules {
403                 return false, errors.New("Redis deployment not supporting command SETIEMPUB")
404         }
405         capacity := 4 + len(channelsAndEvents)
406         command := make([]interface{}, 0, capacity)
407         command = append(command, "SETIEMPUB")
408         command = append(command, key)
409         command = append(command, newData)
410         command = append(command, oldData)
411         for _, ce := range channelsAndEvents {
412                 command = append(command, ce)
413         }
414         return checkResultAndError(db.client.Do(command...).Result())
415 }
416
417 func (db *DB) SetNXPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
418         if !db.redisModules {
419                 return false, errors.New("Redis deployment not supporting command SETNXMPUB")
420         }
421         capacity := 3 + len(channelsAndEvents)
422         command := make([]interface{}, 0, capacity)
423         command = append(command, "SETNXMPUB")
424         command = append(command, key)
425         command = append(command, data)
426         for _, ce := range channelsAndEvents {
427                 command = append(command, ce)
428         }
429         return checkResultAndError(db.client.Do(command...).Result())
430 }
431 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
432         return db.client.SetNX(key, data, expiration).Result()
433 }
434
435 func (db *DB) DelIEPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
436         if !db.redisModules {
437                 return false, errors.New("Redis deployment not supporting command DELIEMPUB")
438         }
439         capacity := 3 + len(channelsAndEvents)
440         command := make([]interface{}, 0, capacity)
441         command = append(command, "DELIEMPUB")
442         command = append(command, key)
443         command = append(command, data)
444         for _, ce := range channelsAndEvents {
445                 command = append(command, ce)
446         }
447         return checkIntResultAndError(db.client.Do(command...).Result())
448 }
449
450 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
451         if !db.redisModules {
452                 return false, errors.New("Redis deployment not supporting command")
453         }
454         return checkIntResultAndError(db.client.Do("DELIE", key, data).Result())
455 }
456
457 func (db *DB) SAdd(key string, data ...interface{}) error {
458         _, err := db.client.SAdd(key, data...).Result()
459         return err
460 }
461
462 func (db *DB) SRem(key string, data ...interface{}) error {
463         _, err := db.client.SRem(key, data...).Result()
464         return err
465 }
466
467 func (db *DB) SMembers(key string) ([]string, error) {
468         result, err := db.client.SMembers(key).Result()
469         return result, err
470 }
471
472 func (db *DB) SIsMember(key string, data interface{}) (bool, error) {
473         result, err := db.client.SIsMember(key, data).Result()
474         return result, err
475 }
476
477 func (db *DB) SCard(key string) (int64, error) {
478         result, err := db.client.SCard(key).Result()
479         return result, err
480 }
481
482 func (db *DB) PTTL(key string) (time.Duration, error) {
483         result, err := db.client.PTTL(key).Result()
484         return result, err
485 }
486
487 func (db *DB) Info() (*DbInfo, error) {
488         var info DbInfo
489         resultStr, err := db.client.Info("all").Result()
490         result := strings.Split(strings.ReplaceAll(resultStr, "\r\n", "\n"), "\n")
491         readRedisInfoReplyFields(result, &info)
492         return &info, err
493 }
494
495 func readRedisInfoReplyFields(input []string, info *DbInfo) {
496         for _, line := range input {
497                 if idx := strings.Index(line, "role:"); idx != -1 {
498                         roleStr := line[idx+len("role:"):]
499                         if roleStr == "master" {
500                                 info.Fields.MasterRole = true
501                         }
502                 } else if idx := strings.Index(line, "connected_slaves:"); idx != -1 {
503                         cntStr := line[idx+len("connected_slaves:"):]
504                         if cnt, err := strconv.ParseUint(cntStr, 10, 32); err == nil {
505                                 info.Fields.ConnectedReplicaCnt = uint32(cnt)
506                         }
507                 }
508         }
509 }
510
511 func (db *DB) State() (*DbState, error) {
512         if db.cfg.sentinelPort != "" {
513                 //Establish connection to Redis sentinel. The reason why connection is done
514                 //here instead of time of the SDL instance creation is that for the time being
515                 //sentinel connection is needed only here to get state information and
516                 //state information is needed only by 'sdlcli' hence it is not time critical
517                 //and also we want to avoid opening unnecessary TCP connections towards Redis
518                 //sentinel for every SDL instance. Now it is done only when 'sdlcli' is used.
519                 sentinelClient := db.sentinel(&db.cfg, db.addr)
520                 return sentinelClient.GetDbState()
521         } else {
522                 var dbState DbState
523                 info, err := db.Info()
524                 if err != nil {
525                         return &dbState, err
526                 }
527                 dbState = fillDbStateFromDbInfo(info)
528                 return &dbState, err
529         }
530 }
531
532 func fillDbStateFromDbInfo(info *DbInfo) DbState {
533         var dbState DbState
534         if info.Fields.MasterRole == true {
535                 dbState = DbState{
536                         MasterDbState: MasterDbState{
537                                 Fields: MasterDbStateFields{
538                                         Role:  "master",
539                                         Flags: "master",
540                                 },
541                         },
542                 }
543         }
544         return dbState
545 }
546
547 var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
548
549 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
550         expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
551         result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
552         if err != nil {
553                 return err
554         }
555         if result == int64(1) {
556                 return nil
557         }
558         return errors.New("Lock not held")
559 }
560
561 func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
562         sCbMap.m.Lock()
563         defer sCbMap.m.Unlock()
564         sCbMap.cbMap[channel] = cb
565 }
566
567 func (sCbMap *sharedCbMap) Remove(channel string) {
568         sCbMap.m.Lock()
569         defer sCbMap.m.Unlock()
570         delete(sCbMap.cbMap, channel)
571 }
572
573 func (sCbMap *sharedCbMap) Count() int {
574         sCbMap.m.Lock()
575         defer sCbMap.m.Unlock()
576         return len(sCbMap.cbMap)
577 }
578
579 func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
580         sCbMap.m.Lock()
581         defer sCbMap.m.Unlock()
582         mapCopy := make(map[string]ChannelNotificationCb, 0)
583         for i, v := range sCbMap.cbMap {
584                 mapCopy[i] = v
585         }
586         return mapCopy
587 }