Add resource locking 06/706/1 v0.3.0
authorMarco Tallskog <marco.tallskog@nokia.com>
Wed, 14 Aug 2019 11:50:23 +0000 (14:50 +0300)
committerMarco Tallskog <marco.tallskog@nokia.com>
Wed, 14 Aug 2019 12:01:26 +0000 (15:01 +0300)
Implement methods that enable applications to create locks for
a resources. Resource locks are plain keys with a random value
with an expiration time. Only the holder of a lock can release it either
by removing the lock or by letting it expire (after it will be
automatically removed). By using a random value (that only lock holder
knows), it can be ensured that others can't remove it. Resource locks
are per namespace.

Implemented methods are:
LockResource: this method will check if the lock already exists and if
not, it will create it with an expiration time. It is possible for the
client to define with options how many times the lock will be tried and
the interval how often it is tried. If retry is used, this method will
block.
ReleaseResource: remove the lock. If removing is tried after the
expiration, an error will be returned.
RefreshResource: with this method, it is possible to set a new
expiration time for the lock. If the lock has already expired, an error
will be returned.
CheckResource: application can query the remaining expiration time with
this method, regardless it is the owner of the lock or not.

Change-Id: Ic6f5274c1740c7e36ddaba564024cffcc4c5de3d
Signed-off-by: Marco Tallskog <marco.tallskog@nokia.com>
internal/sdlgoredis/sdlgoredis.go
internal/sdlgoredis/sdlgoredis_test.go
sdl.go
sdl_test.go

index ab56b12..0ccccb6 100644 (file)
@@ -21,6 +21,7 @@ import (
        "errors"
        "fmt"
        "os"
+       "strconv"
        "strings"
        "time"
 
@@ -67,6 +68,11 @@ type RedisClient interface {
        SMembers(key string) *redis.StringSliceCmd
        SIsMember(key string, member interface{}) *redis.BoolCmd
        SCard(key string) *redis.IntCmd
+       PTTL(key string) *redis.DurationCmd
+       Eval(script string, keys []string, args ...interface{}) *redis.Cmd
+       EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
+       ScriptExists(scripts ...string) *redis.BoolSliceCmd
+       ScriptLoad(script string) *redis.StringCmd
 }
 
 func checkResultAndError(result interface{}, err error) (bool, error) {
@@ -281,8 +287,8 @@ func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, er
        }
        return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result())
 }
