Add resource locking
[ric-plt/sdlgo.git] / internal / sdlgoredis / sdlgoredis.go
index ab56b12..0ccccb6 100644 (file)
@@ -21,6 +21,7 @@ import (
        "errors"
        "fmt"
        "os"
+       "strconv"
        "strings"
        "time"
 
@@ -67,6 +68,11 @@ type RedisClient interface {
        SMembers(key string) *redis.StringSliceCmd
        SIsMember(key string, member interface{}) *redis.BoolCmd
        SCard(key string) *redis.IntCmd
+       PTTL(key string) *redis.DurationCmd
+       Eval(script string, keys []string, args ...interface{}) *redis.Cmd
+       EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
+       ScriptExists(scripts ...string) *redis.BoolSliceCmd
+       ScriptLoad(script string) *redis.StringCmd
 }
 
 func checkResultAndError(result interface{}, err error) (bool, error) {
@@ -281,8 +287,8 @@ func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, er
        }
        return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
 }
-func (db *DB) SetNX(key string, data interface{}) (bool, error) {
-       return db.client.SetNX(key, data, 0).Result()
+func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
+       return db.client.SetNX(key, data, expiration).Result()
 }
 
 func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
@@ -323,3 +329,22 @@ func (db *DB) SCard(key string) (int64, error) {
        result, err := db.client.SCard(key).Result()
        return result, err
 }
+
+func (db *DB) PTTL(key string) (time.Duration, error) {
+       result, err := db.client.PTTL(key).Result()
+       return result, err
+}
+
+var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
+
+func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
+       expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
+       result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
+       if err != nil {
+               return err
+       }
+       if result == int64(1) {
+               return nil
+       }
+       return errors.New("Lock not held")
+}