Improve locking
[nonrtric/plt/sme.git] / capifcore / internal / providermanagement / providermanagement.go
index 0213e89..ea1f1e6 100644 (file)
@@ -42,20 +42,20 @@ type ServiceRegister interface {
 }
 
 type ProviderManager struct {
-       onboardedProviders map[string]provapi.APIProviderEnrolmentDetails
-       lock               sync.Mutex
+       registeredProviders map[string]provapi.APIProviderEnrolmentDetails
+       lock                sync.Mutex
 }
 
 func NewProviderManager() *ProviderManager {
        return &ProviderManager{
-               onboardedProviders: make(map[string]provapi.APIProviderEnrolmentDetails),
+               registeredProviders: make(map[string]provapi.APIProviderEnrolmentDetails),
        }
 }
 
 func (pm *ProviderManager) IsFunctionRegistered(functionId string) bool {
        registered := false
 out:
-       for _, provider := range pm.onboardedProviders {
+       for _, provider := range pm.registeredProviders {
                for _, registeredFunc := range *provider.ApiProvFuncs {
                        if *registeredFunc.ApiProvFuncId == functionId {
                                registered = true
@@ -68,7 +68,7 @@ out:
 }
 
 func (pm *ProviderManager) GetAefsForPublisher(apfId string) []string {
-       for _, provider := range pm.onboardedProviders {
+       for _, provider := range pm.registeredProviders {
                for _, registeredFunc := range *provider.ApiProvFuncs {
                        if *registeredFunc.ApiProvFuncId == apfId && registeredFunc.ApiProvFuncRole == provapi.ApiProviderFuncRoleAPF {
                                return getExposedFuncs(provider.ApiProvFuncs)
@@ -119,13 +119,13 @@ func (pm *ProviderManager) prepareNewProvider(newProvider *provapi.APIProviderEn
        newProvider.ApiProvDomId = pm.getDomainId(newProvider.ApiProvDomInfo)
 
        pm.registerFunctions(newProvider.ApiProvFuncs)
-       pm.onboardedProviders[*newProvider.ApiProvDomId] = *newProvider
+       pm.registeredProviders[*newProvider.ApiProvDomId] = *newProvider
 }
 
 func (pm *ProviderManager) DeleteRegistrationsRegistrationId(ctx echo.Context, registrationId string) error {
 
-       log.Debug(pm.onboardedProviders)
-       if _, ok := pm.onboardedProviders[registrationId]; ok {
+       log.Debug(pm.registeredProviders)
+       if _, ok := pm.registeredProviders[registrationId]; ok {
                pm.deleteProvider(registrationId)
        }
 
@@ -136,13 +136,10 @@ func (pm *ProviderManager) deleteProvider(registrationId string) {
        log.Debug("Deleting provider", registrationId)
        pm.lock.Lock()
        defer pm.lock.Unlock()
-       delete(pm.onboardedProviders, registrationId)
+       delete(pm.registeredProviders, registrationId)
 }
 
 func (pm *ProviderManager) PutRegistrationsRegistrationId(ctx echo.Context, registrationId string) error {
-       pm.lock.Lock()
-       defer pm.lock.Unlock()
-
        errMsg := "Unable to update provider due to %s."
        registeredProvider, err := pm.checkIfProviderIsRegistered(registrationId, ctx)
        if err != nil {
@@ -154,14 +151,11 @@ func (pm *ProviderManager) PutRegistrationsRegistrationId(ctx echo.Context, regi
                return sendCoreError(ctx, http.StatusBadRequest, fmt.Sprintf(errMsg, err))
        }
 
-       updateDomainInfo(&updatedProvider, registeredProvider)
-
-       registeredProvider.ApiProvFuncs, err = updateFuncs(updatedProvider.ApiProvFuncs, registeredProvider.ApiProvFuncs)
+       err = pm.updateProvider(updatedProvider, registeredProvider)
        if err != nil {
                return sendCoreError(ctx, http.StatusBadRequest, fmt.Sprintf(errMsg, err))
        }
 
-       pm.onboardedProviders[*registeredProvider.ApiProvDomId] = *registeredProvider
        err = ctx.JSON(http.StatusOK, *registeredProvider)
        if err != nil {
                // Something really bad happened, tell Echo that our handler failed
@@ -172,7 +166,7 @@ func (pm *ProviderManager) PutRegistrationsRegistrationId(ctx echo.Context, regi
 }
 
 func (pm *ProviderManager) checkIfProviderIsRegistered(registrationId string, ctx echo.Context) (*provapi.APIProviderEnrolmentDetails, error) {
-       registeredProvider, ok := pm.onboardedProviders[registrationId]
+       registeredProvider, ok := pm.registeredProviders[registrationId]
        if !ok {
                return nil, fmt.Errorf("provider not onboarded")
        }
@@ -188,6 +182,22 @@ func getProviderFromRequest(ctx echo.Context) (provapi.APIProviderEnrolmentDetai
        return updatedProvider, nil
 }
 
+func (pm *ProviderManager) updateProvider(updatedProvider provapi.APIProviderEnrolmentDetails, registeredProvider *provapi.APIProviderEnrolmentDetails) error {
+       pm.lock.Lock()
+       defer pm.lock.Unlock()
+
+       updateDomainInfo(&updatedProvider, registeredProvider)
+
+       funcsAfterUpdate, err := updateFuncs(updatedProvider.ApiProvFuncs, registeredProvider.ApiProvFuncs)
+       if err == nil {
+               registeredProvider.ApiProvFuncs = funcsAfterUpdate
+
+               pm.registeredProviders[*registeredProvider.ApiProvDomId] = *registeredProvider
+               return nil
+       }
+       return err
+}
+
 func updateDomainInfo(updatedProvider, registeredProvider *provapi.APIProviderEnrolmentDetails) {
        if updatedProvider.ApiProvDomInfo != nil {
                registeredProvider.ApiProvDomInfo = updatedProvider.ApiProvDomInfo