Fetch of authorization token 76/7976/2
authorPatrikBuhr <patrik.buhr@est.tech>
Wed, 23 Mar 2022 07:15:21 +0000 (08:15 +0100)
committerPatrikBuhr <patrik.buhr@est.tech>
Wed, 23 Mar 2022 07:18:35 +0000 (08:18 +0100)
Added support for configuration of root CAs for trust validation.

Signed-off-by: PatrikBuhr <patrik.buhr@est.tech>
Issue-ID: NONRTRIC-735
Change-Id: I9ee9e73eeb1f9f94a7ea73342d4ddee25066729f

auth-token-fetch/HTTPClient.go
auth-token-fetch/HTTPClient_test.go
auth-token-fetch/config.go
auth-token-fetch/config_test.go
auth-token-fetch/main.go
auth-token-fetch/main_test.go

index ab76b13..a765461 100644 (file)
@@ -23,6 +23,7 @@ package main
 import (
        "bytes"
        "crypto/tls"
+       "crypto/x509"
        "fmt"
        "io"
 
@@ -38,10 +39,10 @@ type HTTPClient interface {
        Do(*http.Request) (*http.Response, error)
 }
 
-func CreateHttpClient(cert tls.Certificate, timeout time.Duration) *http.Client {
+func CreateHttpClient(cert tls.Certificate, caCerts *x509.CertPool, timeout time.Duration) *http.Client {
        return &http.Client{
                Timeout:   timeout,
-               Transport: createTransport(cert),
+               Transport: createTransport(cert, caCerts),
        }
 }
 
@@ -89,9 +90,11 @@ func getRequestError(response *http.Response) RequestError {
        return putError
 }
 
-func createTransport(cert tls.Certificate) *http.Transport {
+func createTransport(cert tls.Certificate, caCerts *x509.CertPool) *http.Transport {
        return &http.Transport{
                TLSClientConfig: &tls.Config{
+                       ClientCAs: caCerts,
+                       RootCAs:   caCerts,
                        Certificates: []tls.Certificate{
                                cert,
                        },
index e0a4cd1..7b8deb0 100644 (file)
@@ -43,7 +43,7 @@ func TestRequestError_Error(t *testing.T) {
 func Test_CreateClient(t *testing.T) {
        assertions := require.New(t)
 
-       client := CreateHttpClient(tls.Certificate{}, 5*time.Second)
+       client := CreateHttpClient(tls.Certificate{}, nil, 5*time.Second)
 
        transport := client.Transport
        assertions.Equal("*http.Transport", reflect.TypeOf(transport).String())
index 18d610d..dfeeb96 100644 (file)
@@ -33,6 +33,7 @@ import (
 type Config struct {
        LogLevel                log.Level
        CertPath                string
+       CACertsPath             string
        KeyPath                 string
        AuthServiceUrl          string
        GrantType               string
@@ -44,14 +45,15 @@ type Config struct {
 
 func NewConfig() *Config {
        return &Config{
-               CertPath:                getEnv("CERT_PATH", "security/tls.crt"),
-               KeyPath:                 getEnv("CERT_KEY_PATH", "security/tls.key"),
+               CertPath:                getEnv("CERT_PATH", "security/tls.crt", false),
+               KeyPath:                 getEnv("CERT_KEY_PATH", "security/tls.key", false),
+               CACertsPath:             getEnv("ROOT_CA_CERTS_PATH", "", false),
                LogLevel:                getLogLevel(),
-               GrantType:               getEnv("CREDS_GRANT_TYPE", ""),
-               ClientSecret:            getEnv("CREDS_CLIENT_SECRET", ""),
-               ClientId:                getEnv("CREDS_CLIENT_ID", ""),
-               AuthTokenOutputFileName: getEnv("OUTPUT_FILE", "/tmp/authToken.txt"),
-               AuthServiceUrl:          getEnv("AUTH_SERVICE_URL", "https://localhost:39687/example-singlelogin-sever/login"),
+               GrantType:               getEnv("CREDS_GRANT_TYPE", "", false),
+               ClientSecret:            getEnv("CREDS_CLIENT_SECRET", "", true),
+               ClientId:                getEnv("CREDS_CLIENT_ID", "", false),
+               AuthTokenOutputFileName: getEnv("OUTPUT_FILE", "/tmp/authToken.txt", false),
+               AuthServiceUrl:          getEnv("AUTH_SERVICE_URL", "https://localhost:39687/example-singlelogin-sever/login", false),
                RefreshMarginSeconds:    getEnvAsInt("REFRESH_MARGIN_SECONDS", 5, 1, 3600),
        }
 }
@@ -61,21 +63,29 @@ func validateConfiguration(configuration *Config) error {
                return fmt.Errorf("missing CERT_PATH and/or CERT_KEY_PATH")
        }
 
+       if configuration.CACertsPath == "" {
+               log.Warn("No Root CA certs loaded, no trust validation may be performed")
+       }
+
        return nil
 }
 
-func getEnv(key string, defaultVal string) string {
+func getEnv(key string, defaultVal string, secret bool) string {
        if value, exists := os.LookupEnv(key); exists {
-               log.Debugf("Using value: '%v' for '%v'", value, key)
+               if !secret {
+                       log.Debugf("Using value: '%v' for '%v'", value, key)
+               }
                return value
        } else {
-               log.Debugf("Using default value: '%v' for '%v'", defaultVal, key)
+               if !secret {
+                       log.Debugf("Using default value: '%v' for '%v'", defaultVal, key)
+               }
                return defaultVal
        }
 }
 
 func getEnvAsInt(name string, defaultVal int, min int, max int) int {
-       valueStr := getEnv(name, "")
+       valueStr := getEnv(name, fmt.Sprint(defaultVal), false)
        if value, err := strconv.Atoi(valueStr); err == nil {
                if value < min || value > max {
                        log.Warnf("Value out of range: '%v' for variable: '%v'. Default value: '%v' will be used", valueStr, name, defaultVal)
@@ -90,7 +100,7 @@ func getEnvAsInt(name string, defaultVal int, min int, max int) int {
 }
 
 func getLogLevel() log.Level {
-       logLevelStr := getEnv("LOG_LEVEL", "Info")
+       logLevelStr := getEnv("LOG_LEVEL", "Info", false)
        if loglevel, err := log.ParseLevel(logLevelStr); err == nil {
                return loglevel
        } else {
index 8b441c1..92a63d8 100644 (file)
@@ -39,6 +39,7 @@ func TestNew_envVarsSetConfigContainSetValues(t *testing.T) {
        os.Setenv("OUTPUT_FILE", "OUTPUT_FILE")
        os.Setenv("AUTH_SERVICE_URL", "AUTH_SERVICE_URL")
        os.Setenv("REFRESH_MARGIN_SECONDS", "33")
+       os.Setenv("ROOT_CA_CERTS_PATH", "ROOT_CA_CERTS_PATH")
 
        t.Cleanup(func() {
                os.Clearenv()
@@ -52,9 +53,11 @@ func TestNew_envVarsSetConfigContainSetValues(t *testing.T) {
                ClientSecret:            "CREDS_CLIENT_SECRET",
                ClientId:                "CREDS_CLIENT_ID",
                AuthTokenOutputFileName: "OUTPUT_FILE",
+               CACertsPath:             "ROOT_CA_CERTS_PATH",
                RefreshMarginSeconds:    33,
        }
        got := NewConfig()
+       assertions.Equal(nil, validateConfiguration(got))
 
        assertions.Equal(&wantConfig, got)
 }
index 9a63534..41f49d3 100644 (file)
@@ -22,8 +22,8 @@ package main
 
 import (
        "crypto/tls"
+       "crypto/x509"
        "encoding/json"
-       "fmt"
        "io/ioutil"
        "net/http"
        "net/url"
@@ -74,14 +74,10 @@ func start(context *Context) {
                log.Fatalf("Stopping due to error: %v", err)
        }
 
-       var cert tls.Certificate
-       if c, err := loadCertificate(context.Config.CertPath, context.Config.KeyPath); err == nil {
-               cert = c
-       } else {
-               log.Fatalf("Stopping due to error: %v", err)
-       }
+       cert := loadCertificate(context.Config.CertPath, context.Config.KeyPath)
+       caCerts := loadCaCerts(context.Config.CACertsPath)
 
-       webClient := CreateHttpClient(cert, 10*time.Second)
+       webClient := CreateHttpClient(cert, caCerts, 10*time.Second)
 
        go periodicRefreshIwtToken(webClient, context)
 }
@@ -142,15 +138,29 @@ func fetchJwtToken(webClient *http.Client, configuration *Config) (JwtToken, err
        return jwt, err
 }
 
-func loadCertificate(certPath string, keyPath string) (tls.Certificate, error) {
+func loadCertificate(certPath string, keyPath string) tls.Certificate {
        log.WithFields(log.Fields{"certPath": certPath, "keyPath": keyPath}).Debug("Loading cert")
-       if cert, err := tls.LoadX509KeyPair(certPath, keyPath); err == nil {
-               return cert, nil
+       cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+       if check(err) {
+               return cert
        } else {
-               return tls.Certificate{}, fmt.Errorf("cannot create x509 keypair from cert file %s and key file %s due to: %v", certPath, keyPath, err)
+               log.Fatalf("cannot create x509 keypair from cert file %s and key file %s due to: %v", certPath, keyPath, err)
+               return tls.Certificate{}
        }
 }
 
+func loadCaCerts(caCertsPath string) *x509.CertPool {
+       var err error
+       if caCertsPath == "" {
+               return nil
+       }
+       caCert, err := ioutil.ReadFile(caCertsPath)
+       check(err)
+       caCertPool := x509.NewCertPool()
+       caCertPool.AppendCertsFromPEM(caCert)
+       return caCertPool
+}
+
 func keepAlive() {
        channel := make(chan int)
        <-channel
index c575614..2222bd9 100644 (file)
@@ -103,6 +103,7 @@ func TestStart(t *testing.T) {
 
        configuration := NewConfig()
        configuration.AuthTokenOutputFileName = "/tmp/authToken" + fmt.Sprint(time.Now().UnixNano())
+       configuration.CACertsPath = configuration.CertPath
        context := NewContext(configuration)
 
        start(context)