import (
"bytes"
"crypto/tls"
+ "crypto/x509"
"fmt"
"io"
Do(*http.Request) (*http.Response, error)
}
-func CreateHttpClient(cert tls.Certificate, timeout time.Duration) *http.Client {
+func CreateHttpClient(cert tls.Certificate, caCerts *x509.CertPool, timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
- Transport: createTransport(cert),
+ Transport: createTransport(cert, caCerts),
}
}
return putError
}
-func createTransport(cert tls.Certificate) *http.Transport {
+func createTransport(cert tls.Certificate, caCerts *x509.CertPool) *http.Transport {
return &http.Transport{
TLSClientConfig: &tls.Config{
+ ClientCAs: caCerts,
+ RootCAs: caCerts,
Certificates: []tls.Certificate{
cert,
},
func Test_CreateClient(t *testing.T) {
assertions := require.New(t)
- client := CreateHttpClient(tls.Certificate{}, 5*time.Second)
+ client := CreateHttpClient(tls.Certificate{}, nil, 5*time.Second)
transport := client.Transport
assertions.Equal("*http.Transport", reflect.TypeOf(transport).String())
type Config struct {
LogLevel log.Level
CertPath string
+ CACertsPath string
KeyPath string
AuthServiceUrl string
GrantType string
func NewConfig() *Config {
return &Config{
- CertPath: getEnv("CERT_PATH", "security/tls.crt"),
- KeyPath: getEnv("CERT_KEY_PATH", "security/tls.key"),
+ CertPath: getEnv("CERT_PATH", "security/tls.crt", false),
+ KeyPath: getEnv("CERT_KEY_PATH", "security/tls.key", false),
+ CACertsPath: getEnv("ROOT_CA_CERTS_PATH", "", false),
LogLevel: getLogLevel(),
- GrantType: getEnv("CREDS_GRANT_TYPE", ""),
- ClientSecret: getEnv("CREDS_CLIENT_SECRET", ""),
- ClientId: getEnv("CREDS_CLIENT_ID", ""),
- AuthTokenOutputFileName: getEnv("OUTPUT_FILE", "/tmp/authToken.txt"),
- AuthServiceUrl: getEnv("AUTH_SERVICE_URL", "https://localhost:39687/example-singlelogin-sever/login"),
+ GrantType: getEnv("CREDS_GRANT_TYPE", "", false),
+ ClientSecret: getEnv("CREDS_CLIENT_SECRET", "", true),
+ ClientId: getEnv("CREDS_CLIENT_ID", "", false),
+ AuthTokenOutputFileName: getEnv("OUTPUT_FILE", "/tmp/authToken.txt", false),
+ AuthServiceUrl: getEnv("AUTH_SERVICE_URL", "https://localhost:39687/example-singlelogin-sever/login", false),
RefreshMarginSeconds: getEnvAsInt("REFRESH_MARGIN_SECONDS", 5, 1, 3600),
}
}
return fmt.Errorf("missing CERT_PATH and/or CERT_KEY_PATH")
}
+ if configuration.CACertsPath == "" {
+ log.Warn("No Root CA certs loaded, no trust validation may be performed")
+ }
+
return nil
}
-func getEnv(key string, defaultVal string) string {
+func getEnv(key string, defaultVal string, secret bool) string {
if value, exists := os.LookupEnv(key); exists {
- log.Debugf("Using value: '%v' for '%v'", value, key)
+ if !secret {
+ log.Debugf("Using value: '%v' for '%v'", value, key)
+ }
return value
} else {
- log.Debugf("Using default value: '%v' for '%v'", defaultVal, key)
+ if !secret {
+ log.Debugf("Using default value: '%v' for '%v'", defaultVal, key)
+ }
return defaultVal
}
}
func getEnvAsInt(name string, defaultVal int, min int, max int) int {
- valueStr := getEnv(name, "")
+ valueStr := getEnv(name, fmt.Sprint(defaultVal), false)
if value, err := strconv.Atoi(valueStr); err == nil {
if value < min || value > max {
log.Warnf("Value out of range: '%v' for variable: '%v'. Default value: '%v' will be used", valueStr, name, defaultVal)
}
func getLogLevel() log.Level {
- logLevelStr := getEnv("LOG_LEVEL", "Info")
+ logLevelStr := getEnv("LOG_LEVEL", "Info", false)
if loglevel, err := log.ParseLevel(logLevelStr); err == nil {
return loglevel
} else {
os.Setenv("OUTPUT_FILE", "OUTPUT_FILE")
os.Setenv("AUTH_SERVICE_URL", "AUTH_SERVICE_URL")
os.Setenv("REFRESH_MARGIN_SECONDS", "33")
+ os.Setenv("ROOT_CA_CERTS_PATH", "ROOT_CA_CERTS_PATH")
t.Cleanup(func() {
os.Clearenv()
ClientSecret: "CREDS_CLIENT_SECRET",
ClientId: "CREDS_CLIENT_ID",
AuthTokenOutputFileName: "OUTPUT_FILE",
+ CACertsPath: "ROOT_CA_CERTS_PATH",
RefreshMarginSeconds: 33,
}
got := NewConfig()
+ assertions.Equal(nil, validateConfiguration(got))
assertions.Equal(&wantConfig, got)
}
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect
+ golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y=
-golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 h1:OH54vjqzRWmbJ62fjuhxy7AxFFgoHN0/DPc/UrL8cAs=
+golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
import (
"crypto/tls"
+ "crypto/x509"
"encoding/json"
- "fmt"
"io/ioutil"
"net/http"
"net/url"
log.Fatalf("Stopping due to error: %v", err)
}
- var cert tls.Certificate
- if c, err := loadCertificate(context.Config.CertPath, context.Config.KeyPath); err == nil {
- cert = c
- } else {
- log.Fatalf("Stopping due to error: %v", err)
- }
+ cert := loadCertificate(context.Config.CertPath, context.Config.KeyPath)
+ caCerts := loadCaCerts(context.Config.CACertsPath)
- webClient := CreateHttpClient(cert, 10*time.Second)
+ webClient := CreateHttpClient(cert, caCerts, 10*time.Second)
go periodicRefreshIwtToken(webClient, context)
}
return jwt, err
}
-func loadCertificate(certPath string, keyPath string) (tls.Certificate, error) {
+func loadCertificate(certPath string, keyPath string) tls.Certificate {
log.WithFields(log.Fields{"certPath": certPath, "keyPath": keyPath}).Debug("Loading cert")
- if cert, err := tls.LoadX509KeyPair(certPath, keyPath); err == nil {
- return cert, nil
+ cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+ if check(err) {
+ return cert
} else {
- return tls.Certificate{}, fmt.Errorf("cannot create x509 keypair from cert file %s and key file %s due to: %v", certPath, keyPath, err)
+ log.Fatalf("cannot create x509 keypair from cert file %s and key file %s due to: %v", certPath, keyPath, err)
+ return tls.Certificate{}
}
}
+func loadCaCerts(caCertsPath string) *x509.CertPool {
+ var err error
+ if caCertsPath == "" {
+ return nil
+ }
+ caCert, err := ioutil.ReadFile(caCertsPath)
+ check(err)
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
+ return caCertPool
+}
+
func keepAlive() {
channel := make(chan int)
<-channel
"io/ioutil"
"net/http"
"os"
- "sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
-func createHttpClientMock(t *testing.T, configuration *Config, wg *sync.WaitGroup, token JwtToken) *http.Client {
+func createHttpClientMock(t *testing.T, configuration *Config, token JwtToken) *http.Client {
assertions := require.New(t)
clientMock := NewTestClient(func(req *http.Request) *http.Response {
if req.URL.String() == configuration.AuthServiceUrl {
assertions.Contains(body, "grant_type="+configuration.GrantType)
contentType := req.Header.Get("content-type")
assertions.Equal("application/x-www-form-urlencoded", contentType)
- wg.Done()
+
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBuffer(toBody(token))),
accessToken := "Access_token" + fmt.Sprint(time.Now().UnixNano())
token := JwtToken{Access_token: accessToken, Expires_in: 7, Token_type: "Token_type"}
- wg := sync.WaitGroup{}
- wg.Add(2) // Get token two times
- clientMock := createHttpClientMock(t, configuration, &wg, token)
+ clientMock := createHttpClientMock(t, configuration, token)
go periodicRefreshIwtToken(clientMock, context)
- if waitTimeout(&wg, 12*time.Second) {
- t.Error("Not all calls to server were made")
- t.Fail()
- }
+ await(func() bool { return fileExists(configuration.AuthTokenOutputFileName) }, t)
tokenFileContent, err := ioutil.ReadFile(configuration.AuthTokenOutputFileName)
check(err)
context.Running = false
}
+func fileExists(fileName string) bool {
+ if _, err := os.Stat(fileName); err == nil {
+ return true
+ }
+ log.Debug("Waiting for file: " + fileName)
+ return false
+}
+
+func await(predicate func() bool, t *testing.T) {
+ MAX_TIME_SECONDS := 30
+ for i := 1; i < MAX_TIME_SECONDS; i++ {
+ if predicate() {
+ return
+ }
+ time.Sleep(time.Second)
+ }
+ t.Error("Predicate not fulfilled")
+ t.Fail()
+}
+
func TestStart(t *testing.T) {
assertions := require.New(t)
log.SetLevel(log.TraceLevel)
configuration := NewConfig()
configuration.AuthTokenOutputFileName = "/tmp/authToken" + fmt.Sprint(time.Now().UnixNano())
+ configuration.CACertsPath = configuration.CertPath
context := NewContext(configuration)
+ t.Cleanup(func() {
+ os.Remove(configuration.AuthTokenOutputFileName)
+ })
start(context)
}
}
-// waitTimeout waits for the waitgroup for the specified max timeout.
-// Returns true if waiting timed out.
-func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
- c := make(chan struct{})
- go func() {
- defer close(c)
- wg.Wait()
- }()
- select {
- case <-c:
- return false // completed normally
- case <-time.After(timeout):
- return true // timed out
- }
-}
-
func getBodyAsString(req *http.Request, t *testing.T) string {
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(req.Body); err != nil {