-func (db *DB) SetNX(key string, data interface{}) (bool, error) {
-       return db.client.SetNX(key, data, 0).Result()
+func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
+       return db.client.SetNX(key, data, expiration).Result()
 }
 
 func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) {
@@ -323,3 +329,22 @@ func (db *DB) SCard(key string) (int64, error) {
        result, err := db.client.SCard(key).Result()
        return result, err
 }
+
+func (db *DB) PTTL(key string) (time.Duration, error) {
+       result, err := db.client.PTTL(key).Result()
+       return result, err
+}
+
+var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
+
+func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
+       expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10)
+       result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result()
+       if err != nil {
+               return err
+       }
+       if result == int64(1) {
+               return nil
+       }
+       return errors.New("Lock not held")
+}
index 81c211a..962704c 100644 (file)
@@ -1,4 +1,3 @@
-
 /*
    Copyright (c) 2019 AT&T Intellectual Property.
    Copyright (c) 2018-2019 Nokia.
 package sdlgoredis_test
 
 import (
-       "testing"
        "errors"
+       "strconv"
+       "testing"
        "time"
 
-       "github.com/go-redis/redis"
        "gerrit.o-ran-sc.org/r/ric-plt/sdlgo/internal/sdlgoredis"
+       "github.com/go-redis/redis"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/mock"
 )
@@ -109,6 +109,26 @@ func (m *clientMock) SCard(key string) *redis.IntCmd {
        return m.Called(key).Get(0).(*redis.IntCmd)
 }
 
+func (m *clientMock) PTTL(key string) *redis.DurationCmd {
+       return m.Called(key).Get(0).(*redis.DurationCmd)
+}
+
+func (m *clientMock) Eval(script string, keys []string, args ...interface{}) *redis.Cmd {
+       return m.Called(script, keys).Get(0).(*redis.Cmd)
+}
+
+func (m *clientMock) EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd {
+       return m.Called(sha1, keys, args).Get(0).(*redis.Cmd)
+}
+
+func (m *clientMock) ScriptExists(scripts ...string) *redis.BoolSliceCmd {
+       return m.Called(scripts).Get(0).(*redis.BoolSliceCmd)
+}
+
+func (m *clientMock) ScriptLoad(script string) *redis.StringCmd {
+       return m.Called(script).Get(0).(*redis.StringCmd)
+}
+
 func setSubscribeNotifications() (*pubSubMock, sdlgoredis.SubscribeFn) {
        mock := new(pubSubMock)
        return mock, func(client sdlgoredis.RedisClient, channels ...string) sdlgoredis.Subscriber {
@@ -128,12 +148,12 @@ func setup(commandsExists bool) (*pubSubMock, *clientMock, *sdlgoredis.DB) {
 
        if commandsExists {
                cmdResult = map[string]*redis.CommandInfo{
-                       "setie": &dummyCommandInfo,
-                       "delie": &dummyCommandInfo,
+                       "setie":    &dummyCommandInfo,
+                       "delie":    &dummyCommandInfo,
                        "setiepub": &dummyCommandInfo,
                        "setnxpub": &dummyCommandInfo,
                        "msetmpub": &dummyCommandInfo,
-                       "delmpub": &dummyCommandInfo,
+                       "delmpub":  &dummyCommandInfo,
                }
        } else {
                cmdResult = map[string]*redis.CommandInfo{
@@ -149,7 +169,7 @@ func setup(commandsExists bool) (*pubSubMock, *clientMock, *sdlgoredis.DB) {
 func TestMSetSuccessfully(t *testing.T) {
        _, r, db := setup(true)
        expectedKeysAndValues := []interface{}{"key1", "value1", "key2", 2}
-       r.On("MSet",expectedKeysAndValues).Return(redis.NewStatusResult("OK", nil))
+       r.On("MSet", expectedKeysAndValues).Return(redis.NewStatusResult("OK", nil))
        err := db.MSet("key1", "value1", "key2", 2)
        assert.Nil(t, err)
        r.AssertExpectations(t)
@@ -158,7 +178,7 @@ func TestMSetSuccessfully(t *testing.T) {
 func TestMSetFailure(t *testing.T) {
        _, r, db := setup(true)
        expectedKeysAndValues := []interface{}{"key1", "value1", "key2", 2}
-       r.On("MSet",expectedKeysAndValues).Return(redis.NewStatusResult("OK", errors.New("Some error")))
+       r.On("MSet", expectedKeysAndValues).Return(redis.NewStatusResult("OK", errors.New("Some error")))
        err := db.MSet("key1", "value1", "key2", 2)
        assert.NotNil(t, err)
        r.AssertExpectations(t)
@@ -167,30 +187,30 @@ func TestMSetFailure(t *testing.T) {
 func TestMSetMPubSuccessfully(t *testing.T) {
        _, r, db := setup(true)
        expectedMessage := []interface{}{"MSETMPUB", 2, 2, "key1", "val1", "key2", "val2",
-                                                                        "chan1", "event1", "chan2", "event2"}
+               "chan1", "event1", "chan2", "event2"}
        r.On("Do", expectedMessage).Return(redis.NewCmdResult("", nil))
        assert.Nil(t, db.MSetMPub([]string{"chan1", "event1", "chan2", "event2"},
-                                                         "key1", "val1", "key2", "val2"))
+               "key1", "val1", "key2", "val2"))
        r.AssertExpectations(t)
 }
 
 func TestMsetMPubFailure(t *testing.T) {
        _, r, db := setup(true)
        expectedMessage := []interface{}{"MSETMPUB", 2, 2, "key1", "val1", "key2", "val2",
-                                                                        "chan1", "event1", "chan2", "event2"}
+               "chan1", "event1", "chan2", "event2"}
        r.On("Do", expectedMessage).Return(redis.NewCmdResult("", errors.New("Some error")))
        assert.NotNil(t, db.MSetMPub([]string{"chan1", "event1", "chan2", "event2"},
-                                                            "key1", "val1", "key2", "val2"))
+               "key1", "val1", "key2", "val2"))
        r.AssertExpectations(t)
 }
 
 func TestMSetMPubCommandMissing(t *testing.T) {
        _, r, db := setup(false)
        expectedMessage := []interface{}{"MSETMPUB", 2, 2, "key1", "val1", "key2", "val2",
-                                                                        "chan1", "event1", "chan2", "event2"}
+               "chan1", "event1", "chan2", "event2"}
        r.AssertNotCalled(t, "Do", expectedMessage)
        assert.NotNil(t, db.MSetMPub([]string{"chan1", "event1", "chan2", "event2"},
-                                                                "key1", "val1", "key2", "val2"))
+               "key1", "val1", "key2", "val2"))
        r.AssertExpectations(t)
 
 }
@@ -211,7 +231,7 @@ func TestMGetFailure(t *testing.T) {
        expectedKeys := []string{"key1", "key2", "key3"}
        expectedResult := []interface{}{nil}
        r.On("MGet", expectedKeys).Return(redis.NewSliceResult(expectedResult,
-                                                                                                                  errors.New("Some error")))
+               errors.New("Some error")))
        result, err := db.MGet([]string{"key1", "key2", "key3"})
        assert.Equal(t, result, expectedResult)
        assert.NotNil(t, err)
@@ -221,30 +241,30 @@ func TestMGetFailure(t *testing.T) {
 func TestDelMPubSuccessfully(t *testing.T) {
        _, r, db := setup(true)
        expectedMessage := []interface{}{"DELMPUB", 2, 2, "key1", "key2", "chan1", "event1",
-                                                                        "chan2", "event2"}
+               "chan2", "event2"}
        r.On("Do", expectedMessage).Return(redis.NewCmdResult("", nil))
        assert.Nil(t, db.DelMPub([]string{"chan1", "event1", "chan2", "event2"},
-                                                        []string{"key1", "key2"}))
+               []string{"key1", "key2"}))
        r.AssertExpectations(t)
 }
 
 func TestDelMPubFailure(t *testing.T) {
        _, r, db := setup(true)
        expectedMessage := []interface{}{"DELMPUB", 2, 2, "key1", "key2", "chan1", "event1",
-                                                                        "chan2", "event2"}
+               "chan2", "event2"}
        r.On("Do", expectedMessage).Return(redis.NewCmdResult("", errors.New("Some error")))
        assert.NotNil(t, db.DelMPub([]string{"chan1", "event1", "chan2", "event2"},
-                                                               []string{"key1", "key2"}))
+               []string{"key1", "key2"}))
        r.AssertExpectations(t)
 }
 
 func TestDelMPubCommandMissing(t *testing.T) {
        _, r, db := setup(false)
        expectedMessage := []interface{}{"DELMPUB", 2, 2, "key1", "key2", "chan1", "event1",
-                                                                        "chan2", "event2"}
+               "chan2", "event2"}
        r.AssertNotCalled(t, "Do", expectedMessage)
        assert.NotNil(t, db.DelMPub([]string{"chan1", "event1", "chan2", "event2"},
-                                                               []string{"key1", "key2"}))
+               []string{"key1", "key2"}))
        r.AssertExpectations(t)
 }
 
@@ -280,7 +300,7 @@ func TestKeysFailure(t *testing.T) {
        expectedPattern := "pattern*"
        expectedResult := []string{}
        r.On("Keys", expectedPattern).Return(redis.NewStringSliceResult(expectedResult,
-                                                                                                                                   errors.New("Some error")))
+               errors.New("Some error")))
        _, err := db.Keys("pattern*")
        assert.NotNil(t, err)
        r.AssertExpectations(t)
@@ -411,7 +431,7 @@ func TestSetNXSuccessfully(t *testing.T) {
        expectedKey := "key"
        expectedData := "data"
        r.On("SetNX", expectedKey, expectedData, time.Duration(0)).Return(redis.NewBoolResult(true, nil))
-       result, err := db.SetNX("key", "data")
+       result, err := db.SetNX("key", "data", 0)
        assert.True(t, result)
        assert.Nil(t, err)
        r.AssertExpectations(t)
@@ -422,7 +442,7 @@ func TestSetNXUnsuccessfully(t *testing.T) {
        expectedKey := "key"
        expectedData := "data"
        r.On("SetNX", expectedKey, expectedData, time.Duration(0)).Return(redis.NewBoolResult(false, nil))
-       result, err := db.SetNX("key", "data")
+       result, err := db.SetNX("key", "data", 0)
        assert.False(t, result)
        assert.Nil(t, err)
        r.AssertExpectations(t)
@@ -433,8 +453,8 @@ func TestSetNXFailure(t *testing.T) {
        expectedKey := "key"
        expectedData := "data"
        r.On("SetNX", expectedKey, expectedData, time.Duration(0)).
-               Return(redis.NewBoolResult(false,errors.New("Some error")))
-       result, err := db.SetNX("key", "data")
+               Return(redis.NewBoolResult(false, errors.New("Some error")))
+       result, err := db.SetNX("key", "data", 0)
        assert.False(t, result)
        assert.NotNil(t, err)
        r.AssertExpectations(t)
@@ -572,7 +592,7 @@ func TestSMembersFailure(t *testing.T) {
        expectedKey := "key"
        expectedResult := []string{"member1", "member2"}
        r.On("SMembers", expectedKey).Return(redis.NewStringSliceResult(expectedResult,
-                                                                                                                                       errors.New("Some error")))
+               errors.New("Some error")))
        result, err := db.SMembers("key")
        assert.Equal(t, result, expectedResult)
        assert.NotNil(t, err)
@@ -646,10 +666,10 @@ func TestSubscribeChannelDBSubscribeRXUnsubscribe(t *testing.T) {
        ps.On("Close").Return(nil)
        count := 0
        receivedChannel := ""
-       db.SubscribeChannelDB(func(channel string, payload ...string){
+       db.SubscribeChannelDB(func(channel string, payload ...string) {
                count++
                receivedChannel = channel
-               },"{prefix}", "---", "{prefix}channel")
+       }, "{prefix}", "---", "{prefix}channel")
        ch <- &msg
        db.UnsubscribeChannelDB("{prefix}channel")
        time.Sleep(1 * time.Second)
@@ -679,16 +699,16 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) {
        ps.On("Close").Return(nil)
        count := 0
        receivedChannel1 := ""
-       db.SubscribeChannelDB(func(channel string, payload ...string){
+       db.SubscribeChannelDB(func(channel string, payload ...string) {
                count++
                receivedChannel1 = channel
-               },"{prefix}", "---", "{prefix}channel1")
+       }, "{prefix}", "---", "{prefix}channel1")
        ch <- &msg1
        receivedChannel2 := ""
-       db.SubscribeChannelDB(func(channel string, payload ...string){
+       db.SubscribeChannelDB(func(channel string, payload ...string) {
                count++
                receivedChannel2 = channel
-               },"{prefix}", "---", "{prefix}channel2")
+       }, "{prefix}", "---", "{prefix}channel2")
 
        db.UnsubscribeChannelDB("{prefix}channel1")
        ch <- &msg2
@@ -699,4 +719,70 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) {
        assert.Equal(t, "channel2", receivedChannel2)
        r.AssertExpectations(t)
        ps.AssertExpectations(t)
-}
\ No newline at end of file
+}
+
+func TestPTTLSuccessfully(t *testing.T) {
+       _, r, db := setup(true)
+       expectedKey := "key"
+       expectedResult := time.Duration(1)
+       r.On("PTTL", expectedKey).Return(redis.NewDurationResult(expectedResult,
+               nil))
+       result, err := db.PTTL("key")
+       assert.Equal(t, result, expectedResult)
+       assert.Nil(t, err)
+       r.AssertExpectations(t)
+}
+
+func TestPTTLFailure(t *testing.T) {
+       _, r, db := setup(true)
+       expectedKey := "key"
+       expectedResult := time.Duration(1)
+       r.On("PTTL", expectedKey).Return(redis.NewDurationResult(expectedResult,
+               errors.New("Some error")))
+       result, err := db.PTTL("key")
+       assert.Equal(t, result, expectedResult)
+       assert.NotNil(t, err)
+       r.AssertExpectations(t)
+}
+
+func TestPExpireIESuccessfully(t *testing.T) {
+       _, r, db := setup(true)
+       expectedKey := "key"
+       expectedData := "data"
+       expectedDuration := strconv.FormatInt(int64(10000), 10)
+
+       r.On("EvalSha", mock.Anything, []string{expectedKey}, []interface{}{expectedData, expectedDuration}).
+               Return(redis.NewCmdResult(int64(1), nil))
+
+       err := db.PExpireIE("key", "data", 10*time.Second)
+       assert.Nil(t, err)
+       r.AssertExpectations(t)
+}
+
+func TestPExpireIEFailure(t *testing.T) {
+       _, r, db := setup(true)
+       expectedKey := "key"
+       expectedData := "data"
+       expectedDuration := strconv.FormatInt(int64(10000), 10)
+
+       r.On("EvalSha", mock.Anything, []string{expectedKey}, []interface{}{expectedData, expectedDuration}).
+               Return(redis.NewCmdResult(int64(1), errors.New("Some error")))
+
+       err := db.PExpireIE("key", "data", 10*time.Second)
+       assert.NotNil(t, err)
+       r.AssertExpectations(t)
+}
+
+func TestPExpireIELockNotHeld(t *testing.T) {
+       _, r, db := setup(true)
+       expectedKey := "key"
+       expectedData := "data"
+       expectedDuration := strconv.FormatInt(int64(10000), 10)
+
+       r.On("EvalSha", mock.Anything, []string{expectedKey}, []interface{}{expectedData, expectedDuration}).
+               Return(redis.NewCmdResult(int64(0), nil))
+
+       err := db.PExpireIE("key", "data", 10*time.Second)
+       assert.NotNil(t, err)
+       r.AssertExpectations(t)
+}
diff --git a/sdl.go b/sdl.go
index 78dd15d..426c2c4 100644 (file)
--- a/sdl.go
+++ b/sdl.go
 package sdlgo
 
 import (
+       "crypto/rand"
+       "encoding/base64"
        "errors"
        "fmt"
+       "io"
        "reflect"
        "strings"
+       "sync"
+       "time"
 
        "gerrit.o-ran-sc.org/r/ric-plt/sdlgo/internal/sdlgoredis"
 )
@@ -32,6 +37,8 @@ type SdlInstance struct {
        nameSpace      string
        nsPrefix       string
        eventSeparator string
+       mutex          sync.Mutex
+       tmp            []byte
        iDatabase
 }
 
@@ -294,7 +301,7 @@ func (s *SdlInstance) SetIf(key string, oldData, newData interface{}) (bool, err
 //given channel.
 func (s *SdlInstance) SetIfNotExistsAndPublish(channelsAndEvents []string, key string, data interface{}) (bool, error) {
        if len(channelsAndEvents) == 0 {
-               return s.SetNX(s.nsPrefix+key, data)
+               return s.SetNX(s.nsPrefix+key, data, 0)
        }
        if err := s.checkChannelsAndEvents("SetIfNotExistsAndPublish", channelsAndEvents); err != nil {
                return false, err
@@ -307,7 +314,7 @@ func (s *SdlInstance) SetIfNotExistsAndPublish(channelsAndEvents []string, key s
 //then it's value is not changed. Checking the key existence and potential set operation
 //is done atomically.
 func (s *SdlInstance) SetIfNotExists(key string, data interface{}) (bool, error) {
-       return s.SetNX(s.nsPrefix+key, data)
+       return s.SetNX(s.nsPrefix+key, data, 0)
 }
 
 //RemoveAndPublish removes data from SDL. Operation is done atomically, i.e. either all succeeds or fails.
@@ -466,6 +473,122 @@ func (s *SdlInstance) GroupSize(group string) (int64, error) {
        return retVal, err
 }
 
+func (s *SdlInstance) randomToken() (string, error) {
+       s.mutex.Lock()
+       defer s.mutex.Unlock()
+
+       if len(s.tmp) == 0 {
+               s.tmp = make([]byte, 16)
+       }
+
+       if _, err := io.ReadFull(rand.Reader, s.tmp); err != nil {
+               return "", err
+       }
+
+       return base64.RawURLEncoding.EncodeToString(s.tmp), nil
+}
+
+//LockResource function is used for locking a resource. The resource lock in
+//practice is a key with random value that is set to expire after a time
+//period. The value written to key is a random value, thus only the instance
+//created a lock, can release it. Resource locks are per namespace.
+func (s *SdlInstance) LockResource(resource string, expiration time.Duration, opt *Options) (*Lock, error) {
+       value, err := s.randomToken()
+       if err != nil {
+               return nil, err
+       }
+
+       var retryTimer *time.Timer
+       for i, attempts := 0, opt.getRetryCount()+1; i < attempts; i++ {
+               ok, err := s.SetNX(s.nsPrefix+resource, value, expiration)
+               if err != nil {
+                       return nil, err
+               } else if ok {
+                       return &Lock{s: s, key: resource, value: value}, nil
+               }
+               if retryTimer == nil {
+                       retryTimer = time.NewTimer(opt.getRetryWait())
+                       defer retryTimer.Stop()
+               } else {
+                       retryTimer.Reset(opt.getRetryWait())
+               }
+
+               select {
+               case <-retryTimer.C:
+               }
+       }
+       return nil, errors.New("Lock not obtained")
+}
+
+//ReleaseResource removes the lock from a resource. If lock is already
+//expired or some other instance is keeping the lock (lock taken after expiration),
+//an error is returned.
+func (l *Lock) ReleaseResource() error {
+       ok, err := l.s.DelIE(l.s.nsPrefix+l.key, l.value)
+
+       if err != nil {
+               return err
+       }
+       if !ok {
+               return errors.New("Lock not held")
+       }
+       return nil
+}
+
+//RefreshResource function can be used to set a new expiration time for the
+//resource lock (if the lock still exists). The old remaining expiration
+//time is overwritten with the given new expiration time.
+func (l *Lock) RefreshResource(expiration time.Duration) error {
+       err := l.s.PExpireIE(l.s.nsPrefix+l.key, l.value, expiration)
+       return err
+}
+
+//CheckResource returns the expiration time left for a resource.
+//If the resource doesn't exist, -2 is returned.
+func (s *SdlInstance) CheckResource(resource string) (time.Duration, error) {
+       result, err := s.PTTL(s.nsPrefix + resource)
+       if err != nil {
+               return 0, err
+       }
+       if result == time.Duration(-1) {
+               return 0, errors.New("invalid resource given, no expiration time attached")
+       }
+       return result, nil
+}
+
+//Options struct defines the behaviour for getting the resource lock.
+type Options struct {
+       //The number of time the lock will be tried.
+       //Default: 0 = no retry
+       RetryCount int
+
+       //Wait between the retries.
+       //Default: 100ms
+       RetryWait time.Duration
+}
+
+func (o *Options) getRetryCount() int {
+       if o != nil && o.RetryCount > 0 {
+               return o.RetryCount
+       }
+       return 0
+}
+
+func (o *Options) getRetryWait() time.Duration {
+       if o != nil && o.RetryWait > 0 {
+               return o.RetryWait
+       }
+       return 100 * time.Millisecond
+}
+
+//Lock struct identifies the resource lock instance. Releasing and adjusting the
+//expirations are done using the methods defined for this struct.
+type Lock struct {
+       s     *SdlInstance
+       key   string
+       value string
+}
+
 type iDatabase interface {
        SubscribeChannelDB(cb sdlgoredis.ChannelNotificationCb, channelPrefix, eventSeparator string, channels ...string)
        UnsubscribeChannelDB(channels ...string)
@@ -478,7 +601,7 @@ type iDatabase interface {
        Keys(key string) ([]string, error)
        SetIE(key string, oldData, newData interface{}) (bool, error)
        SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error)
-       SetNX(key string, data interface{}) (bool, error)
+       SetNX(key string, data interface{}, expiration time.Duration) (bool, error)
        SetNXPub(channel, message, key string, data interface{}) (bool, error)
        DelIE(key string, data interface{}) (bool, error)
        DelIEPub(channel, message, key string, data interface{}) (bool, error)
@@ -487,4 +610,6 @@ type iDatabase interface {
        SMembers(key string) ([]string, error)
        SIsMember(key string, data interface{}) (bool, error)
        SCard(key string) (int64, error)
+       PTTL(key string) (time.Duration, error)
+       PExpireIE(key string, data interface{}, expiration time.Duration) error
 }
index 652ec55..832e8eb 100644 (file)
@@ -19,8 +19,9 @@ package sdlgo_test
 
 import (
        "errors"
-       "testing"
        "reflect"
+       "testing"
+       "time"
 
        "gerrit.o-ran-sc.org/r/ric-plt/sdlgo"
        "gerrit.o-ran-sc.org/r/ric-plt/sdlgo/internal/sdlgoredis"
@@ -85,8 +86,8 @@ func (m *mockDB) SetIEPub(channel, message, key string, oldData, newData interfa
        return a.Bool(0), a.Error(1)
 }
 
-func (m *mockDB) SetNX(key string, data interface{}) (bool, error) {
-       a := m.Called(key, data)
+func (m *mockDB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) {
+       a := m.Called(key, data, expiration)
        return a.Bool(0), a.Error(1)
 }
 
@@ -129,6 +130,16 @@ func (m *mockDB) SCard(key string) (int64, error) {
        return a.Get(0).(int64), a.Error(1)
 }
 
+func (m *mockDB) PTTL(key string) (time.Duration, error) {
+       a := m.Called(key)
+       return a.Get(0).(time.Duration), a.Error(1)
+}
+
+func (m *mockDB) PExpireIE(key string, data interface{}, expiration time.Duration) error {
+       a := m.Called(key, data, expiration)
+       return a.Error(0)
+}
+
 func setup() (*mockDB, *sdlgo.SdlInstance) {
        m := new(mockDB)
        i := sdlgo.NewSdlInstance("namespace", m)
@@ -136,21 +147,21 @@ func setup() (*mockDB, *sdlgo.SdlInstance) {
 }
 
 func verifySliceInOrder(a, b []string) bool {
-               for i, v := range a {
-                       found := false
-                       if i%2 == 0 {
-                               for j, x := range b {
-                                       if j%2 == 0 && x == v && a[i+1] == b[j+1] {
-                                               found = true
-                                               break
-                                       }
-                               }
-                               if !found {
-                                       return false
+       for i, v := range a {
+               found := false
+               if i%2 == 0 {
+                       for j, x := range b {
+                               if j%2 == 0 && x == v && a[i+1] == b[j+1] {
+                                       found = true
+                                       break
                                }
                        }
+                       if !found {
+                               return false
+                       }
                }
-               return true
+       }
+       return true
 
 }
 
@@ -292,16 +303,16 @@ func TestWriteByteArrayAsValue(t *testing.T) {
        m.AssertExpectations(t)
 }
 
-func TestWriteMapAsInput(t *testing.T){
+func TestWriteMapAsInput(t *testing.T) {
        m, i := setup()
 
        setExpected := []interface{}{"{namespace},key1", "string123",
-                                                               "{namespace},key22", 12,
-                                                               "{namespace},key333", []byte{1,2,3,4,5}}
+               "{namespace},key22", 12,
+               "{namespace},key333", []byte{1, 2, 3, 4, 5}}
        inputMap := map[string]interface{}{
-               "key1": "string123",
-               "key22": 12,
-               "key333": []byte{1,2,3,4,5},
+               "key1":   "string123",
+               "key22":  12,
+               "key333": []byte{1, 2, 3, 4, 5},
        }
 
        m.On("MSet", mock.MatchedBy(func(input []interface{}) bool {
@@ -318,7 +329,7 @@ func TestWriteMapAsInput(t *testing.T){
                        }
                }
                return true
-               })).Return(nil)
+       })).Return(nil)
 
        err := i.Set(inputMap)
        assert.Nil(t, err)
@@ -436,10 +447,10 @@ func TestWriteAndPublishOneKeyOneChannel(t *testing.T) {
 }
 
 func TestWriteAndPublishSeveralChannelsAndEvents(t *testing.T) {
-       m , i := setup()
+       m, i := setup()
 
-       expectedChannelsAndEvents := []string{"{namespace},channel1", "event1___event2", 
-                                                                                 "{namespace},channel2", "event3___event4"}
+       expectedChannelsAndEvents := []string{"{namespace},channel1", "event1___event2",
+               "{namespace},channel2", "event3___event4"}
        expectedKeyVal := []interface{}{"{namespace},key", "data"}
 
        verifyChannelAndEvent := func(input []string) bool {
@@ -448,7 +459,7 @@ func TestWriteAndPublishSeveralChannelsAndEvents(t *testing.T) {
        m.On("MSetMPub", mock.MatchedBy(verifyChannelAndEvent), expectedKeyVal).Return(nil)
        m.AssertNotCalled(t, "MSet", expectedKeyVal)
        err := i.SetAndPublish([]string{"channel1", "event1", "channel2", "event3", "channel1", "event2", "channel2", "event4"},
-                                                                       "key", "data")
+               "key", "data")
        assert.Nil(t, err)
        m.AssertExpectations(t)
 }
@@ -532,7 +543,7 @@ func TestRemoveAndPublishSeveralChannelsAndEventsSuccessfully(t *testing.T) {
        m, i := setup()
 
        expectedChannelAndEvent := []string{"{namespace},channel1", "event1___event2",
-                                                                               "{namespace},channel2", "event3___event4"}
+               "{namespace},channel2", "event3___event4"}
        expectedKeys := []string{"{namespace},key1", "{namespace},key2"}
 
        verifyChannelAndEvent := func(input []string) bool {
@@ -540,8 +551,8 @@ func TestRemoveAndPublishSeveralChannelsAndEventsSuccessfully(t *testing.T) {
        }
        m.On("DelMPub", mock.MatchedBy(verifyChannelAndEvent), expectedKeys).Return(nil)
        err := i.RemoveAndPublish([]string{"channel1", "event1", "channel2", "event3",
-                                                                       "channel1", "event2", "channel2", "event4"},
-                                                                       []string{"key1", "key2"})
+               "channel1", "event2", "channel2", "event4"},
+               []string{"key1", "key2"})
        assert.Nil(t, err)
        m.AssertExpectations(t)
 }
@@ -856,7 +867,7 @@ func TestSetIfNotExistsAndPublishNoChannels(t *testing.T) {
        expectedKey := "{namespace},key"
        expectedData := interface{}("data")
 
-       m.On("SetNX", expectedKey, expectedData).Return(true, nil)
+       m.On("SetNX", expectedKey, expectedData, time.Duration(0)).Return(true, nil)
        status, err := i.SetIfNotExistsAndPublish([]string{}, "key", "data")
        assert.Nil(t, err)
        assert.True(t, status)
@@ -887,7 +898,7 @@ func TestSetIfNotExistsAndPublishIncorrectChannels(t *testing.T) {
        expectedData := interface{}("data")
 
        m.AssertNotCalled(t, "SetNXPub", expectedChannel, expectedEvent, expectedKey, expectedData)
-       m.AssertNotCalled(t, "SetNX", expectedKey, expectedData)
+       m.AssertNotCalled(t, "SetNX", expectedKey, expectedData, 0)
        status, err := i.SetIfNotExistsAndPublish([]string{"channel", "event", "channel2"}, "key", "data")
        assert.NotNil(t, err)
        assert.False(t, status)
@@ -914,7 +925,7 @@ func TestSetIfNotExistsSuccessfullyOkStatus(t *testing.T) {
 
        mSetNXExpectedKey := string("{namespace},key1")
        mSetNXExpectedData := interface{}("data")
-       m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData).Return(true, nil)
+       m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData, time.Duration(0)).Return(true, nil)
        status, err := i.SetIfNotExists("key1", "data")
        assert.Nil(t, err)
        assert.True(t, status)
@@ -926,7 +937,7 @@ func TestSetIfNotExistsSuccessfullyNOKStatus(t *testing.T) {
 
        mSetNXExpectedKey := string("{namespace},key1")
        mSetNXExpectedData := interface{}("data")
-       m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData).Return(false, nil)
+       m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData, time.Duration(0)).Return(false, nil)
        status, err := i.SetIfNotExists("key1", "data")
        assert.Nil(t, err)
        assert.False(t, status)
@@ -938,7 +949,7 @@ func TestSetIfNotExistsFailure(t *testing.T) {
 
        mSetNXExpectedKey := string("{namespace},key1")
        mSetNXExpectedData := interface{}("data")
-       m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData).Return(false, errors.New("Some error"))
+       m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData, time.Duration(0)).Return(false, errors.New("Some error"))
        status, err := i.SetIfNotExists("key1", "data")
        assert.NotNil(t, err)
        assert.False(t, status)
@@ -1075,7 +1086,7 @@ func TestRemoveAllAndPublishKeysReturnError(t *testing.T) {
        mKeysExpected := string("{namespace},*")
        mKeysReturn := []string{"{namespace},key1", "{namespace},key2"}
        mDelExpected := mKeysReturn
-       expectedChannelAndEvent := []string{"{namespace},channel", "event" }
+       expectedChannelAndEvent := []string{"{namespace},channel", "event"}
        m.On("Keys", mKeysExpected).Return(mKeysReturn, errors.New("Some error"))
        m.AssertNotCalled(t, "DelMPub", expectedChannelAndEvent, mDelExpected)
        err := i.RemoveAllAndPublish([]string{"channel", "event"})
@@ -1223,7 +1234,7 @@ func TestGetMembersSuccessfully(t *testing.T) {
        m.On("SMembers", groupExpected).Return(returnExpected, nil)
 
        result, err := i.GetMembers("group")
-       assert.Nil(t,err)
+       assert.Nil(t, err)
        assert.Equal(t, result, returnExpected)
        m.AssertExpectations(t)
 }
@@ -1236,7 +1247,7 @@ func TestGetMembersFail(t *testing.T) {
        m.On("SMembers", groupExpected).Return(returnExpected, errors.New("Some error"))
 
        result, err := i.GetMembers("group")
-       assert.NotNil(t,err)
+       assert.NotNil(t, err)
        assert.Equal(t, []string{}, result)
        m.AssertExpectations(t)
 }
@@ -1308,4 +1319,195 @@ func TestGroupSizeFail(t *testing.T) {
        assert.NotNil(t, err)
        assert.Equal(t, int64(0), result)
        m.AssertExpectations(t)
-}
\ No newline at end of file
+}
+
+func TestLockResourceSuccessfully(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil)
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{})
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       m.AssertExpectations(t)
+}
+
+func TestLockResourceFailure(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, errors.New("Some error"))
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{})
+       assert.NotNil(t, err)
+       assert.Nil(t, lock)
+       m.AssertExpectations(t)
+}
+
+func TestLockResourceTrySeveralTimesSuccessfully(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once()
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 2,
+       })
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       m.AssertExpectations(t)
+}
+
+func TestLockResourceTrySeveralTimesFailure(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once()
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, errors.New("Some error")).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 2,
+       })
+       assert.NotNil(t, err)
+       assert.Nil(t, lock)
+       m.AssertExpectations(t)
+}
+
+func TestLockResourceTrySeveralTimesUnableToGetResource(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once()
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 1,
+       })
+       assert.NotNil(t, err)
+       assert.EqualError(t, err, "Lock not obtained")
+       assert.Nil(t, lock)
+       m.AssertExpectations(t)
+}
+
+func TestReleaseResourceSuccessfully(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once()
+       m.On("DelIE", resourceExpected, mock.Anything).Return(true, nil).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 1,
+       })
+       err2 := lock.ReleaseResource()
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       assert.Nil(t, err2)
+       m.AssertExpectations(t)
+}
+
+func TestReleaseResourceFailure(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once()
+       m.On("DelIE", resourceExpected, mock.Anything).Return(true, errors.New("Some error")).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 1,
+       })
+       err2 := lock.ReleaseResource()
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       assert.NotNil(t, err2)
+       m.AssertExpectations(t)
+}
+
+func TestReleaseResourceLockNotHeld(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once()
+       m.On("DelIE", resourceExpected, mock.Anything).Return(false, nil).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 1,
+       })
+       err2 := lock.ReleaseResource()
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       assert.NotNil(t, err2)
+       assert.EqualError(t, err2, "Lock not held")
+       m.AssertExpectations(t)
+}
+
+func TestRefreshResourceSuccessfully(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once()
+       m.On("PExpireIE", resourceExpected, mock.Anything, time.Duration(1)).Return(nil).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 1,
+       })
+       err2 := lock.RefreshResource(time.Duration(1))
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       assert.Nil(t, err2)
+       m.AssertExpectations(t)
+}
+
+func TestRefreshResourceFailure(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once()
+       m.On("PExpireIE", resourceExpected, mock.Anything, time.Duration(1)).Return(errors.New("Some error")).Once()
+
+       lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{
+               RetryCount: 1,
+       })
+       err2 := lock.RefreshResource(time.Duration(1))
+       assert.Nil(t, err)
+       assert.NotNil(t, lock)
+       assert.NotNil(t, err2)
+       m.AssertExpectations(t)
+}
+
+func TestCheckResourceSuccessfully(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("PTTL", resourceExpected).Return(time.Duration(1), nil)
+       result, err := i.CheckResource("resource")
+       assert.Nil(t, err)
+       assert.Equal(t, result, time.Duration(1))
+       m.AssertExpectations(t)
+}
+
+func TestCheckResourceFailure(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("PTTL", resourceExpected).Return(time.Duration(1), errors.New("Some error"))
+       result, err := i.CheckResource("resource")
+       assert.NotNil(t, err)
+       assert.EqualError(t, err, "Some error")
+       assert.Equal(t, result, time.Duration(0))
+       m.AssertExpectations(t)
+}
+
+func TestCheckResourceInvalidResource(t *testing.T) {
+       m, i := setup()
+
+       resourceExpected := "{namespace},resource"
+       m.On("PTTL", resourceExpected).Return(time.Duration(-1), nil)
+       result, err := i.CheckResource("resource")
+       assert.NotNil(t, err)
+       assert.EqualError(t, err, "invalid resource given, no expiration time attached")
+       assert.Equal(t, result, time.Duration(0))
+       m.AssertExpectations(t)
+}