Add resource locking
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
1 /*
2    Copyright (c) 2019 AT&T Intellectual Property.
3    Copyright (c) 2018-2019 Nokia.
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16 */
17
18 package sdlgoredis
19
20 import (
21         "errors"
22         "fmt"
23         "os"
24         "strconv"
25         "strings"
26         "time"
27
28         "github.com/go-redis/redis"
29 )
30
31 type ChannelNotificationCb func(channel string, payload ...string)
32
33 type intChannels struct {
34         addChannel    chan string
35         removeChannel chan string
36         exit          chan bool
37 }
38
39 type DB struct {
40         client       RedisClient
41         subscribe    SubscribeFn
42         redisModules bool
43         cbMap        map[string]ChannelNotificationCb
44         ch           intChannels
45 }
46
47 type Subscriber interface {
48         Channel() <-chan *redis.Message
49         Subscribe(channels ...string) error
50         Unsubscribe(channels ...string) error
51         Close() error
52 }
53
54 type SubscribeFn func(client RedisClient, channels ...string) Subscriber
55
56 type RedisClient interface {
57         Command() *redis.CommandsInfoCmd
58         Close() error
59         Subscribe(channels ...string) *redis.PubSub
60         MSet(pairs ...interface{}) *redis.StatusCmd
61         Do(args ...interface{}) *redis.Cmd
62         MGet(keys ...string) *redis.SliceCmd
63         Del(keys ...string) *redis.IntCmd
64         Keys(pattern string) *redis.StringSliceCmd
65         SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
66         SAdd(key string, members ...interface{}) *redis.IntCmd
67         SRem(key string, members ...interface{}) *redis.IntCmd
68         SMembers(key string) *redis.StringSliceCmd
69         SIsMember(key string, member interface{}) *redis.BoolCmd
70         SCard(key string) *redis.IntCmd
71         PTTL(key string) *redis.DurationCmd
72         Eval(script string, keys []string, args ...interface{}) *redis.Cmd
73         EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
74         ScriptExists(scripts ...string) *redis.BoolSliceCmd
75         ScriptLoad(script string) *redis.StringCmd
76 }
77
78 func checkResultAndError(result interface{}, err error) (bool, error) {
79         if err != nil {
80                 if err == redis.Nil {
81                         return false, nil
82                 }
83                 return false, err
84         }
85         if result == "OK" {
86                 return true, nil
87         }
88         return false, nil
89 }
90
91 func checkIntResultAndError(result interface{}, err error) (bool, error) {
92         if err != nil {
93                 return false, err
94         }
95         if result == 1 {
96                 return true, nil
97         }
98         return false, nil
99 }
100
101 func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
102         return client.Subscribe(channels...)
103 }
104
105 func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
106         db := DB{
107                 client:       client,
108                 subscribe:    subscribe,
109                 redisModules: true,
110                 cbMap:        make(map[string]ChannelNotificationCb, 0),
111                 ch: intChannels{
112                         addChannel:    make(chan string),
113                         removeChannel: make(chan string),
114                         exit:          make(chan bool),
115                 },
116         }
117
118         return &db
119 }
120
121 func Create() *DB {
122         hostname := os.Getenv("DBAAS_SERVICE_HOST")
123         if hostname == "" {
124                 hostname = "localhost"
125         }
126         port := os.Getenv("DBAAS_SERVICE_PORT")
127         if port == "" {
128                 port = "6379"
129         }
130         redisAddress := hostname + ":" + port
131         client := redis.NewClient(&redis.Options{
132                 Addr:     redisAddress,
133                 Password: "", // no password set
134                 DB:       0,  // use default DB
135                 PoolSize: 20,
136         })
137         db := CreateDB(client, subscribeNotifications)
138         db.CheckCommands()
139         return db
140 }
141
142 func (db *DB) CheckCommands() {
143         commands, err := db.client.Command().Result()
144         if err == nil {
145                 redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
146                         "msetmpub", "delmpub"}
147                 for _, v := range redisModuleCommands {
148                         _, ok := commands[v]
149                         if !ok {
150                                 db.redisModules = false
151                         }
152                 }
153         } else {
154                 fmt.Println(err)
155         }
156 }
157
158 func (db *DB) CloseDB() error {
159         return db.client.Close()
160 }
161
162 func (db *DB) UnsubscribeChannelDB(channels ...string) {
163         for _, v := range channels {
164                 db.ch.removeChannel <- v
165                 delete(db.cbMap, v)
166                 if len(db.cbMap) == 0 {
167                         db.ch.exit <- true
168                 }
169         }
170 }
171
172 func (db *DB) SubscribeChannelDB(cb ChannelNotificationCb, channelPrefix, eventSeparator string, channels ...string) {
173         if len(db.cbMap) == 0 {
174                 for _, v := range channels {
175                         db.cbMap[v] = cb
176                 }
177
178                 go func(cbMap *map[string]ChannelNotificationCb,
179                         channelPrefix,
180                         eventSeparator string,
181                         ch intChannels,
182                         channels ...string) {
183                         sub := db.subscribe(db.client, channels...)
184                         rxChannel := sub.Channel()
185                         for {
186                                 select {
187                                 case msg := <-rxChannel:
188                                         cb, ok := (*cbMap)[msg.Channel]
189                                         if ok {
190                                                 cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
191                                         }
192                                 case channel := <-ch.addChannel:
193                                         sub.Subscribe(channel)
194                                 case channel := <-ch.removeChannel:
195                                         sub.Unsubscribe(channel)
196                                 case exit := <-ch.exit:
197                                         if exit {
198                                                 if err := sub.Close(); err != nil {
199                                                         fmt.Println(err)
200                                                 }
201                                                 return
202                                         }
203                                 }
204                         }
205                 }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
206
207         } else {
208                 for _, v := range channels {
209                         db.cbMap[v] = cb
210                         db.ch.addChannel <- v
211                 }
212         }
213 }
214
215 func (db *DB) MSet(pairs ...interface{}) error {
216         return db.client.MSet(pairs...).Err()
217 }
218
219 func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
220         if !db.redisModules {
221                 return errors.New("Redis deployment doesn't support MSETMPUB command")
222         }
223         command := make([]interface{}, 0)
224         command = append(command, "MSETMPUB")
225         command = append(command, len(pairs)/2)
226         command = append(command, len(channelsAndEvents)/2)
227         for _, d := range pairs {
228                 command = append(command, d)
229         }
230         for _, d := range channelsAndEvents {
231                 command = append(command, d)
232         }
233         _, err := db.client.Do(command...).Result()
234         return err
235 }
236
237 func (db *DB) MGet(keys []string) ([]interface{}, error) {
238         return db.client.MGet(keys...).Result()
239 }
240
241 func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
242         if !db.redisModules {
243                 return errors.New("Redis deployment not supporting command DELMPUB")
244         }
245         command := make([]interface{}, 0)
246         command = append(command, "DELMPUB")
247         command = append(command, len(keys))
248         command = append(command, len(channelsAndEvents)/2)
249         for _, d := range keys {
250                 command = append(command, d)
251         }
252         for _, d := range channelsAndEvents {
253                 command = append(command, d)
254         }
255         _, err := db.client.Do(command...).Result()
256         return err
257
258 }
259
260 func (db *DB) Del(keys []string) error {
261         _, err := db.client.Del(keys...).Result()
262         return err
263 }
264
265 func (db *DB) Keys(pattern string) ([]string, error) {
266         return db.client.Keys(pattern).Result()
267 }
268
269 func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
270         if !db.redisModules {
271                 return false, errors.New("Redis deployment not supporting command")
272         }
273
274         return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
275 }
276
277 func (db *DB) SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error) {
278         if !db.redisModules {
279                 return false, errors.New("Redis deployment not supporting command SETIEPUB")
280         }
281         return checkResultAndError(db.client.Do("SETIEPUB", key, newData, oldData, channel, message).Result())
282 }
283
284 func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, error) {
285         if !db.redisModules {
286                 return false, errors.New("Redis deployment not supporting command SETNXPUB")
287         }
288         return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
289 }
290 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
291         return db.client.SetNX(key, data, expiration).Result()
292 }
293
294 func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
295         if !db.redisModules {
296                 return false, errors.New("Redis deployment not supporting command")
297         }
298         return checkIntResultAndError(db.client.Do("DELIEPUB", key, data, channel, message).Result())
299 }
300
301 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
302         if !db.redisModules {
303                 return false, errors.New("Redis deployment not supporting command")
304         }
305         return checkIntResultAndError(db.client.Do("DELIE", key, data).Result())
306 }
307
308 func (db *DB) SAdd(key string, data ...interface{}) error {
309         _, err := db.client.SAdd(key, data...).Result()
310         return err
311 }
312
313 func (db *DB) SRem(key string, data ...interface{}) error {
314         _, err := db.client.SRem(key, data...).Result()
315         return err
316 }
317
318 func (db *DB) SMembers(key string) ([]string, error) {
319         result, err := db.client.SMembers(key).Result()
320         return result, err
321 }
322
323 func (db *DB) SIsMember(key string, data interface{}) (bool, error) {
324         result, err := db.client.SIsMember(key, data).Result()
325         return result, err
326 }
327
328 func (db *DB) SCard(key string) (int64, error) {
329         result, err := db.client.SCard(key).Result()
330         return result, err
331 }
332
333 func (db *DB) PTTL(key string) (time.Duration, error) {
334         result, err := db.client.PTTL(key).Result()
335         return result, err
336 }
337
338 var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
339
340 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
341         expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
342         result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
343         if err != nil {
344                 return err
345         }
346         if result == int64(1) {
347                 return nil
348         }
349         return errors.New("Lock not held")
350 }