From a10caff26de79e06caac2d00d3c11218e2d7ee87 Mon Sep 17 00:00:00 2001 From: Marco Tallskog Date: Wed, 14 Aug 2019 14:50:23 +0300 Subject: [PATCH] Add resource locking Implement methods that enable applications to create locks for a resources. Resource locks are plain keys with a random value with an expiration time. Only the holder of a lock can release it either by removing the lock or by letting it expire (after it will be automatically removed). By using a random value (that only lock holder knows), it can be ensured that others can't remove it. Resource locks are per namespace. Implemented methods are: LockResource: this method will check if the lock already exists and if not, it will create it with an expiration time. It is possible for the client to define with options how many times the lock will be tried and the interval how often it is tried. If retry is used, this method will block. ReleaseResource: remove the lock. If removing is tried after the expiration, an error will be returned. RefreshResource: with this method, it is possible to set a new expiration time for the lock. If the lock has already expired, an error will be returned. CheckResource: application can query the remaining expiration time with this method, regardless it is the owner of the lock or not. Change-Id: Ic6f5274c1740c7e36ddaba564024cffcc4c5de3d Signed-off-by: Marco Tallskog --- internal/sdlgoredis/sdlgoredis.go | 29 +++- internal/sdlgoredis/sdlgoredis_test.go | 154 ++++++++++++++---- sdl.go | 131 +++++++++++++++- sdl_test.go | 278 ++++++++++++++++++++++++++++----- 4 files changed, 515 insertions(+), 77 deletions(-) diff --git a/internal/sdlgoredis/sdlgoredis.go b/internal/sdlgoredis/sdlgoredis.go index ab56b12..0ccccb6 100644 --- a/internal/sdlgoredis/sdlgoredis.go +++ b/internal/sdlgoredis/sdlgoredis.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "os" + "strconv" "strings" "time" @@ -67,6 +68,11 @@ type RedisClient interface { SMembers(key string) *redis.StringSliceCmd SIsMember(key string, member interface{}) *redis.BoolCmd SCard(key string) *redis.IntCmd + PTTL(key string) *redis.DurationCmd + Eval(script string, keys []string, args ...interface{}) *redis.Cmd + EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd + ScriptExists(scripts ...string) *redis.BoolSliceCmd + ScriptLoad(script string) *redis.StringCmd } func checkResultAndError(result interface{}, err error) (bool, error) { @@ -281,8 +287,8 @@ func (db *DB) SetNXPub(channel, message, key string, data interface{}) (bool, er } return checkResultAndError(db.client.Do("SETNXPUB", key, data, channel, message).Result()) } -func (db *DB) SetNX(key string, data interface{}) (bool, error) { - return db.client.SetNX(key, data, 0).Result() +func (db *DB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) { + return db.client.SetNX(key, data, expiration).Result() } func (db *DB) DelIEPub(channel, message, key string, data interface{}) (bool, error) { @@ -323,3 +329,22 @@ func (db *DB) SCard(key string) (int64, error) { result, err := db.client.SCard(key).Result() return result, err } + +func (db *DB) PTTL(key string) (time.Duration, error) { + result, err := db.client.PTTL(key).Result() + return result, err +} + +var luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`) + +func (db *DB) PExpireIE(key string, data interface{}, expiration time.Duration) error { + expirationStr := strconv.FormatInt(int64(expiration/time.Millisecond), 10) + result, err := luaRefresh.Run(db.client, []string{key}, data, expirationStr).Result() + if err != nil { + return err + } + if result == int64(1) { + return nil + } + return errors.New("Lock not held") +} diff --git a/internal/sdlgoredis/sdlgoredis_test.go b/internal/sdlgoredis/sdlgoredis_test.go index 81c211a..962704c 100644 --- a/internal/sdlgoredis/sdlgoredis_test.go +++ b/internal/sdlgoredis/sdlgoredis_test.go @@ -1,4 +1,3 @@ - /* Copyright (c) 2019 AT&T Intellectual Property. Copyright (c) 2018-2019 Nokia. @@ -19,12 +18,13 @@ package sdlgoredis_test import ( - "testing" "errors" + "strconv" + "testing" "time" - "github.com/go-redis/redis" "gerrit.o-ran-sc.org/r/ric-plt/sdlgo/internal/sdlgoredis" + "github.com/go-redis/redis" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -109,6 +109,26 @@ func (m *clientMock) SCard(key string) *redis.IntCmd { return m.Called(key).Get(0).(*redis.IntCmd) } +func (m *clientMock) PTTL(key string) *redis.DurationCmd { + return m.Called(key).Get(0).(*redis.DurationCmd) +} + +func (m *clientMock) Eval(script string, keys []string, args ...interface{}) *redis.Cmd { + return m.Called(script, keys).Get(0).(*redis.Cmd) +} + +func (m *clientMock) EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd { + return m.Called(sha1, keys, args).Get(0).(*redis.Cmd) +} + +func (m *clientMock) ScriptExists(scripts ...string) *redis.BoolSliceCmd { + return m.Called(scripts).Get(0).(*redis.BoolSliceCmd) +} + +func (m *clientMock) ScriptLoad(script string) *redis.StringCmd { + return m.Called(script).Get(0).(*redis.StringCmd) +} + func setSubscribeNotifications() (*pubSubMock, sdlgoredis.SubscribeFn) { mock := new(pubSubMock) return mock, func(client sdlgoredis.RedisClient, channels ...string) sdlgoredis.Subscriber { @@ -128,12 +148,12 @@ func setup(commandsExists bool) (*pubSubMock, *clientMock, *sdlgoredis.DB) { if commandsExists { cmdResult = map[string]*redis.CommandInfo{ - "setie": &dummyCommandInfo, - "delie": &dummyCommandInfo, + "setie": &dummyCommandInfo, + "delie": &dummyCommandInfo, "setiepub": &dummyCommandInfo, "setnxpub": &dummyCommandInfo, "msetmpub": &dummyCommandInfo, - "delmpub": &dummyCommandInfo, + "delmpub": &dummyCommandInfo, } } else { cmdResult = map[string]*redis.CommandInfo{ @@ -149,7 +169,7 @@ func setup(commandsExists bool) (*pubSubMock, *clientMock, *sdlgoredis.DB) { func TestMSetSuccessfully(t *testing.T) { _, r, db := setup(true) expectedKeysAndValues := []interface{}{"key1", "value1", "key2", 2} - r.On("MSet",expectedKeysAndValues).Return(redis.NewStatusResult("OK", nil)) + r.On("MSet", expectedKeysAndValues).Return(redis.NewStatusResult("OK", nil)) err := db.MSet("key1", "value1", "key2", 2) assert.Nil(t, err) r.AssertExpectations(t) @@ -158,7 +178,7 @@ func TestMSetSuccessfully(t *testing.T) { func TestMSetFailure(t *testing.T) { _, r, db := setup(true) expectedKeysAndValues := []interface{}{"key1", "value1", "key2", 2} - r.On("MSet",expectedKeysAndValues).Return(redis.NewStatusResult("OK", errors.New("Some error"))) + r.On("MSet", expectedKeysAndValues).Return(redis.NewStatusResult("OK", errors.New("Some error"))) err := db.MSet("key1", "value1", "key2", 2) assert.NotNil(t, err) r.AssertExpectations(t) @@ -167,30 +187,30 @@ func TestMSetFailure(t *testing.T) { func TestMSetMPubSuccessfully(t *testing.T) { _, r, db := setup(true) expectedMessage := []interface{}{"MSETMPUB", 2, 2, "key1", "val1", "key2", "val2", - "chan1", "event1", "chan2", "event2"} + "chan1", "event1", "chan2", "event2"} r.On("Do", expectedMessage).Return(redis.NewCmdResult("", nil)) assert.Nil(t, db.MSetMPub([]string{"chan1", "event1", "chan2", "event2"}, - "key1", "val1", "key2", "val2")) + "key1", "val1", "key2", "val2")) r.AssertExpectations(t) } func TestMsetMPubFailure(t *testing.T) { _, r, db := setup(true) expectedMessage := []interface{}{"MSETMPUB", 2, 2, "key1", "val1", "key2", "val2", - "chan1", "event1", "chan2", "event2"} + "chan1", "event1", "chan2", "event2"} r.On("Do", expectedMessage).Return(redis.NewCmdResult("", errors.New("Some error"))) assert.NotNil(t, db.MSetMPub([]string{"chan1", "event1", "chan2", "event2"}, - "key1", "val1", "key2", "val2")) + "key1", "val1", "key2", "val2")) r.AssertExpectations(t) } func TestMSetMPubCommandMissing(t *testing.T) { _, r, db := setup(false) expectedMessage := []interface{}{"MSETMPUB", 2, 2, "key1", "val1", "key2", "val2", - "chan1", "event1", "chan2", "event2"} + "chan1", "event1", "chan2", "event2"} r.AssertNotCalled(t, "Do", expectedMessage) assert.NotNil(t, db.MSetMPub([]string{"chan1", "event1", "chan2", "event2"}, - "key1", "val1", "key2", "val2")) + "key1", "val1", "key2", "val2")) r.AssertExpectations(t) } @@ -211,7 +231,7 @@ func TestMGetFailure(t *testing.T) { expectedKeys := []string{"key1", "key2", "key3"} expectedResult := []interface{}{nil} r.On("MGet", expectedKeys).Return(redis.NewSliceResult(expectedResult, - errors.New("Some error"))) + errors.New("Some error"))) result, err := db.MGet([]string{"key1", "key2", "key3"}) assert.Equal(t, result, expectedResult) assert.NotNil(t, err) @@ -221,30 +241,30 @@ func TestMGetFailure(t *testing.T) { func TestDelMPubSuccessfully(t *testing.T) { _, r, db := setup(true) expectedMessage := []interface{}{"DELMPUB", 2, 2, "key1", "key2", "chan1", "event1", - "chan2", "event2"} + "chan2", "event2"} r.On("Do", expectedMessage).Return(redis.NewCmdResult("", nil)) assert.Nil(t, db.DelMPub([]string{"chan1", "event1", "chan2", "event2"}, - []string{"key1", "key2"})) + []string{"key1", "key2"})) r.AssertExpectations(t) } func TestDelMPubFailure(t *testing.T) { _, r, db := setup(true) expectedMessage := []interface{}{"DELMPUB", 2, 2, "key1", "key2", "chan1", "event1", - "chan2", "event2"} + "chan2", "event2"} r.On("Do", expectedMessage).Return(redis.NewCmdResult("", errors.New("Some error"))) assert.NotNil(t, db.DelMPub([]string{"chan1", "event1", "chan2", "event2"}, - []string{"key1", "key2"})) + []string{"key1", "key2"})) r.AssertExpectations(t) } func TestDelMPubCommandMissing(t *testing.T) { _, r, db := setup(false) expectedMessage := []interface{}{"DELMPUB", 2, 2, "key1", "key2", "chan1", "event1", - "chan2", "event2"} + "chan2", "event2"} r.AssertNotCalled(t, "Do", expectedMessage) assert.NotNil(t, db.DelMPub([]string{"chan1", "event1", "chan2", "event2"}, - []string{"key1", "key2"})) + []string{"key1", "key2"})) r.AssertExpectations(t) } @@ -280,7 +300,7 @@ func TestKeysFailure(t *testing.T) { expectedPattern := "pattern*" expectedResult := []string{} r.On("Keys", expectedPattern).Return(redis.NewStringSliceResult(expectedResult, - errors.New("Some error"))) + errors.New("Some error"))) _, err := db.Keys("pattern*") assert.NotNil(t, err) r.AssertExpectations(t) @@ -411,7 +431,7 @@ func TestSetNXSuccessfully(t *testing.T) { expectedKey := "key" expectedData := "data" r.On("SetNX", expectedKey, expectedData, time.Duration(0)).Return(redis.NewBoolResult(true, nil)) - result, err := db.SetNX("key", "data") + result, err := db.SetNX("key", "data", 0) assert.True(t, result) assert.Nil(t, err) r.AssertExpectations(t) @@ -422,7 +442,7 @@ func TestSetNXUnsuccessfully(t *testing.T) { expectedKey := "key" expectedData := "data" r.On("SetNX", expectedKey, expectedData, time.Duration(0)).Return(redis.NewBoolResult(false, nil)) - result, err := db.SetNX("key", "data") + result, err := db.SetNX("key", "data", 0) assert.False(t, result) assert.Nil(t, err) r.AssertExpectations(t) @@ -433,8 +453,8 @@ func TestSetNXFailure(t *testing.T) { expectedKey := "key" expectedData := "data" r.On("SetNX", expectedKey, expectedData, time.Duration(0)). - Return(redis.NewBoolResult(false,errors.New("Some error"))) - result, err := db.SetNX("key", "data") + Return(redis.NewBoolResult(false, errors.New("Some error"))) + result, err := db.SetNX("key", "data", 0) assert.False(t, result) assert.NotNil(t, err) r.AssertExpectations(t) @@ -572,7 +592,7 @@ func TestSMembersFailure(t *testing.T) { expectedKey := "key" expectedResult := []string{"member1", "member2"} r.On("SMembers", expectedKey).Return(redis.NewStringSliceResult(expectedResult, - errors.New("Some error"))) + errors.New("Some error"))) result, err := db.SMembers("key") assert.Equal(t, result, expectedResult) assert.NotNil(t, err) @@ -646,10 +666,10 @@ func TestSubscribeChannelDBSubscribeRXUnsubscribe(t *testing.T) { ps.On("Close").Return(nil) count := 0 receivedChannel := "" - db.SubscribeChannelDB(func(channel string, payload ...string){ + db.SubscribeChannelDB(func(channel string, payload ...string) { count++ receivedChannel = channel - },"{prefix}", "---", "{prefix}channel") + }, "{prefix}", "---", "{prefix}channel") ch <- &msg db.UnsubscribeChannelDB("{prefix}channel") time.Sleep(1 * time.Second) @@ -679,16 +699,16 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) { ps.On("Close").Return(nil) count := 0 receivedChannel1 := "" - db.SubscribeChannelDB(func(channel string, payload ...string){ + db.SubscribeChannelDB(func(channel string, payload ...string) { count++ receivedChannel1 = channel - },"{prefix}", "---", "{prefix}channel1") + }, "{prefix}", "---", "{prefix}channel1") ch <- &msg1 receivedChannel2 := "" - db.SubscribeChannelDB(func(channel string, payload ...string){ + db.SubscribeChannelDB(func(channel string, payload ...string) { count++ receivedChannel2 = channel - },"{prefix}", "---", "{prefix}channel2") + }, "{prefix}", "---", "{prefix}channel2") db.UnsubscribeChannelDB("{prefix}channel1") ch <- &msg2 @@ -699,4 +719,70 @@ func TestSubscribeChannelDBSubscribeTwoUnsubscribeOne(t *testing.T) { assert.Equal(t, "channel2", receivedChannel2) r.AssertExpectations(t) ps.AssertExpectations(t) -} \ No newline at end of file +} + +func TestPTTLSuccessfully(t *testing.T) { + _, r, db := setup(true) + expectedKey := "key" + expectedResult := time.Duration(1) + r.On("PTTL", expectedKey).Return(redis.NewDurationResult(expectedResult, + nil)) + result, err := db.PTTL("key") + assert.Equal(t, result, expectedResult) + assert.Nil(t, err) + r.AssertExpectations(t) +} + +func TestPTTLFailure(t *testing.T) { + _, r, db := setup(true) + expectedKey := "key" + expectedResult := time.Duration(1) + r.On("PTTL", expectedKey).Return(redis.NewDurationResult(expectedResult, + errors.New("Some error"))) + result, err := db.PTTL("key") + assert.Equal(t, result, expectedResult) + assert.NotNil(t, err) + r.AssertExpectations(t) +} + +func TestPExpireIESuccessfully(t *testing.T) { + _, r, db := setup(true) + expectedKey := "key" + expectedData := "data" + expectedDuration := strconv.FormatInt(int64(10000), 10) + + r.On("EvalSha", mock.Anything, []string{expectedKey}, []interface{}{expectedData, expectedDuration}). + Return(redis.NewCmdResult(int64(1), nil)) + + err := db.PExpireIE("key", "data", 10*time.Second) + assert.Nil(t, err) + r.AssertExpectations(t) +} + +func TestPExpireIEFailure(t *testing.T) { + _, r, db := setup(true) + expectedKey := "key" + expectedData := "data" + expectedDuration := strconv.FormatInt(int64(10000), 10) + + r.On("EvalSha", mock.Anything, []string{expectedKey}, []interface{}{expectedData, expectedDuration}). + Return(redis.NewCmdResult(int64(1), errors.New("Some error"))) + + err := db.PExpireIE("key", "data", 10*time.Second) + assert.NotNil(t, err) + r.AssertExpectations(t) +} + +func TestPExpireIELockNotHeld(t *testing.T) { + _, r, db := setup(true) + expectedKey := "key" + expectedData := "data" + expectedDuration := strconv.FormatInt(int64(10000), 10) + + r.On("EvalSha", mock.Anything, []string{expectedKey}, []interface{}{expectedData, expectedDuration}). + Return(redis.NewCmdResult(int64(0), nil)) + + err := db.PExpireIE("key", "data", 10*time.Second) + assert.NotNil(t, err) + r.AssertExpectations(t) +} diff --git a/sdl.go b/sdl.go index 78dd15d..426c2c4 100644 --- a/sdl.go +++ b/sdl.go @@ -18,10 +18,15 @@ package sdlgo import ( + "crypto/rand" + "encoding/base64" "errors" "fmt" + "io" "reflect" "strings" + "sync" + "time" "gerrit.o-ran-sc.org/r/ric-plt/sdlgo/internal/sdlgoredis" ) @@ -32,6 +37,8 @@ type SdlInstance struct { nameSpace string nsPrefix string eventSeparator string + mutex sync.Mutex + tmp []byte iDatabase } @@ -294,7 +301,7 @@ func (s *SdlInstance) SetIf(key string, oldData, newData interface{}) (bool, err //given channel. func (s *SdlInstance) SetIfNotExistsAndPublish(channelsAndEvents []string, key string, data interface{}) (bool, error) { if len(channelsAndEvents) == 0 { - return s.SetNX(s.nsPrefix+key, data) + return s.SetNX(s.nsPrefix+key, data, 0) } if err := s.checkChannelsAndEvents("SetIfNotExistsAndPublish", channelsAndEvents); err != nil { return false, err @@ -307,7 +314,7 @@ func (s *SdlInstance) SetIfNotExistsAndPublish(channelsAndEvents []string, key s //then it's value is not changed. Checking the key existence and potential set operation //is done atomically. func (s *SdlInstance) SetIfNotExists(key string, data interface{}) (bool, error) { - return s.SetNX(s.nsPrefix+key, data) + return s.SetNX(s.nsPrefix+key, data, 0) } //RemoveAndPublish removes data from SDL. Operation is done atomically, i.e. either all succeeds or fails. @@ -466,6 +473,122 @@ func (s *SdlInstance) GroupSize(group string) (int64, error) { return retVal, err } +func (s *SdlInstance) randomToken() (string, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if len(s.tmp) == 0 { + s.tmp = make([]byte, 16) + } + + if _, err := io.ReadFull(rand.Reader, s.tmp); err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(s.tmp), nil +} + +//LockResource function is used for locking a resource. The resource lock in +//practice is a key with random value that is set to expire after a time +//period. The value written to key is a random value, thus only the instance +//created a lock, can release it. Resource locks are per namespace. +func (s *SdlInstance) LockResource(resource string, expiration time.Duration, opt *Options) (*Lock, error) { + value, err := s.randomToken() + if err != nil { + return nil, err + } + + var retryTimer *time.Timer + for i, attempts := 0, opt.getRetryCount()+1; i < attempts; i++ { + ok, err := s.SetNX(s.nsPrefix+resource, value, expiration) + if err != nil { + return nil, err + } else if ok { + return &Lock{s: s, key: resource, value: value}, nil + } + if retryTimer == nil { + retryTimer = time.NewTimer(opt.getRetryWait()) + defer retryTimer.Stop() + } else { + retryTimer.Reset(opt.getRetryWait()) + } + + select { + case <-retryTimer.C: + } + } + return nil, errors.New("Lock not obtained") +} + +//ReleaseResource removes the lock from a resource. If lock is already +//expired or some other instance is keeping the lock (lock taken after expiration), +//an error is returned. +func (l *Lock) ReleaseResource() error { + ok, err := l.s.DelIE(l.s.nsPrefix+l.key, l.value) + + if err != nil { + return err + } + if !ok { + return errors.New("Lock not held") + } + return nil +} + +//RefreshResource function can be used to set a new expiration time for the +//resource lock (if the lock still exists). The old remaining expiration +//time is overwritten with the given new expiration time. +func (l *Lock) RefreshResource(expiration time.Duration) error { + err := l.s.PExpireIE(l.s.nsPrefix+l.key, l.value, expiration) + return err +} + +//CheckResource returns the expiration time left for a resource. +//If the resource doesn't exist, -2 is returned. +func (s *SdlInstance) CheckResource(resource string) (time.Duration, error) { + result, err := s.PTTL(s.nsPrefix + resource) + if err != nil { + return 0, err + } + if result == time.Duration(-1) { + return 0, errors.New("invalid resource given, no expiration time attached") + } + return result, nil +} + +//Options struct defines the behaviour for getting the resource lock. +type Options struct { + //The number of time the lock will be tried. + //Default: 0 = no retry + RetryCount int + + //Wait between the retries. + //Default: 100ms + RetryWait time.Duration +} + +func (o *Options) getRetryCount() int { + if o != nil && o.RetryCount > 0 { + return o.RetryCount + } + return 0 +} + +func (o *Options) getRetryWait() time.Duration { + if o != nil && o.RetryWait > 0 { + return o.RetryWait + } + return 100 * time.Millisecond +} + +//Lock struct identifies the resource lock instance. Releasing and adjusting the +//expirations are done using the methods defined for this struct. +type Lock struct { + s *SdlInstance + key string + value string +} + type iDatabase interface { SubscribeChannelDB(cb sdlgoredis.ChannelNotificationCb, channelPrefix, eventSeparator string, channels ...string) UnsubscribeChannelDB(channels ...string) @@ -478,7 +601,7 @@ type iDatabase interface { Keys(key string) ([]string, error) SetIE(key string, oldData, newData interface{}) (bool, error) SetIEPub(channel, message, key string, oldData, newData interface{}) (bool, error) - SetNX(key string, data interface{}) (bool, error) + SetNX(key string, data interface{}, expiration time.Duration) (bool, error) SetNXPub(channel, message, key string, data interface{}) (bool, error) DelIE(key string, data interface{}) (bool, error) DelIEPub(channel, message, key string, data interface{}) (bool, error) @@ -487,4 +610,6 @@ type iDatabase interface { SMembers(key string) ([]string, error) SIsMember(key string, data interface{}) (bool, error) SCard(key string) (int64, error) + PTTL(key string) (time.Duration, error) + PExpireIE(key string, data interface{}, expiration time.Duration) error } diff --git a/sdl_test.go b/sdl_test.go index 652ec55..832e8eb 100644 --- a/sdl_test.go +++ b/sdl_test.go @@ -19,8 +19,9 @@ package sdlgo_test import ( "errors" - "testing" "reflect" + "testing" + "time" "gerrit.o-ran-sc.org/r/ric-plt/sdlgo" "gerrit.o-ran-sc.org/r/ric-plt/sdlgo/internal/sdlgoredis" @@ -85,8 +86,8 @@ func (m *mockDB) SetIEPub(channel, message, key string, oldData, newData interfa return a.Bool(0), a.Error(1) } -func (m *mockDB) SetNX(key string, data interface{}) (bool, error) { - a := m.Called(key, data) +func (m *mockDB) SetNX(key string, data interface{}, expiration time.Duration) (bool, error) { + a := m.Called(key, data, expiration) return a.Bool(0), a.Error(1) } @@ -129,6 +130,16 @@ func (m *mockDB) SCard(key string) (int64, error) { return a.Get(0).(int64), a.Error(1) } +func (m *mockDB) PTTL(key string) (time.Duration, error) { + a := m.Called(key) + return a.Get(0).(time.Duration), a.Error(1) +} + +func (m *mockDB) PExpireIE(key string, data interface{}, expiration time.Duration) error { + a := m.Called(key, data, expiration) + return a.Error(0) +} + func setup() (*mockDB, *sdlgo.SdlInstance) { m := new(mockDB) i := sdlgo.NewSdlInstance("namespace", m) @@ -136,21 +147,21 @@ func setup() (*mockDB, *sdlgo.SdlInstance) { } func verifySliceInOrder(a, b []string) bool { - for i, v := range a { - found := false - if i%2 == 0 { - for j, x := range b { - if j%2 == 0 && x == v && a[i+1] == b[j+1] { - found = true - break - } - } - if !found { - return false + for i, v := range a { + found := false + if i%2 == 0 { + for j, x := range b { + if j%2 == 0 && x == v && a[i+1] == b[j+1] { + found = true + break } } + if !found { + return false + } } - return true + } + return true } @@ -292,16 +303,16 @@ func TestWriteByteArrayAsValue(t *testing.T) { m.AssertExpectations(t) } -func TestWriteMapAsInput(t *testing.T){ +func TestWriteMapAsInput(t *testing.T) { m, i := setup() setExpected := []interface{}{"{namespace},key1", "string123", - "{namespace},key22", 12, - "{namespace},key333", []byte{1,2,3,4,5}} + "{namespace},key22", 12, + "{namespace},key333", []byte{1, 2, 3, 4, 5}} inputMap := map[string]interface{}{ - "key1": "string123", - "key22": 12, - "key333": []byte{1,2,3,4,5}, + "key1": "string123", + "key22": 12, + "key333": []byte{1, 2, 3, 4, 5}, } m.On("MSet", mock.MatchedBy(func(input []interface{}) bool { @@ -318,7 +329,7 @@ func TestWriteMapAsInput(t *testing.T){ } } return true - })).Return(nil) + })).Return(nil) err := i.Set(inputMap) assert.Nil(t, err) @@ -436,10 +447,10 @@ func TestWriteAndPublishOneKeyOneChannel(t *testing.T) { } func TestWriteAndPublishSeveralChannelsAndEvents(t *testing.T) { - m , i := setup() + m, i := setup() - expectedChannelsAndEvents := []string{"{namespace},channel1", "event1___event2", - "{namespace},channel2", "event3___event4"} + expectedChannelsAndEvents := []string{"{namespace},channel1", "event1___event2", + "{namespace},channel2", "event3___event4"} expectedKeyVal := []interface{}{"{namespace},key", "data"} verifyChannelAndEvent := func(input []string) bool { @@ -448,7 +459,7 @@ func TestWriteAndPublishSeveralChannelsAndEvents(t *testing.T) { m.On("MSetMPub", mock.MatchedBy(verifyChannelAndEvent), expectedKeyVal).Return(nil) m.AssertNotCalled(t, "MSet", expectedKeyVal) err := i.SetAndPublish([]string{"channel1", "event1", "channel2", "event3", "channel1", "event2", "channel2", "event4"}, - "key", "data") + "key", "data") assert.Nil(t, err) m.AssertExpectations(t) } @@ -532,7 +543,7 @@ func TestRemoveAndPublishSeveralChannelsAndEventsSuccessfully(t *testing.T) { m, i := setup() expectedChannelAndEvent := []string{"{namespace},channel1", "event1___event2", - "{namespace},channel2", "event3___event4"} + "{namespace},channel2", "event3___event4"} expectedKeys := []string{"{namespace},key1", "{namespace},key2"} verifyChannelAndEvent := func(input []string) bool { @@ -540,8 +551,8 @@ func TestRemoveAndPublishSeveralChannelsAndEventsSuccessfully(t *testing.T) { } m.On("DelMPub", mock.MatchedBy(verifyChannelAndEvent), expectedKeys).Return(nil) err := i.RemoveAndPublish([]string{"channel1", "event1", "channel2", "event3", - "channel1", "event2", "channel2", "event4"}, - []string{"key1", "key2"}) + "channel1", "event2", "channel2", "event4"}, + []string{"key1", "key2"}) assert.Nil(t, err) m.AssertExpectations(t) } @@ -856,7 +867,7 @@ func TestSetIfNotExistsAndPublishNoChannels(t *testing.T) { expectedKey := "{namespace},key" expectedData := interface{}("data") - m.On("SetNX", expectedKey, expectedData).Return(true, nil) + m.On("SetNX", expectedKey, expectedData, time.Duration(0)).Return(true, nil) status, err := i.SetIfNotExistsAndPublish([]string{}, "key", "data") assert.Nil(t, err) assert.True(t, status) @@ -887,7 +898,7 @@ func TestSetIfNotExistsAndPublishIncorrectChannels(t *testing.T) { expectedData := interface{}("data") m.AssertNotCalled(t, "SetNXPub", expectedChannel, expectedEvent, expectedKey, expectedData) - m.AssertNotCalled(t, "SetNX", expectedKey, expectedData) + m.AssertNotCalled(t, "SetNX", expectedKey, expectedData, 0) status, err := i.SetIfNotExistsAndPublish([]string{"channel", "event", "channel2"}, "key", "data") assert.NotNil(t, err) assert.False(t, status) @@ -914,7 +925,7 @@ func TestSetIfNotExistsSuccessfullyOkStatus(t *testing.T) { mSetNXExpectedKey := string("{namespace},key1") mSetNXExpectedData := interface{}("data") - m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData).Return(true, nil) + m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData, time.Duration(0)).Return(true, nil) status, err := i.SetIfNotExists("key1", "data") assert.Nil(t, err) assert.True(t, status) @@ -926,7 +937,7 @@ func TestSetIfNotExistsSuccessfullyNOKStatus(t *testing.T) { mSetNXExpectedKey := string("{namespace},key1") mSetNXExpectedData := interface{}("data") - m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData).Return(false, nil) + m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData, time.Duration(0)).Return(false, nil) status, err := i.SetIfNotExists("key1", "data") assert.Nil(t, err) assert.False(t, status) @@ -938,7 +949,7 @@ func TestSetIfNotExistsFailure(t *testing.T) { mSetNXExpectedKey := string("{namespace},key1") mSetNXExpectedData := interface{}("data") - m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData).Return(false, errors.New("Some error")) + m.On("SetNX", mSetNXExpectedKey, mSetNXExpectedData, time.Duration(0)).Return(false, errors.New("Some error")) status, err := i.SetIfNotExists("key1", "data") assert.NotNil(t, err) assert.False(t, status) @@ -1075,7 +1086,7 @@ func TestRemoveAllAndPublishKeysReturnError(t *testing.T) { mKeysExpected := string("{namespace},*") mKeysReturn := []string{"{namespace},key1", "{namespace},key2"} mDelExpected := mKeysReturn - expectedChannelAndEvent := []string{"{namespace},channel", "event" } + expectedChannelAndEvent := []string{"{namespace},channel", "event"} m.On("Keys", mKeysExpected).Return(mKeysReturn, errors.New("Some error")) m.AssertNotCalled(t, "DelMPub", expectedChannelAndEvent, mDelExpected) err := i.RemoveAllAndPublish([]string{"channel", "event"}) @@ -1223,7 +1234,7 @@ func TestGetMembersSuccessfully(t *testing.T) { m.On("SMembers", groupExpected).Return(returnExpected, nil) result, err := i.GetMembers("group") - assert.Nil(t,err) + assert.Nil(t, err) assert.Equal(t, result, returnExpected) m.AssertExpectations(t) } @@ -1236,7 +1247,7 @@ func TestGetMembersFail(t *testing.T) { m.On("SMembers", groupExpected).Return(returnExpected, errors.New("Some error")) result, err := i.GetMembers("group") - assert.NotNil(t,err) + assert.NotNil(t, err) assert.Equal(t, []string{}, result) m.AssertExpectations(t) } @@ -1308,4 +1319,195 @@ func TestGroupSizeFail(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, int64(0), result) m.AssertExpectations(t) -} \ No newline at end of file +} + +func TestLockResourceSuccessfully(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil) + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{}) + assert.Nil(t, err) + assert.NotNil(t, lock) + m.AssertExpectations(t) +} + +func TestLockResourceFailure(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, errors.New("Some error")) + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{}) + assert.NotNil(t, err) + assert.Nil(t, lock) + m.AssertExpectations(t) +} + +func TestLockResourceTrySeveralTimesSuccessfully(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once() + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 2, + }) + assert.Nil(t, err) + assert.NotNil(t, lock) + m.AssertExpectations(t) +} + +func TestLockResourceTrySeveralTimesFailure(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once() + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, errors.New("Some error")).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 2, + }) + assert.NotNil(t, err) + assert.Nil(t, lock) + m.AssertExpectations(t) +} + +func TestLockResourceTrySeveralTimesUnableToGetResource(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once() + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(false, nil).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 1, + }) + assert.NotNil(t, err) + assert.EqualError(t, err, "Lock not obtained") + assert.Nil(t, lock) + m.AssertExpectations(t) +} + +func TestReleaseResourceSuccessfully(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once() + m.On("DelIE", resourceExpected, mock.Anything).Return(true, nil).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 1, + }) + err2 := lock.ReleaseResource() + assert.Nil(t, err) + assert.NotNil(t, lock) + assert.Nil(t, err2) + m.AssertExpectations(t) +} + +func TestReleaseResourceFailure(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once() + m.On("DelIE", resourceExpected, mock.Anything).Return(true, errors.New("Some error")).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 1, + }) + err2 := lock.ReleaseResource() + assert.Nil(t, err) + assert.NotNil(t, lock) + assert.NotNil(t, err2) + m.AssertExpectations(t) +} + +func TestReleaseResourceLockNotHeld(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once() + m.On("DelIE", resourceExpected, mock.Anything).Return(false, nil).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 1, + }) + err2 := lock.ReleaseResource() + assert.Nil(t, err) + assert.NotNil(t, lock) + assert.NotNil(t, err2) + assert.EqualError(t, err2, "Lock not held") + m.AssertExpectations(t) +} + +func TestRefreshResourceSuccessfully(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once() + m.On("PExpireIE", resourceExpected, mock.Anything, time.Duration(1)).Return(nil).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 1, + }) + err2 := lock.RefreshResource(time.Duration(1)) + assert.Nil(t, err) + assert.NotNil(t, lock) + assert.Nil(t, err2) + m.AssertExpectations(t) +} + +func TestRefreshResourceFailure(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("SetNX", resourceExpected, mock.Anything, time.Duration(1)).Return(true, nil).Once() + m.On("PExpireIE", resourceExpected, mock.Anything, time.Duration(1)).Return(errors.New("Some error")).Once() + + lock, err := i.LockResource("resource", time.Duration(1), &sdlgo.Options{ + RetryCount: 1, + }) + err2 := lock.RefreshResource(time.Duration(1)) + assert.Nil(t, err) + assert.NotNil(t, lock) + assert.NotNil(t, err2) + m.AssertExpectations(t) +} + +func TestCheckResourceSuccessfully(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("PTTL", resourceExpected).Return(time.Duration(1), nil) + result, err := i.CheckResource("resource") + assert.Nil(t, err) + assert.Equal(t, result, time.Duration(1)) + m.AssertExpectations(t) +} + +func TestCheckResourceFailure(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("PTTL", resourceExpected).Return(time.Duration(1), errors.New("Some error")) + result, err := i.CheckResource("resource") + assert.NotNil(t, err) + assert.EqualError(t, err, "Some error") + assert.Equal(t, result, time.Duration(0)) + m.AssertExpectations(t) +} + +func TestCheckResourceInvalidResource(t *testing.T) { + m, i := setup() + + resourceExpected := "{namespace},resource" + m.On("PTTL", resourceExpected).Return(time.Duration(-1), nil) + result, err := i.CheckResource("resource") + assert.NotNil(t, err) + assert.EqualError(t, err, "invalid resource given, no expiration time attached") + assert.Equal(t, result, time.Duration(0)) + m.AssertExpectations(t) +} -- 2.16.6