Bump Redis client version to v8.11.4
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
index 78c8b5a..278be2a 100644 (file)
 package sdlgoredis
 
 import (
+       "context"
        "errors"
        "fmt"
-       "github.com/go-redis/redis/v7"
+       "github.com/go-redis/redis/v8"
        "io"
        "log"
        "net"
@@ -61,6 +62,7 @@ type Config struct {
 }
 
 type DB struct {
+       ctx          context.Context
        client       RedisClient
        sentinel     RedisSentinelCreateCb
        subscribe    SubscribeFn
@@ -72,46 +74,48 @@ type DB struct {
 }
 
 type Subscriber interface {
-       Channel() <-chan *redis.Message
-       Subscribe(channels ...string) error
-       Unsubscribe(channels ...string) error
+       Channel(opts ...redis.ChannelOption) <-chan *redis.Message
+       Subscribe(ctx context.Context, channels ...string) error
+       Unsubscribe(ctx context.Context, channels ...string) error
        Close() error
 }
 
-type SubscribeFn func(client RedisClient, channels ...string) Subscriber
+type SubscribeFn func(ctx context.Context, client RedisClient, channels ...string) Subscriber
 
 type RedisClient interface {
-       Command() *redis.CommandsInfoCmd
+       Command(ctx context.Context) *redis.CommandsInfoCmd
        Close() error
-       Subscribe(channels ...string) *redis.PubSub
-       MSet(pairs ...interface{}) *redis.StatusCmd
-       Do(args ...interface{}) *redis.Cmd
-       MGet(keys ...string) *redis.SliceCmd
-       Del(keys ...string) *redis.IntCmd
-       Keys(pattern string) *redis.StringSliceCmd
-       SetNX(key string, value interface{}, expiration time.Duration) *redis.BoolCmd
-       SAdd(key string, members ...interface{}) *redis.IntCmd
-       SRem(key string, members ...interface{}) *redis.IntCmd
-       SMembers(key string) *redis.StringSliceCmd
-       SIsMember(key string, member interface{}) *redis.BoolCmd
-       SCard(key string) *redis.IntCmd
-       PTTL(key string) *redis.DurationCmd
-       Eval(script string, keys []string, args ...interface{}) *redis.Cmd
-       EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
-       ScriptExists(scripts ...string) *redis.BoolSliceCmd
-       ScriptLoad(script string) *redis.StringCmd
-       Info(section ...string) *redis.StringCmd
-}
-
-var dbLogger *log.Logger
+       Subscribe(ctx context.Context, channels ...string) *redis.PubSub
+       MSet(ctx context.Context, pairs ...interface{}) *redis.StatusCmd
+       Do(ctx context.Context, args ...interface{}) *redis.Cmd
+       MGet(ctx context.Context, keys ...string) *redis.SliceCmd
+       Del(ctx context.Context, keys ...string) *redis.IntCmd
+       Keys(ctx context.Context, pattern string) *redis.StringSliceCmd
+       SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd
+       SAdd(ctx context.Context, key string, members ...interface{}) *redis.IntCmd
+       SRem(ctx context.Context, key string, members ...interface{}) *redis.IntCmd
+       SMembers(ctx context.Context, key string) *redis.StringSliceCmd
+       SIsMember(ctx context.Context, key string, member interface{}) *redis.BoolCmd
+       SCard(ctx context.Context, key string) *redis.IntCmd
+       PTTL(ctx context.Context, key string) *redis.DurationCmd
+       Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
+       EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
+       ScriptExists(ctx context.Context, scripts ...string) *redis.BoolSliceCmd
+       ScriptLoad(ctx context.Context, script string) *redis.StringCmd
+       Info(ctx context.Context, section ...string) *redis.StringCmd
+}
+
+var dbLogger *logger
 
 func init() {
-       dbLogger = log.New(os.Stdout, "database: ", log.LstdFlags|log.Lshortfile)
+       dbLogger = &logger{
+               log: log.New(os.Stdout, "database: ", log.LstdFlags|log.Lshortfile),
+       }
        redis.SetLogger(dbLogger)
 }
 
 func SetDbLogger(out io.Writer) {
-       dbLogger.SetOutput(out)
+       dbLogger.log.SetOutput(out)
 }
 
 func checkResultAndError(result interface{}, err error) (bool, error) {
@@ -143,12 +147,13 @@ func checkIntResultAndError(result interface{}, err error) (bool, error) {
        return false, nil
 }
 
-func subscribeNotifications(client RedisClient, channels ...string) Subscriber {
-       return client.Subscribe(channels...)
+func subscribeNotifications(ctx context.Context, client RedisClient, channels ...string) Subscriber {
+       return client.Subscribe(ctx, channels...)
 }
 
 func CreateDB(client RedisClient, subscribe SubscribeFn, sentinelCreateCb RedisSentinelCreateCb, cfg Config, sentinelAddr string) *DB {
        db := DB{
+               ctx:          context.Background(),
                client:       client,
                sentinel:     sentinelCreateCb,
                subscribe:    subscribe,
@@ -266,7 +271,7 @@ func newRedisClient(addr, port, clusterName string, isHa bool) RedisClient {
 }
 
 func (db *DB) CheckCommands() {
-       commands, err := db.client.Command().Result()
+       commands, err := db.client.Command(db.ctx).Result()
        if err == nil {
                redisModuleCommands := []string{"setie", "delie", "setiepub", "setnxpub",
                        "msetmpub", "delmpub"}
@@ -277,7 +282,7 @@ func (db *DB) CheckCommands() {
                        }
                }
        } else {
-               dbLogger.Printf("SDL DB commands checking failure: %s\n", err)
+               dbLogger.Printf(db.ctx, "SDL DB commands checking failure: %s\n", err)
        }
 }
 
@@ -306,7 +311,7 @@ func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, even
                        eventSeparator string,
                        ch intChannels,
                        channels ...string) {
-                       sub := db.subscribe(db.client, channels...)
+                       sub := db.subscribe(db.ctx, db.client, channels...)
                        rxChannel := sub.Channel()
                        lCbMap := sCbMap.GetMapCopy()
                        for {
@@ -318,14 +323,14 @@ func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, even
                                        }
                                case channel := <-ch.addChannel:
                                        lCbMap = sCbMap.GetMapCopy()
-                                       sub.Subscribe(channel)
+                                       sub.Subscribe(db.ctx, channel)
                                case channel := <-ch.removeChannel:
                                        lCbMap = sCbMap.GetMapCopy()
-                                       sub.Unsubscribe(channel)
+                                       sub.Unsubscribe(db.ctx, channel)
                                case exit := <-ch.exit:
                                        if exit {
                                                if err := sub.Close(); err != nil {
-                                                       dbLogger.Printf("SDL DB channel closing failure: %s\n", err)
+                                                       dbLogger.Printf(db.ctx, "SDL DB channel closing failure: %s\n", err)
                                                }
                                                return
                                        }
@@ -342,7 +347,7 @@ func (db *DB) SubscribeChannelDB(cb func(string, ...string), channelPrefix, even
 }
 
 func (db *DB) MSet(pairs ...interface{}) error {
-       return db.client.MSet(pairs...).Err()
+       return db.client.MSet(db.ctx, pairs...).Err()
 }
 
 func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
@@ -359,12 +364,12 @@ func (db *DB) MSetMPub(channelsAndEvents []string, pairs ...interface{}) error {
        for _, d := range channelsAndEvents {
                command = append(command, d)
        }
-       _, err := db.client.Do(command...).Result()
+       _, err := db.client.Do(db.ctx, command...).Result()
        return err
 }
 
 func (db *DB) MGet(keys []string) ([]interface{}, error) {
-       return db.client.MGet(keys...).Result()
+       return db.client.MGet(db.ctx, keys...).Result()
 }
 
 func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
@@ -381,18 +386,18 @@ func (db *DB) DelMPub(channelsAndEvents []string, keys []string) error {
        for _, d := range channelsAndEvents {
                command = append(command, d)
        }
-       _, err := db.client.Do(command...).Result()
+       _, err := db.client.Do(db.ctx, command...).Result()
        return err
 
 }
 
 func (db *DB) Del(keys []string) error {
-       _, err := db.client.Del(keys...).Result()
+       _, err := db.client.Del(db.ctx, keys...).Result()
        return err
 }
 
 func (db *DB) Keys(pattern string) ([]string, error) {
-       return db.client.Keys(pattern).Result()
+       return db.client.Keys(db.ctx, pattern).Result()
 }
 
 func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
@@ -400,7 +405,7 @@ func (db *DB) SetIE(key string, oldData, newData interface{}) (bool, error) {
                return false, errors.New("Redis deployment not supporting command")
        }
 
-       return checkResultAndError(db.client.Do("SETIE", key, newData, oldData).Result())
+       return checkResultAndError(db.client.Do(db.ctx, "SETIE", key, newData, oldData).Result())
 }
 
 func (db *DB) SetIEPub(channelsAndEvents []string, key string, oldData, newData interface{}) (bool, error) {
@@ -416,7 +421,7 @@ func (db *DB) SetIEPub(channelsAndEvents []string, key string, oldData, newData
        for _, ce := range channelsAndEvents {
                command = append(command, ce)
        }
-       return checkResultAndError(db.client.Do(command...).Result())
+       return checkResultAndError(db.client.Do(db.ctx, command...).Result())
 }
 
 func (db *DB) SetNXPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
@@ -431,10 +436,10 @@ func (db *DB) SetNXPub(channelsAndEvents []string, key string, data interface{})
        for _, ce := range channelsAndEvents {
                command = append(command, ce)
        }
-       return checkResultAndError(db.client.Do(command...).Result())
+       return checkResultAndError(db.client.Do(db.ctx, command...).Result())
 }
 func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
-       return db.client.SetNX(key, data, expiration).Result()
+       return db.client.SetNX(db.ctx, key, data, expiration).Result()
 }
 
 func (db *DB) DelIEPub(channelsAndEvents []string, key string, data interface{}) (bool, error) {
@@ -449,49 +454,49 @@ func (db *DB) DelIEPub(channelsAndEvents []string, key string, data interface{})
        for _, ce := range channelsAndEvents {
                command = append(command, ce)
        }
-       return checkIntResultAndError(db.client.Do(command...).Result())
+       return checkIntResultAndError(db.client.Do(db.ctx, command...).Result())
 }
 
 func (db *DB) DelIE(key string, data interface{}) (bool, error) {
        if !db.redisModules {
                return false, errors.New("Redis deployment not supporting command")
        }
-       return checkIntResultAndError(db.client.Do("DELIE", key, data).Result())
+       return checkIntResultAndError(db.client.Do(db.ctx, "DELIE", key, data).Result())
 }
 
 func (db *DB) SAdd(key string, data ...interface{}) error {
-       _, err := db.client.SAdd(key, data...).Result()
+       _, err := db.client.SAdd(db.ctx, key, data...).Result()
        return err
 }
 
 func (db *DB) SRem(key string, data ...interface{}) error {
-       _, err := db.client.SRem(key, data...).Result()
+       _, err := db.client.SRem(db.ctx, key, data...).Result()
        return err
 }
 
 func (db *DB) SMembers(key string) ([]string, error) {
-       result, err := db.client.SMembers(key).Result()
+       result, err := db.client.SMembers(db.ctx, key).Result()
        return result, err
 }
 
 func (db *DB) SIsMember(key string, data interface{}) (bool, error) {
-       result, err := db.client.SIsMember(key, data).Result()
+       result, err := db.client.SIsMember(db.ctx, key, data).Result()
        return result, err
 }
 
 func (db *DB) SCard(key string) (int64, error) {
-       result, err := db.client.SCard(key).Result()
+       result, err := db.client.SCard(db.ctx, key).Result()
        return result, err
 }
 
 func (db *DB) PTTL(key string) (time.Duration, error) {
-       result, err := db.client.PTTL(key).Result()
+       result, err := db.client.PTTL(db.ctx, key).Result()
        return result, err
 }
 
 func (db *DB) Info() (*DbInfo, error) {
        var info DbInfo
-       resultStr, err := db.client.Info("all").Result()
+       resultStr, err := db.client.Info(db.ctx, "all").Result()
        if err != nil {
                return &info, err
        }
@@ -872,7 +877,7 @@ var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then
 
 func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
        expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
-       result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
+       result, err := luaRefresh.Run(db.ctx, db.client, []string{key}, data, expirationStr).Result()
        if err != nil {
                return err
        }