Support redis sentinel configuration
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
index ab56b12..f1c7e62 100644 (file)
@@ -21,6 +21,7 @@ import (
        "errors"
        "fmt"
        "os"
+       "strconv"
        "strings"
        "time"
 
@@ -67,6 +68,11 @@ type RedisClient interface {
        SMembers(key string) *redis.StringSliceCmd
        SIsMember(key string, member interface{}) *redis.BoolCmd
        SCard(key string) *redis.IntCmd
+       PTTL(key string) *redis.DurationCmd
+       Eval(script string, keys []string, args ...interface{}) *redis.Cmd
+       EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
+       ScriptExists(scripts ...string) *redis.BoolSliceCmd
+       ScriptLoad(script string) *redis.StringCmd
 }
 
 func checkResultAndError(result interface{}, err error) (bool, error) {
@@ -86,7 +92,7 @@ func checkIntResultAndError(result interface{}, err error) (bool, error) {
        if err != nil {
                return false, err
        }
-       if result == 1 {
+       if result.(int64) == int64(1) {
                return true, nil
        }
        return false, nil
@@ -113,6 +119,7 @@ func CreateDB(client RedisClient, subscribe SubscribeFn) *DB {
 }
 
 func Create() *DB {
+       var client *redis.Client
        hostname := os.Getenv("DBAAS_SERVICE_HOST")
        if hostname == "" {
                hostname = "localhost"
@@ -121,13 +128,26 @@ func Create() *DB {
        if port == "" {
                port = "6379"
        }
-       redisAddress := hostname + ":" + port
-       client := redis.NewClient(&redis.Options{
-               Addr:     redisAddress,
-               Password: "", // no password set
-               DB:       0,  // use default DB
-               PoolSize: 20,
-       })
+       sentinelPort := os.Getenv("DBAAS_SERVICE_SENTINEL_PORT")
+       masterName := os.Getenv("DBAAS_MASTER_NAME")
+       if sentinelPort == "" {
+               redisAddress := hostname + ":" + port
+               client = redis.NewClient(&redis.Options{
+                       Addr:       redisAddress,
+                       Password:   "", // no password set
+                       DB:         0,  // use default DB
+                       PoolSize:   20,
+                       MaxRetries: 2,
+               })
+       } else {
+               sentinelAddress := hostname + ":" + sentinelPort
+               client = redis.NewFailoverClient(&redis.FailoverOptions{
+                       MasterName:    masterName,
+                       SentinelAddrs: []string{sentinelAddress},
+                       PoolSize:      20,
+                       MaxRetries:    2,
+               })
+       }
        db := CreateDB(client, subscribeNotifications)
        db.CheckCommands()
        return db
@@ -281,8 +301,8 @@ func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, er
        }
        return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
 }
-func (db *DB) SetNX(key string, data interface{}) (bool, error) {
-       return db.client.SetNX(key, data, 0).Result()
+func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
+       return db.client.SetNX(key, data, expiration).Result()
 }
 
 func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
@@ -323,3 +343,22 @@ func (db *DB) SCard(key string) (int64, error) {
        result, err := db.client.SCard(key).Result()
        return result, err
 }
+
+func (db *DB) PTTL(key string) (time.Duration, error) {
+       result, err := db.client.PTTL(key).Result()
+       return result, err
+}
+
+var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
+
+func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
+       expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
+       result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
+       if err != nil {
+               return err
+       }
+       if result == int64(1) {
+               return nil
+       }
+       return errors.New("Lock not held")
+}