Fix Go routine race condition in DB notification map
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
1 /*
2    Copyright (c) 2019 AT&T Intellectual Property.
3    Copyright (c) 2018-2019 Nokia.
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16 */
17
18 /*
19  * This source code is part of the near-RT RIC (RAN Intelligent Controller)
20  * platform project (RICP).
21  */
22
23 package sdlgoredis
24
25 import (
26         "errors"
27         "fmt"
28         "github.com/go-redis/redis"
29         "os"
30         "strconv"
31         "strings"
32         "sync"
33         "time"
34 )
35
36 type ChannelNotificationCb func(channel string, payload ...string)
37
38 type intChannels struct {
39         addChannel    chan string
40         removeChannel chan string
41         exit          chan bool
42 }
43
44 type sharedCbMap struct {
45         m     sync.Mutex
46         cbMap map[string]ChannelNotificationCb
47 }
48
49 type DB struct {
50         client       RedisClient
51         subscribe    SubscribeFn
52         redisModules bool
53         sCbMap       *sharedCbMap
54         ch           intChannels
55 }
56
57 type Subscriber interface {
58         Channel() <-chan *redis.Message
59         Subscribe(channels ...string) error
60         Unsubscribe(channels ...string) error
61         Close() error
62 }
63
64 type SubscribeFn func(client RedisClient, channels ...string) Subscriber
65
66 type RedisClient interface {
67         Command() *redis.CommandsInfoCmd
68         Close() error
69         Subscribe(channels ...string) *redis.PubSub
70         MSet(pairs ...interface{}) *redis.StatusCmd
71         Do(args ...interface{}) *redis.Cmd
72         MGet(keys ...string) *redis.SliceCmd
73         Del(keys ...string) *redis.IntCmd
74         Keys(pattern string) *redis.StringSliceCmd
75         SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
76         SAdd(key string, members ...interface{}) *redis.IntCmd
77         SRem(key string, members ...interface{}) *redis.IntCmd
78         SMembers(key string) *redis.StringSliceCmd
79         SIsMember(key string, member interface{}) *redis.BoolCmd
80         SCard(key string) *redis.IntCmd
81         PTTL(key string) *redis.DurationCmd
82         Eval(script string, keys []string, args ...interface{}) *redis.Cmd
83         EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
84         ScriptExists(scripts ...string) *redis.BoolSliceCmd
85         ScriptLoad(script string) *redis.StringCmd
86 }
87
88 func checkResultAndError(result interface{}, err error) (bool, error) {
89         if err != nil {
90                 if err == redis.Nil {
91                         return false, nil
92                 }
93                 return false, err
94         }
95         if result == "OK" {
96                 return true, nil
97         }
98         return false, nil
99 }
100
101 func checkIntResultAndError(result interface{}, err error) (bool, error) {
102         if err != nil {
103                 return false, err
104         }
105         if result.(int) == int(1) {
106                 return true, nil
107         }
108         return false, nil
109 }
110
111 func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
112         return client.Subscribe(channels...)
113 }
114
115 func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
116         db := DB{
117                 client:       client,
118                 subscribe:    subscribe,
119                 redisModules: true,
120                 sCbMap:       &sharedCbMap{cbMap: make(map[string]ChannelNotificationCb, 0)},
121                 ch: intChannels{
122                         addChannel:    make(chan string),
123                         removeChannel: make(chan string),
124                         exit:          make(chan bool),
125                 },
126         }
127
128         return &db
129 }
130
131 func Create() *DB {
132         var client *redis.Client
133         hostname := os.Getenv("DBAAS_SERVICE_HOST")
134         if hostname == "" {
135                 hostname = "localhost"
136         }
137         port := os.Getenv("DBAAS_SERVICE_PORT")
138         if port == "" {
139                 port = "6379"
140         }
141         sentinelPort := os.Getenv("DBAAS_SERVICE_SENTINEL_PORT")
142         masterName := os.Getenv("DBAAS_MASTER_NAME")
143         if sentinelPort == "" {
144                 redisAddress := hostname + ":" + port
145                 client = redis.NewClient(&redis.Options{
146                         Addr:       redisAddress,
147                         Password:   "", // no password set
148                         DB:         0,  // use default DB
149                         PoolSize:   20,
150                         MaxRetries: 2,
151                 })
152         } else {
153                 sentinelAddress := hostname + ":" + sentinelPort
154                 client = redis.NewFailoverClient(&redis.FailoverOptions{
155                         MasterName:    masterName,
156                         SentinelAddrs: []string{sentinelAddress},
157                         PoolSize:      20,
158                         MaxRetries:    2,
159                 })
160         }
161         db := CreateDB(client, subscribeNotifications)
162         db.CheckCommands()
163         return db
164 }
165
166 func (db *DB) CheckCommands() {
167         commands, err := db.client.Command().Result()
168         if err == nil {
169                 redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
170                         "msetmpub", "delmpub"}
171                 for _, v := range redisModuleCommands {
172                         _, ok := commands[v]
173                         if !ok {
174                                 db.redisModules = false
175                         }
176                 }
177         } else {
178                 fmt.Println(err)
179         }
180 }
181
182 func (db *DB) CloseDB() error {
183         return db.client.Close()
184 }
185
186 func (db *DB) UnsubscribeChannelDB(channels ...string) {
187         for _, v := range channels {
188                 db.sCbMap.Remove(v)
189                 db.ch.removeChannel <- v
190                 if db.sCbMap.Count() == 0 {
191                         db.ch.exit <- true
192                 }
193         }
194 }
195
196 func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, eventSeparator string, channels ...string) {
197         if db.sCbMap.Count() == 0 {
198                 for _, v := range channels {
199                         db.sCbMap.Add(v, cb)
200                 }
201
202                 go func(sCbMap *sharedCbMap,
203                         channelPrefix,
204                         eventSeparator string,
205                         ch intChannels,
206                         channels ...string) {
207                         sub := db.subscribe(db.client, channels...)
208                         rxChannel := sub.Channel()
209                         lCbMap := sCbMap.GetMapCopy()
210                         for {
211                                 select {
212                                 case msg := <-rxChannel:
213                                         cb, ok := lCbMap[msg.Channel]
214                                         if ok {
215                                                 cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
216                                         }
217                                 case channel := <-ch.addChannel:
218                                         lCbMap = sCbMap.GetMapCopy()
219                                         sub.Subscribe(channel)
220                                 case channel := <-ch.removeChannel:
221                                         lCbMap = sCbMap.GetMapCopy()
222                                         sub.Unsubscribe(channel)
223                                 case exit := <-ch.exit:
224                                         if exit {
225                                                 if err := sub.Close(); err != nil {
226                                                         fmt.Println(err)
227                                                 }
228                                                 return
229                                         }
230                                 }
231                         }
232                 }(db.sCbMap, channelPrefix, eventSeparator, db.ch, channels...)
233
234         } else {
235                 for _, v := range channels {
236                         db.sCbMap.Add(v, cb)
237                         db.ch.addChannel <- v
238                 }
239         }
240 }
241
242 func (db *DB) MSet(pairs ...interface{}) error {
243         return db.client.MSet(pairs...).Err()
244 }
245
246 func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
247         if !db.redisModules {
248                 return errors.New("Redis deployment doesn't support MSETMPUB command")
249         }
250         command := make([]interface{}, 0)
251         command = append(command, "MSETMPUB")
252         command = append(command, len(pairs)/2)
253         command = append(command, len(channelsAndEvents)/2)
254         for _, d := range pairs {
255                 command = append(command, d)
256         }
257         for _, d := range channelsAndEvents {
258                 command = append(command, d)
259         }
260         _, err := db.client.Do(command...).Result()
261         return err
262 }
263
264 func (db *DB) MGet(keys []string) ([]interface{}, error) {
265         return db.client.MGet(keys...).Result()
266 }
267
268 func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
269         if !db.redisModules {
270                 return errors.New("Redis deployment not supporting command DELMPUB")
271         }
272         command := make([]interface{}, 0)
273         command = append(command, "DELMPUB")
274         command = append(command, len(keys))
275         command = append(command, len(channelsAndEvents)/2)
276         for _, d := range keys {
277                 command = append(command, d)
278         }
279         for _, d := range channelsAndEvents {
280                 command = append(command, d)
281         }
282         _, err := db.client.Do(command...).Result()
283         return err
284
285 }
286
287 func (db *DB) Del(keys []string) error {
288         _, err := db.client.Del(keys...).Result()
289         return err
290 }
291
292 func (db *DB) Keys(pattern string) ([]string, error) {
293         return db.client.Keys(pattern).Result()
294 }
295
296 func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
297         if !db.redisModules {
298                 return false, errors.New("Redis deployment not supporting command")
299         }
300
301         return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
302 }
303
304 func (db *DB) SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error) {
305         if !db.redisModules {
306                 return false, errors.New("Redis deployment not supporting command SETIEPUB")
307         }
308         return checkResultAndError(db.client.Do("SETIEPUB", key, newData, oldData, channel, message).Result())
309 }
310
311 func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, error) {
312         if !db.redisModules {
313                 return false, errors.New("Redis deployment not supporting command SETNXPUB")
314         }
315         return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
316 }
317 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
318         return db.client.SetNX(key, data, expiration).Result()
319 }
320
321 func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
322         if !db.redisModules {
323                 return false, errors.New("Redis deployment not supporting command")
324         }
325         return checkIntResultAndError(db.client.Do("DELIEPUB", key, data, channel, message).Result())
326 }
327
328 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
329         if !db.redisModules {
330                 return false, errors.New("Redis deployment not supporting command")
331         }
332         return checkIntResultAndError(db.client.Do("DELIE", key, data).Result())
333 }
334
335 func (db *DB) SAdd(key string, data ...interface{}) error {
336         _, err := db.client.SAdd(key, data...).Result()
337         return err
338 }
339
340 func (db *DB) SRem(key string, data ...interface{}) error {
341         _, err := db.client.SRem(key, data...).Result()
342         return err
343 }
344
345 func (db *DB) SMembers(key string) ([]string, error) {
346         result, err := db.client.SMembers(key).Result()
347         return result, err
348 }
349
350 func (db *DB) SIsMember(key string, data interface{}) (bool, error) {
351         result, err := db.client.SIsMember(key, data).Result()
352         return result, err
353 }
354
355 func (db *DB) SCard(key string) (int64, error) {
356         result, err := db.client.SCard(key).Result()
357         return result, err
358 }
359
360 func (db *DB) PTTL(key string) (time.Duration, error) {
361         result, err := db.client.PTTL(key).Result()
362         return result, err
363 }
364
365 var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
366
367 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
368         expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
369         result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
370         if err != nil {
371                 return err
372         }
373         if result == int64(1) {
374                 return nil
375         }
376         return errors.New("Lock not held")
377 }
378
379 func (sCbMap *sharedCbMap) Add(channel string, cb ChannelNotificationCb) {
380         sCbMap.m.Lock()
381         defer sCbMap.m.Unlock()
382         sCbMap.cbMap[channel] = cb
383 }
384
385 func (sCbMap *sharedCbMap) Remove(channel string) {
386         sCbMap.m.Lock()
387         defer sCbMap.m.Unlock()
388         delete(sCbMap.cbMap, channel)
389 }
390
391 func (sCbMap *sharedCbMap) Count() int {
392         sCbMap.m.Lock()
393         defer sCbMap.m.Unlock()
394         return len(sCbMap.cbMap)
395 }
396
397 func (sCbMap *sharedCbMap) GetMapCopy() map[string]ChannelNotificationCb {
398         sCbMap.m.Lock()
399         defer sCbMap.m.Unlock()
400         mapCopy := make(map[string]ChannelNotificationCb, 0)
401         for i, v := range sCbMap.cbMap {
402                 mapCopy[i] = v
403         }
404         return mapCopy
405 }