Set MaxRetries count
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
1 /*
2    Copyright (c) 2019 AT&T Intellectual Property.
3    Copyright (c) 2018-2019 Nokia.
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16 */
17
18 package sdlgoredis
19
20 import (
21         "errors"
22         "fmt"
23         "os"
24         "strconv"
25         "strings"
26         "time"
27
28         "github.com/go-redis/redis"
29 )
30
31 type ChannelNotificationCb func(channel string, payload ...string)
32
33 type intChannels struct {
34         addChannel    chan string
35         removeChannel chan string
36         exit          chan bool
37 }
38
39 type DB struct {
40         client       RedisClient
41         subscribe    SubscribeFn
42         redisModules bool
43         cbMap        map[string]ChannelNotificationCb
44         ch           intChannels
45 }
46
47 type Subscriber interface {
48         Channel() <-chan *redis.Message
49         Subscribe(channels ...string) error
50         Unsubscribe(channels ...string) error
51         Close() error
52 }
53
54 type SubscribeFn func(client RedisClient, channels ...string) Subscriber
55
56 type RedisClient interface {
57         Command() *redis.CommandsInfoCmd
58         Close() error
59         Subscribe(channels ...string) *redis.PubSub
60         MSet(pairs ...interface{}) *redis.StatusCmd
61         Do(args ...interface{}) *redis.Cmd
62         MGet(keys ...string) *redis.SliceCmd
63         Del(keys ...string) *redis.IntCmd
64         Keys(pattern string) *redis.StringSliceCmd
65         SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
66         SAdd(key string, members ...interface{}) *redis.IntCmd
67         SRem(key string, members ...interface{}) *redis.IntCmd
68         SMembers(key string) *redis.StringSliceCmd
69         SIsMember(key string, member interface{}) *redis.BoolCmd
70         SCard(key string) *redis.IntCmd
71         PTTL(key string) *redis.DurationCmd
72         Eval(script string, keys []string, args ...interface{}) *redis.Cmd
73         EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
74         ScriptExists(scripts ...string) *redis.BoolSliceCmd
75         ScriptLoad(script string) *redis.StringCmd
76 }
77
78 func checkResultAndError(result interface{}, err error) (bool, error) {
79         if err != nil {
80                 if err == redis.Nil {
81                         return false, nil
82                 }
83                 return false, err
84         }
85         if result == "OK" {
86                 return true, nil
87         }
88         return false, nil
89 }
90
91 func checkIntResultAndError(result interface{}, err error) (bool, error) {
92         if err != nil {
93                 return false, err
94         }
95         if result == 1 {
96                 return true, nil
97         }
98         return false, nil
99 }
100
101 func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
102         return client.Subscribe(channels...)
103 }
104
105 func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
106         db := DB{
107                 client:       client,
108                 subscribe:    subscribe,
109                 redisModules: true,
110                 cbMap:        make(map[string]ChannelNotificationCb, 0),
111                 ch: intChannels{
112                         addChannel:    make(chan string),
113                         removeChannel: make(chan string),
114                         exit:          make(chan bool),
115                 },
116         }
117
118         return &db
119 }
120
121 func Create() *DB {
122         hostname := os.Getenv("DBAAS_SERVICE_HOST")
123         if hostname == "" {
124                 hostname = "localhost"
125         }
126         port := os.Getenv("DBAAS_SERVICE_PORT")
127         if port == "" {
128                 port = "6379"
129         }
130         redisAddress := hostname + ":" + port
131         client := redis.NewClient(&redis.Options{
132                 Addr:     redisAddress,
133                 Password: "", // no password set
134                 DB:       0,  // use default DB
135                 PoolSize: 20,
136                 MaxRetries: 2,
137         })
138         db := CreateDB(client, subscribeNotifications)
139         db.CheckCommands()
140         return db
141 }
142
143 func (db *DB) CheckCommands() {
144         commands, err := db.client.Command().Result()
145         if err == nil {
146                 redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
147                         "msetmpub", "delmpub"}
148                 for _, v := range redisModuleCommands {
149                         _, ok := commands[v]
150                         if !ok {
151                                 db.redisModules = false
152                         }
153                 }
154         } else {
155                 fmt.Println(err)
156         }
157 }
158
159 func (db *DB) CloseDB() error {
160         return db.client.Close()
161 }
162
163 func (db *DB) UnsubscribeChannelDB(channels ...string) {
164         for _, v := range channels {
165                 db.ch.removeChannel <- v
166                 delete(db.cbMap, v)
167                 if len(db.cbMap) == 0 {
168                         db.ch.exit <- true
169                 }
170         }
171 }
172
173 func (db *DB) SubscribeChannelDB(cb ChannelNotificationCb, channelPrefix, eventSeparator string, channels ...string) {
174         if len(db.cbMap) == 0 {
175                 for _, v := range channels {
176                         db.cbMap[v] = cb
177                 }
178
179                 go func(cbMap *map[string]ChannelNotificationCb,
180                         channelPrefix,
181                         eventSeparator string,
182                         ch intChannels,
183                         channels ...string) {
184                         sub := db.subscribe(db.client, channels...)
185                         rxChannel := sub.Channel()
186                         for {
187                                 select {
188                                 case msg := <-rxChannel:
189                                         cb, ok := (*cbMap)[msg.Channel]
190                                         if ok {
191                                                 cb(strings.TrimPrefix(msg.Channel, channelPrefix), strings.Split(msg.Payload, eventSeparator)...)
192                                         }
193                                 case channel := <-ch.addChannel:
194                                         sub.Subscribe(channel)
195                                 case channel := <-ch.removeChannel:
196                                         sub.Unsubscribe(channel)
197                                 case exit := <-ch.exit:
198                                         if exit {
199                                                 if err := sub.Close(); err != nil {
200                                                         fmt.Println(err)
201                                                 }
202                                                 return
203                                         }
204                                 }
205                         }
206                 }(&db.cbMap, channelPrefix, eventSeparator, db.ch, channels...)
207
208         } else {
209                 for _, v := range channels {
210                         db.cbMap[v] = cb
211                         db.ch.addChannel <- v
212                 }
213         }
214 }
215
216 func (db *DB) MSet(pairs ...interface{}) error {
217         return db.client.MSet(pairs...).Err()
218 }
219
220 func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
221         if !db.redisModules {
222                 return errors.New("Redis deployment doesn't support MSETMPUB command")
223         }
224         command := make([]interface{}, 0)
225         command = append(command, "MSETMPUB")
226         command = append(command, len(pairs)/2)
227         command = append(command, len(channelsAndEvents)/2)
228         for _, d := range pairs {
229                 command = append(command, d)
230         }
231         for _, d := range channelsAndEvents {
232                 command = append(command, d)
233         }
234         _, err := db.client.Do(command...).Result()
235         return err
236 }
237
238 func (db *DB) MGet(keys []string) ([]interface{}, error) {
239         return db.client.MGet(keys...).Result()
240 }
241
242 func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
243         if !db.redisModules {
244                 return errors.New("Redis deployment not supporting command DELMPUB")
245         }
246         command := make([]interface{}, 0)
247         command = append(command, "DELMPUB")
248         command = append(command, len(keys))
249         command = append(command, len(channelsAndEvents)/2)
250         for _, d := range keys {
251                 command = append(command, d)
252         }
253         for _, d := range channelsAndEvents {
254                 command = append(command, d)
255         }
256         _, err := db.client.Do(command...).Result()
257         return err
258
259 }
260
261 func (db *DB) Del(keys []string) error {
262         _, err := db.client.Del(keys...).Result()
263         return err
264 }
265
266 func (db *DB) Keys(pattern string) ([]string, error) {
267         return db.client.Keys(pattern).Result()
268 }
269
270 func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
271         if !db.redisModules {
272                 return false, errors.New("Redis deployment not supporting command")
273         }
274
275         return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
276 }
277
278 func (db *DB) SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error) {
279         if !db.redisModules {
280                 return false, errors.New("Redis deployment not supporting command SETIEPUB")
281         }
282         return checkResultAndError(db.client.Do("SETIEPUB", key, newData, oldData, channel, message).Result())
283 }
284
285 func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, error) {
286         if !db.redisModules {
287                 return false, errors.New("Redis deployment not supporting command SETNXPUB")
288         }
289         return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
290 }
291 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
292         return db.client.SetNX(key, data, expiration).Result()
293 }
294
295 func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
296         if !db.redisModules {
297                 return false, errors.New("Redis deployment not supporting command")
298         }
299         return checkIntResultAndError(db.client.Do("DELIEPUB", key, data, channel, message).Result())
300 }
301
302 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
303         if !db.redisModules {
304                 return false, errors.New("Redis deployment not supporting command")
305         }
306         return checkIntResultAndError(db.client.Do("DELIE", key, data).Result())
307 }
308
309 func (db *DB) SAdd(key string, data ...interface{}) error {
310         _, err := db.client.SAdd(key, data...).Result()
311         return err
312 }
313
314 func (db *DB) SRem(key string, data ...interface{}) error {
315         _, err := db.client.SRem(key, data...).Result()
316         return err
317 }
318
319 func (db *DB) SMembers(key string) ([]string, error) {
320         result, err := db.client.SMembers(key).Result()
321         return result, err
322 }
323
324 func (db *DB) SIsMember(key string, data interface{}) (bool, error) {
325         result, err := db.client.SIsMember(key, data).Result()
326         return result, err
327 }
328
329 func (db *DB) SCard(key string) (int64, error) {
330         result, err := db.client.SCard(key).Result()
331         return result, err
332 }
333
334 func (db *DB) PTTL(key string) (time.Duration, error) {
335         result, err := db.client.PTTL(key).Result()
336         return result, err
337 }
338
339 var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
340
341 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
342         expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
343         result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
344         if err != nil {
345                 return err
346         }
347         if result == int64(1) {
348                 return nil
349         }
350         return errors.New("Lock not held")
351 }