327946e214c78190686d5d444917b3da39d939be
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
1 /*
2    Copyright (c) 2019 AT&T Intellectual Property.
3    Copyright (c) 2018-2019 Nokia.
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16 */
17
18 /*
19  * This source code is part of the near-RT RIC (RAN Intelligent Controller)
20  * platform project (RICP).
21  */
22
23 package sdlgoredis
24
25 import (
26         "errors"
27         "fmt"
28         "github.com/go-redis/redis"
29         "os"
30         "strconv"
31         "strings"
32         "sync"
33         "time"
34 )
35
36 type ChannelNotificationCb func(channel string, payload ...string)
37
38 type intChannels struct {
39         addChannel    chan string
40         removeChannel chan string
41         exit          chan bool
42 }
43
44 type sharedCbMap struct {
45         m     sync.Mutex
46         cbMap map[string]ChannelNotificationCb
47 }
48
49 type DB struct {
50         client       RedisClient
51         subscribe    SubscribeFn
52         redisModules bool
53         sCbMap       *sharedCbMap
54         ch           intChannels
55 }
56
57 type Subscriber interface {
58         Channel() <-chan *redis.Message
59         Subscribe(channels ...string) error
60         Unsubscribe(channels ...string) error
61         Close() error
62 }
63
64 type SubscribeFn func(client RedisClient, channels ...string) Subscriber
65
66 type RedisClient interface {
67         Command() *redis.CommandsInfoCmd
68         Close() error
69         Subscribe(channels ...string) *redis.PubSub
70         MSet(pairs ...interface{}) *redis.StatusCmd
71         Do(args ...interface{}) *redis.Cmd
72         MGet(keys ...string) *redis.SliceCmd
73         Del(keys ...string) *redis.IntCmd
74         Keys(pattern string) *redis.StringSliceCmd
75         SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
76         SAdd(key string, members ...interface{}) *redis.IntCmd
77         SRem(key string, members ...interface{}) *redis.IntCmd
78         SMembers(key string) *redis.StringSliceCmd
79         SIsMember(key string, member interface{}) *redis.BoolCmd
80         SCard(key string) *redis.IntCmd
81         PTTL(key string) *redis.DurationCmd
82         Eval(script string, keys []string, args ...interface{}) *redis.Cmd
83         EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
84         ScriptExists(scripts ...string) *redis.BoolSliceCmd
85         ScriptLoad(script string) *redis.StringCmd
86 }
87
88 func checkResultAndError(result interface{}, err error) (bool, error) {
89         if err != nil {
90                 if err == redis.Nil {
91                         return false, nil
92                 }
93                 return false, err
94         }
95         if result == "OK" {
96                 return true, nil
97         }
98         return false, nil
99 }
100
101 func checkIntResultAndError(result interface{}, err error) (bool, error) {
102         if err != nil {
103                 return false, err
104         }
105         if n, ok := result.(int64); ok {
106                 if n == 1 {
107                         return true, nil
108                 }
109         } else if n, ok := result.(int); ok {
110                 if n == 1 {
111                         return true, nil
112                 }
113         }
114         return false, nil
115 }
116
117 func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
118         return client.Subscribe(channels...)
119 }
120
121 func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
122         db := DB{
123                 client:       client,
124                 subscribe:    subscribe,
125                 redisModules: true,
126                 sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
127                 ch: intChannels{
128                         addChannel:    make(chan string),
129                         removeChannel: make(chan string),
130                         exit:          make(chan bool),
131                 },
132         }
133
134         return &db
135 }
136
137 func Create() *DB {
138         var client *redis.Client
139         hostname := os.Getenv("DBAAS_SERVICE_HOST")
140         if hostname == "" {
141                 hostname = "localhost"
142         }
143         port := os.Getenv("DBAAS_SERVICE_PORT")
144         if port == "" {
145                 port = "6379"
146         }
147         sentinelPort := os.Getenv("DBAAS_SERVICE_SENTINEL_PORT")
148         masterName := os.Getenv("DBAAS_MASTER_NAME")
149         if sentinelPort == "" {
150                 redisAddress := hostname + ":" + port
151                 client = redis.NewClient(&redis.Options{
152                         Addr:       redisAddress,
153                         Password:   "", // no password set
154                         DB:         0,  // use default DB
155                         PoolSize:   20,
156                         MaxRetries: 2,
157                 })
158         } else {
159                 sentinelAddress := hostname + ":" + sentinelPort
160                 client = redis.NewFailoverClient(&redis.FailoverOptions{
161                         MasterName:    masterName,
162                         SentinelAddrs: []string{sentinelAddress},
163                         PoolSize:      20,
164                         MaxRetries:    2,
165                 })
166         }
167         db := CreateDB(client, subscribeNotifications)
168         db.CheckCommands()
169         return db
170 }
171
172 func (db *DB) CheckCommands() {
173         commands, err := db.client.Command().Result()
174         if err == nil {
175                 redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
176                         "msetmpub", "delmpub"}
177                 for _, v := range redisModuleCommands {
178                         _, ok := commands[v]
179                         if !ok {
180                                 db.redisModules = false
181                         }
182                 }
183         } else {
184                 fmt.Println(err)
185         }
186 }
187
188 func (db *DB) CloseDB() error {
189         return db.client.Close()
190 }
191
192 func (db *DB) UnsubscribeChannelDB(channels ...string) {
193         for _, v := range channels {
194                 db.sCbMap.Remove(v)
195                 db.ch.removeChannel <- v
196                 if db.sCbMap.Count() == 0 {
197                         db.ch.exit <- true
198                 }
199         }
200 }
201
202 func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
203         if db.sCbMap.Count() == 0 {
204                 for _, v := range channels {
205                         db.sCbMap.Add(v, cb)
206                 }
207
208                 go func(sCbMap *sharedCbMap,
209                         channelPrefix,
210                         eventSeparator string,
211                         ch intChannels,
212                         channels ...string) {
213                         sub := db.subscribe(db.client, channels...)
214                         rxChannel := sub.Channel()
215                         lCbMap := sCbMap.GetMapCopy()
216                         for {
217                                 select {
218                                 case msg := <-rxChannel:
219                                         cb, ok := lCbMap[msg.Channel]
220                                         if ok {
221                                                 cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
222                                         }
223                                 case channel := <-ch.addChannel:
224                                         lCbMap = sCbMap.GetMapCopy()
225                                         sub.Subscribe(channel)
226                                 case channel := <-ch.removeChannel:
227                                         lCbMap = sCbMap.GetMapCopy()
228                                         sub.Unsubscribe(channel)
229                                 case exit := <-ch.exit:
230                                         if exit {
231                                                 if err := sub.Close(); err != nil {
232                                                         fmt.Println(err)
233                                                 }
234                                                 return
235                                         }
236                                 }
237                         }
238                 }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
239
240         } else {
241                 for _, v := range channels {
242                         db.sCbMap.Add(v, cb)
243                         db.ch.addChannel <- v
244                 }
245         }
246 }
247
248 func (db *DB) MSet(pairs ...interface{}) error {
249         return db.client.MSet(pairs...).Err()
250 }
251
252 func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
253         if !db.redisModules {
254                 return errors.New("Redis deployment doesn't support MSETMPUB command")
255         }
256         command := make([]interface{}, 0)
257         command = append(command, "MSETMPUB")
258         command = append(command, len(pairs)/2)
259         command = append(command, len(channelsAndEvents)/2)
260         for _, d := range pairs {
261                 command = append(command, d)
262         }
263         for _, d := range channelsAndEvents {
264                 command = append(command, d)
265         }
266         _, err := db.client.Do(command...).Result()
267         return err
268 }
269
270 func (db *DB) MGet(keys []string) ([]interface{}, error) {
271         return db.client.MGet(keys...).Result()
272 }
273
274 func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
275         if !db.redisModules {
276                 return errors.New("Redis deployment not supporting command DELMPUB")
277         }
278         command := make([]interface{}, 0)
279         command = append(command, "DELMPUB")
280         command = append(command, len(keys))
281         command = append(command, len(channelsAndEvents)/2)
282         for _, d := range keys {
283                 command = append(command, d)
284         }
285         for _, d := range channelsAndEvents {
286                 command = append(command, d)
287         }
288         _, err := db.client.Do(command...).Result()
289         return err
290
291 }
292
293 func (db *DB) Del(keys []string) error {
294         _, err := db.client.Del(keys...).Result()
295         return err
296 }
297
298 func (db *DB) Keys(pattern string) ([]string, error) {
299         return db.client.Keys(pattern).Result()
300 }
301
302 func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
303         if !db.redisModules {
304                 return false, errors.New("Redis deployment not supporting command")
305         }
306
307         return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
308 }
309
310 func (db *DB) SetIEPub(channelsAndEvents []string, key string, oldData, newData interface{}) (bool, error) {
311         if !db.redisModules {
312                 return false, errors.New("Redis deployment not supporting command SETIEMPUB")
313         }
314         capacity := 4 + len(channelsAndEvents)
315         command := make([]interface{}, 0, capacity)
316         command = append(command, "SETIEMPUB")
317         command = append(command, key)
318         command = append(command, newData)
319         command = append(command, oldData)
320         for _, ce := range channelsAndEvents {
321                 command = append(command, ce)
322         }
323         return checkResultAndError(db.client.Do(command...).Result())
324 }
325
326 func (db *DB) SetNXPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
327         if !db.redisModules {
328                 return false, errors.New("Redis deployment not supporting command SETNXMPUB")
329         }
330         capacity := 3 + len(channelsAndEvents)
331         command := make([]interface{}, 0, capacity)
332         command = append(command, "SETNXMPUB")
333         command = append(command, key)
334         command = append(command, data)
335         for _, ce := range channelsAndEvents {
336                 command = append(command, ce)
337         }
338         return checkResultAndError(db.client.Do(command...).Result())
339 }
340 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
341         return db.client.SetNX(key, data, expiration).Result()
342 }
343
344 func (db *DB) DelIEPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
345         if !db.redisModules {
346                 return false, errors.New("Redis deployment not supporting command DELIEMPUB")
347         }
348         capacity := 3 + len(channelsAndEvents)
349         command := make([]interface{}, 0, capacity)
350         command = append(command, "DELIEMPUB")
351         command = append(command, key)
352         command = append(command, data)
353         for _, ce := range channelsAndEvents {
354                 command = append(command, ce)
355         }
356         return checkIntResultAndError(db.client.Do(command...).Result())
357 }
358
359 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
360         if !db.redisModules {
361                 return false, errors.New("Redis deployment not supporting command")
362         }
363         return checkIntResultAndError(db.client.Do("DELIE", key, data).Result())
364 }
365
366 func (db *DB) SAdd(key string, data ...interface{}) error {
367         _, err := db.client.SAdd(key, data...).Result()
368         return err
369 }
370
371 func (db *DB) SRem(key string, data ...interface{}) error {
372         _, err := db.client.SRem(key, data...).Result()
373         return err
374 }
375
376 func (db *DB) SMembers(key string) ([]string, error) {
377         result, err := db.client.SMembers(key).Result()
378         return result, err
379 }
380
381 func (db *DB) SIsMember(key string, data interface{}) (bool, error) {
382         result, err := db.client.SIsMember(key, data).Result()
383         return result, err
384 }
385
386 func (db *DB) SCard(key string) (int64, error) {
387         result, err := db.client.SCard(key).Result()
388         return result, err
389 }
390
391 func (db *DB) PTTL(key string) (time.Duration, error) {
392         result, err := db.client.PTTL(key).Result()
393         return result, err
394 }
395
396 var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
397
398 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
399         expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
400         result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
401         if err != nil {
402                 return err
403         }
404         if result == int64(1) {
405                 return nil
406         }
407         return errors.New("Lock not held")
408 }
409
410 func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
411         sCbMap.m.Lock()
412         defer sCbMap.m.Unlock()
413         sCbMap.cbMap[channel] = cb
414 }
415
416 func (sCbMap *sharedCbMap) Remove(channel string) {
417         sCbMap.m.Lock()
418         defer sCbMap.m.Unlock()
419         delete(sCbMap.cbMap, channel)
420 }
421
422 func (sCbMap *sharedCbMap) Count() int {
423         sCbMap.m.Lock()
424         defer sCbMap.m.Unlock()
425         return len(sCbMap.cbMap)
426 }
427
428 func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
429         sCbMap.m.Lock()
430         defer sCbMap.m.Unlock()
431         mapCopy := make(map[string]ChannelNotificationCb, 0)
432         for i, v := range sCbMap.cbMap {
433                 mapCopy[i] = v
434         }
435         return mapCopy
436 }