Merge "ICS tests with istio and JWTs"
authorHenrik Andersson <henrik.b.andersson@est.tech>
Fri, 25 Mar 2022 10:14:25 +0000 (10:14 +0000)
committerGerrit Code Review <gerrit@o-ran-sc.org>
Fri, 25 Mar 2022 10:14:25 +0000 (10:14 +0000)
auth-token-fetch/HTTPClient.go
auth-token-fetch/HTTPClient_test.go
auth-token-fetch/config.go
auth-token-fetch/config_test.go
auth-token-fetch/go.mod
auth-token-fetch/go.sum
auth-token-fetch/main.go
auth-token-fetch/main_test.go

index ab76b13..a765461 100644 (file)
@@ -23,6 +23,7 @@ package main
 import (
        "bytes"
        "crypto/tls"
+       "crypto/x509"
        "fmt"
        "io"
 
@@ -38,10 +39,10 @@ type HTTPClient interface {
        Do(*http.Request) (*http.Response, error)
 }
 
-func CreateHttpClient(cert tls.Certificate, timeout time.Duration) *http.Client {
+func CreateHttpClient(cert tls.Certificate, caCerts *x509.CertPool, timeout time.Duration) *http.Client {
        return &http.Client{
                Timeout:   timeout,
-               Transport: createTransport(cert),
+               Transport: createTransport(cert, caCerts),
        }
 }
 
@@ -89,9 +90,11 @@ func getRequestError(response *http.Response) RequestError {
        return putError
 }
 
-func createTransport(cert tls.Certificate) *http.Transport {
+func createTransport(cert tls.Certificate, caCerts *x509.CertPool) *http.Transport {
        return &http.Transport{
                TLSClientConfig: &tls.Config{
+                       ClientCAs: caCerts,
+                       RootCAs:   caCerts,
                        Certificates: []tls.Certificate{
                                cert,
                        },
index e0a4cd1..7b8deb0 100644 (file)
@@ -43,7 +43,7 @@ func TestRequestError_Error(t *testing.T) {
 func Test_CreateClient(t *testing.T) {
        assertions := require.New(t)
 
-       client := CreateHttpClient(tls.Certificate{}, 5*time.Second)
+       client := CreateHttpClient(tls.Certificate{}, nil, 5*time.Second)
 
        transport := client.Transport
        assertions.Equal("*http.Transport", reflect.TypeOf(transport).String())
index 18d610d..dfeeb96 100644 (file)
@@ -33,6 +33,7 @@ import (
 type Config struct {
        LogLevel                log.Level
        CertPath                string
+       CACertsPath             string
        KeyPath                 string
        AuthServiceUrl          string
        GrantType               string
@@ -44,14 +45,15 @@ type Config struct {
 
 func NewConfig() *Config {
        return &Config{
-               CertPath:                getEnv("CERT_PATH", "security/tls.crt"),
-               KeyPath:                 getEnv("CERT_KEY_PATH", "security/tls.key"),
+               CertPath:                getEnv("CERT_PATH", "security/tls.crt", false),
+               KeyPath:                 getEnv("CERT_KEY_PATH", "security/tls.key", false),
+               CACertsPath:             getEnv("ROOT_CA_CERTS_PATH", "", false),
                LogLevel:                getLogLevel(),
-               GrantType:               getEnv("CREDS_GRANT_TYPE", ""),
-               ClientSecret:            getEnv("CREDS_CLIENT_SECRET", ""),
-               ClientId:                getEnv("CREDS_CLIENT_ID", ""),
-               AuthTokenOutputFileName: getEnv("OUTPUT_FILE", "/tmp/authToken.txt"),
-               AuthServiceUrl:          getEnv("AUTH_SERVICE_URL", "https://localhost:39687/example-singlelogin-sever/login"),
+               GrantType:               getEnv("CREDS_GRANT_TYPE", "", false),
+               ClientSecret:            getEnv("CREDS_CLIENT_SECRET", "", true),
+               ClientId:                getEnv("CREDS_CLIENT_ID", "", false),
+               AuthTokenOutputFileName: getEnv("OUTPUT_FILE", "/tmp/authToken.txt", false),
+               AuthServiceUrl:          getEnv("AUTH_SERVICE_URL", "https://localhost:39687/example-singlelogin-sever/login", false),
                RefreshMarginSeconds:    getEnvAsInt("REFRESH_MARGIN_SECONDS", 5, 1, 3600),
        }
 }
@@ -61,21 +63,29 @@ func validateConfiguration(configuration *Config) error {
                return fmt.Errorf("missing CERT_PATH and/or CERT_KEY_PATH")
        }
 
+       if configuration.CACertsPath == "" {
+               log.Warn("No Root CA certs loaded, no trust validation may be performed")
+       }
+
        return nil
 }
 
-func getEnv(key string, defaultVal string) string {
+func getEnv(key string, defaultVal string, secret bool) string {
        if value, exists := os.LookupEnv(key); exists {
-               log.Debugf("Using value: '%v' for '%v'", value, key)
+               if !secret {
+                       log.Debugf("Using value: '%v' for '%v'", value, key)
+               }
                return value
        } else {
-               log.Debugf("Using default value: '%v' for '%v'", defaultVal, key)
+               if !secret {
+                       log.Debugf("Using default value: '%v' for '%v'", defaultVal, key)
+               }
                return defaultVal
        }
 }
 
 func getEnvAsInt(name string, defaultVal int, min int, max int) int {
-       valueStr := getEnv(name, "")
+       valueStr := getEnv(name, fmt.Sprint(defaultVal), false)
        if value, err := strconv.Atoi(valueStr); err == nil {
                if value < min || value > max {
                        log.Warnf("Value out of range: '%v' for variable: '%v'. Default value: '%v' will be used", valueStr, name, defaultVal)
@@ -90,7 +100,7 @@ func getEnvAsInt(name string, defaultVal int, min int, max int) int {
 }
 
 func getLogLevel() log.Level {
-       logLevelStr := getEnv("LOG_LEVEL", "Info")
+       logLevelStr := getEnv("LOG_LEVEL", "Info", false)
        if loglevel, err := log.ParseLevel(logLevelStr); err == nil {
                return loglevel
        } else {
index 8b441c1..92a63d8 100644 (file)
@@ -39,6 +39,7 @@ func TestNew_envVarsSetConfigContainSetValues(t *testing.T) {
        os.Setenv("OUTPUT_FILE", "OUTPUT_FILE")
        os.Setenv("AUTH_SERVICE_URL", "AUTH_SERVICE_URL")
        os.Setenv("REFRESH_MARGIN_SECONDS", "33")
+       os.Setenv("ROOT_CA_CERTS_PATH", "ROOT_CA_CERTS_PATH")
 
        t.Cleanup(func() {
                os.Clearenv()
@@ -52,9 +53,11 @@ func TestNew_envVarsSetConfigContainSetValues(t *testing.T) {
                ClientSecret:            "CREDS_CLIENT_SECRET",
                ClientId:                "CREDS_CLIENT_ID",
                AuthTokenOutputFileName: "OUTPUT_FILE",
+               CACertsPath:             "ROOT_CA_CERTS_PATH",
                RefreshMarginSeconds:    33,
        }
        got := NewConfig()
+       assertions.Equal(nil, validateConfiguration(got))
 
        assertions.Equal(&wantConfig, got)
 }
index b1fd1b6..4fbd6d3 100644 (file)
@@ -11,7 +11,7 @@ require (
        github.com/davecgh/go-spew v1.1.1 // indirect
        github.com/kr/pretty v0.2.0 // indirect
        github.com/pmezard/go-difflib v1.0.0 // indirect
-       golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect
+       golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 // indirect
        gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
        gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
 )
index f638fbf..e23f462 100644 (file)
@@ -15,8 +15,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
 github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y=
-golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 h1:OH54vjqzRWmbJ62fjuhxy7AxFFgoHN0/DPc/UrL8cAs=
+golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
index 9a63534..41f49d3 100644 (file)
@@ -22,8 +22,8 @@ package main
 
 import (
        "crypto/tls"
+       "crypto/x509"
        "encoding/json"
-       "fmt"
        "io/ioutil"
        "net/http"
        "net/url"
@@ -74,14 +74,10 @@ func start(context *Context) {
                log.Fatalf("Stopping due to error: %v", err)
        }
 
-       var cert tls.Certificate
-       if c, err := loadCertificate(context.Config.CertPath, context.Config.KeyPath); err == nil {
-               cert = c
-       } else {
-               log.Fatalf("Stopping due to error: %v", err)
-       }
+       cert := loadCertificate(context.Config.CertPath, context.Config.KeyPath)
+       caCerts := loadCaCerts(context.Config.CACertsPath)
 
-       webClient := CreateHttpClient(cert, 10*time.Second)
+       webClient := CreateHttpClient(cert, caCerts, 10*time.Second)
 
        go periodicRefreshIwtToken(webClient, context)
 }
@@ -142,15 +138,29 @@ func fetchJwtToken(webClient *http.Client, configuration *Config) (JwtToken, err
        return jwt, err
 }
 
-func loadCertificate(certPath string, keyPath string) (tls.Certificate, error) {
+func loadCertificate(certPath string, keyPath string) tls.Certificate {
        log.WithFields(log.Fields{"certPath": certPath, "keyPath": keyPath}).Debug("Loading cert")
-       if cert, err := tls.LoadX509KeyPair(certPath, keyPath); err == nil {
-               return cert, nil
+       cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+       if check(err) {
+               return cert
        } else {
-               return tls.Certificate{}, fmt.Errorf("cannot create x509 keypair from cert file %s and key file %s due to: %v", certPath, keyPath, err)
+               log.Fatalf("cannot create x509 keypair from cert file %s and key file %s due to: %v", certPath, keyPath, err)
+               return tls.Certificate{}
        }
 }
 
+func loadCaCerts(caCertsPath string) *x509.CertPool {
+       var err error
+       if caCertsPath == "" {
+               return nil
+       }
+       caCert, err := ioutil.ReadFile(caCertsPath)
+       check(err)
+       caCertPool := x509.NewCertPool()
+       caCertPool.AppendCertsFromPEM(caCert)
+       return caCertPool
+}
+
 func keepAlive() {
        channel := make(chan int)
        <-channel
index c575614..cceff07 100644 (file)
@@ -28,7 +28,6 @@ import (
        "io/ioutil"
        "net/http"
        "os"
-       "sync"
        "testing"
        "time"
 
@@ -36,7 +35,7 @@ import (
        "github.com/stretchr/testify/require"
 )
 
-func createHttpClientMock(t *testing.T, configuration *Config, wg *sync.WaitGroup, token JwtToken) *http.Client {
+func createHttpClientMock(t *testing.T, configuration *Config, token JwtToken) *http.Client {
        assertions := require.New(t)
        clientMock := NewTestClient(func(req *http.Request) *http.Response {
                if req.URL.String() == configuration.AuthServiceUrl {
@@ -47,7 +46,7 @@ func createHttpClientMock(t *testing.T, configuration *Config, wg *sync.WaitGrou
                        assertions.Contains(body, "grant_type="+configuration.GrantType)
                        contentType := req.Header.Get("content-type")
                        assertions.Equal("application/x-www-form-urlencoded", contentType)
-                       wg.Done()
+
                        return &http.Response{
                                StatusCode: 200,
                                Body:       ioutil.NopCloser(bytes.NewBuffer(toBody(token))),
@@ -78,16 +77,11 @@ func TestFetchAndStoreToken(t *testing.T) {
        accessToken := "Access_token" + fmt.Sprint(time.Now().UnixNano())
        token := JwtToken{Access_token: accessToken, Expires_in: 7, Token_type: "Token_type"}
 
-       wg := sync.WaitGroup{}
-       wg.Add(2) // Get token two times
-       clientMock := createHttpClientMock(t, configuration, &wg, token)
+       clientMock := createHttpClientMock(t, configuration, token)
 
        go periodicRefreshIwtToken(clientMock, context)
 
-       if waitTimeout(&wg, 12*time.Second) {
-               t.Error("Not all calls to server were made")
-               t.Fail()
-       }
+       await(func() bool { return fileExists(configuration.AuthTokenOutputFileName) }, t)
 
        tokenFileContent, err := ioutil.ReadFile(configuration.AuthTokenOutputFileName)
        check(err)
@@ -97,13 +91,37 @@ func TestFetchAndStoreToken(t *testing.T) {
        context.Running = false
 }
 
+func fileExists(fileName string) bool {
+       if _, err := os.Stat(fileName); err == nil {
+               return true
+       }
+       log.Debug("Waiting for file: " + fileName)
+       return false
+}
+
+func await(predicate func() bool, t *testing.T) {
+       MAX_TIME_SECONDS := 30
+       for i := 1; i < MAX_TIME_SECONDS; i++ {
+               if predicate() {
+                       return
+               }
+               time.Sleep(time.Second)
+       }
+       t.Error("Predicate not fulfilled")
+       t.Fail()
+}
+
 func TestStart(t *testing.T) {
        assertions := require.New(t)
        log.SetLevel(log.TraceLevel)
 
        configuration := NewConfig()
        configuration.AuthTokenOutputFileName = "/tmp/authToken" + fmt.Sprint(time.Now().UnixNano())
+       configuration.CACertsPath = configuration.CertPath
        context := NewContext(configuration)
+       t.Cleanup(func() {
+               os.Remove(configuration.AuthTokenOutputFileName)
+       })
 
        start(context)
 
@@ -134,22 +152,6 @@ func NewTestClient(fn RoundTripFunc) *http.Client {
        }
 }
 
-// waitTimeout waits for the waitgroup for the specified max timeout.
-// Returns true if waiting timed out.
-func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
-       c := make(chan struct{})
-       go func() {
-               defer close(c)
-               wg.Wait()
-       }()
-       select {
-       case <-c:
-               return false // completed normally
-       case <-time.After(timeout):
-               return true // timed out
-       }
-}
-
 func getBodyAsString(req *http.Request, t *testing.T) string {
        buf := new(bytes.Buffer)
        if _, err := buf.ReadFrom(req.Body); err != nil {