978fbeef52cbdbb64268d80628b2b8485b80a771
[ric-plt/a1.git] / a1-go / pkg / restapi / server.go
1 /*
2 ==================================================================================
3   Copyright (c) 2021 Samsung
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16
17    This source code is part of the near-RT RIC (RAN Intelligent Controller)
18    platform project (RICP).
19 ==================================================================================
20 */
21 // Code generated by go-swagger; DO NOT EDIT.
22
23 package restapi
24
25 import (
26         "context"
27         "crypto/tls"
28         "crypto/x509"
29         "errors"
30         "fmt"
31         "io/ioutil"
32         "log"
33         "net"
34         "net/http"
35         "os"
36         "os/signal"
37         "strconv"
38         "sync"
39         "sync/atomic"
40         "syscall"
41         "time"
42
43         "github.com/go-openapi/runtime/flagext"
44         "github.com/go-openapi/swag"
45         flags "github.com/jessevdk/go-flags"
46         "golang.org/x/net/netutil"
47
48        "gerrit.o-ran-sc.org/r/ric-plt/a1/pkg/restapi/operations"
49 )
50
51 const (
52         schemeHTTP  = "http"
53         schemeHTTPS = "https"
54         schemeUnix  = "unix"
55 )
56
57 var defaultSchemes []string
58
59 func init() {
60         defaultSchemes = []string{
61                 schemeHTTP,
62         }
63 }
64
65 // NewServer creates a new api a1 server but does not configure it
66 func NewServer(api *operations.A1API) *Server {
67         s := new(Server)
68
69         s.shutdown = make(chan struct{})
70         s.api = api
71         s.interrupt = make(chan os.Signal, 1)
72         return s
73 }
74
75 // ConfigureAPI configures the API and handlers.
76 func (s *Server) ConfigureAPI() {
77         if s.api != nil {
78                 s.handler = configureAPI(s.api)
79         }
80 }
81
82 // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
83 func (s *Server) ConfigureFlags() {
84         if s.api != nil {
85                 configureFlags(s.api)
86         }
87 }
88
89 // Server for the a1 API
90 type Server struct {
91         EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
92         CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
93         GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
94         MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
95
96         SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/a1.sock"`
97         domainSocketL net.Listener
98
99         Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
100         Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
101         ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
102         KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
103         ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
104         WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
105         httpServerL  net.Listener
106
107         TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
108         TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
109         TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
110         TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
111         TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
112         TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
113         TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
114         TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
115         TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
116         httpsServerL      net.Listener
117
118         api          *operations.A1API
119         handler      http.Handler
120         hasListeners bool
121         shutdown     chan struct{}
122         shuttingDown int32
123         interrupted  bool
124         interrupt    chan os.Signal
125 }
126
127 // Logf logs message either via defined user logger or via system one if no user logger is defined.
128 func (s *Server) Logf(f string, args ...interface{}) {
129         if s.api != nil && s.api.Logger != nil {
130                 s.api.Logger(f, args...)
131         } else {
132                 log.Printf(f, args...)
133         }
134 }
135
136 // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
137 // Exits with non-zero status after printing
138 func (s *Server) Fatalf(f string, args ...interface{}) {
139         if s.api != nil && s.api.Logger != nil {
140                 s.api.Logger(f, args...)
141                 os.Exit(1)
142         } else {
143                 log.Fatalf(f, args...)
144         }
145 }
146
147 // SetAPI configures the server with the specified API. Needs to be called before Serve
148 func (s *Server) SetAPI(api *operations.A1API) {
149         if api == nil {
150                 s.api = nil
151                 s.handler = nil
152                 return
153         }
154
155         s.api = api
156         s.handler = configureAPI(api)
157 }
158
159 func (s *Server) hasScheme(scheme string) bool {
160         schemes := s.EnabledListeners
161         if len(schemes) == 0 {
162                 schemes = defaultSchemes
163         }
164
165         for _, v := range schemes {
166                 if v == scheme {
167                         return true
168                 }
169         }
170         return false
171 }
172
173 // Serve the api
174 func (s *Server) Serve() (err error) {
175         if !s.hasListeners {
176                 if err = s.Listen(); err != nil {
177                         return err
178                 }
179         }
180
181         // set default handler, if none is set
182         if s.handler == nil {
183                 if s.api == nil {
184                         return errors.New("can't create the default handler, as no api is set")
185                 }
186
187                 s.SetHandler(s.api.Serve(nil))
188         }
189
190         wg := new(sync.WaitGroup)
191         once := new(sync.Once)
192         signalNotify(s.interrupt)
193         go handleInterrupt(once, s)
194
195         servers := []*http.Server{}
196
197         if s.hasScheme(schemeUnix) {
198                 domainSocket := new(http.Server)
199                 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
200                 domainSocket.Handler = s.handler
201                 if int64(s.CleanupTimeout) > 0 {
202                         domainSocket.IdleTimeout = s.CleanupTimeout
203                 }
204
205                 configureServer(domainSocket, "unix", string(s.SocketPath))
206
207                 servers = append(servers, domainSocket)
208                 wg.Add(1)
209                 s.Logf("Serving a1 at unix://%s", s.SocketPath)
210                 go func(l net.Listener) {
211                         defer wg.Done()
212                         if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
213                                 s.Fatalf("%v", err)
214                         }
215                         s.Logf("Stopped serving a1 at unix://%s", s.SocketPath)
216                 }(s.domainSocketL)
217         }
218
219         if s.hasScheme(schemeHTTP) {
220                 httpServer := new(http.Server)
221                 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
222                 httpServer.ReadTimeout = s.ReadTimeout
223                 httpServer.WriteTimeout = s.WriteTimeout
224                 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
225                 if s.ListenLimit > 0 {
226                         s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
227                 }
228
229                 if int64(s.CleanupTimeout) > 0 {
230                         httpServer.IdleTimeout = s.CleanupTimeout
231                 }
232
233                 httpServer.Handler = s.handler
234
235                 configureServer(httpServer, "http", s.httpServerL.Addr().String())
236
237                 servers = append(servers, httpServer)
238                 wg.Add(1)
239                 s.Logf("Serving a1 at http://%s", s.httpServerL.Addr())
240                 go func(l net.Listener) {
241                         defer wg.Done()
242                         if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
243                                 s.Fatalf("%v", err)
244                         }
245                         s.Logf("Stopped serving a1 at http://%s", l.Addr())
246                 }(s.httpServerL)
247         }
248
249         if s.hasScheme(schemeHTTPS) {
250                 httpsServer := new(http.Server)
251                 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
252                 httpsServer.ReadTimeout = s.TLSReadTimeout
253                 httpsServer.WriteTimeout = s.TLSWriteTimeout
254                 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
255                 if s.TLSListenLimit > 0 {
256                         s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
257                 }
258                 if int64(s.CleanupTimeout) > 0 {
259                         httpsServer.IdleTimeout = s.CleanupTimeout
260                 }
261                 httpsServer.Handler = s.handler
262
263                 // Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
264                 httpsServer.TLSConfig = &tls.Config{
265                         // Causes servers to use Go's default ciphersuite preferences,
266                         // which are tuned to avoid attacks. Does nothing on clients.
267                         PreferServerCipherSuites: true,
268                         // Only use curves which have assembly implementations
269                         // https://github.com/golang/go/tree/master/src/crypto/elliptic
270                         CurvePreferences: []tls.CurveID{tls.CurveP256},
271                         // Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
272                         NextProtos: []string{"h2", "http/1.1"},
273                         // https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
274                         MinVersion: tls.VersionTLS12,
275                         // These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
276                         CipherSuites: []uint16{
277                                 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
278                                 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
279                                 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
280                                 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
281                                 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
282                                 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
283                         },
284                 }
285
286                 // build standard config from server options
287                 if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
288                         httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
289                         httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
290                         if err != nil {
291                                 return err
292                         }
293                 }
294
295                 if s.TLSCACertificate != "" {
296                         // include specified CA certificate
297                         caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
298                         if caCertErr != nil {
299                                 return caCertErr
300                         }
301                         caCertPool := x509.NewCertPool()
302                         ok := caCertPool.AppendCertsFromPEM(caCert)
303                         if !ok {
304                                 return fmt.Errorf("cannot parse CA certificate")
305                         }
306                         httpsServer.TLSConfig.ClientCAs = caCertPool
307                         httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
308                 }
309
310                 // call custom TLS configurator
311                 configureTLS(httpsServer.TLSConfig)
312
313                 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
314                         // after standard and custom config are passed, this ends up with no certificate
315                         if s.TLSCertificate == "" {
316                                 if s.TLSCertificateKey == "" {
317                                         s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
318                                 }
319                                 s.Fatalf("the required flag `--tls-certificate` was not specified")
320                         }
321                         if s.TLSCertificateKey == "" {
322                                 s.Fatalf("the required flag `--tls-key` was not specified")
323                         }
324                         // this happens with a wrong custom TLS configurator
325                         s.Fatalf("no certificate was configured for TLS")
326                 }
327
328                 // must have at least one certificate or panics
329                 httpsServer.TLSConfig.BuildNameToCertificate()
330
331                 configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
332
333                 servers = append(servers, httpsServer)
334                 wg.Add(1)
335                 s.Logf("Serving a1 at https://%s", s.httpsServerL.Addr())
336                 go func(l net.Listener) {
337                         defer wg.Done()
338                         if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
339                                 s.Fatalf("%v", err)
340                         }
341                         s.Logf("Stopped serving a1 at https://%s", l.Addr())
342                 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
343         }
344
345         wg.Add(1)
346         go s.handleShutdown(wg, &servers)
347
348         wg.Wait()
349         return nil
350 }
351
352 // Listen creates the listeners for the server
353 func (s *Server) Listen() error {
354         if s.hasListeners { // already done this
355                 return nil
356         }
357
358         if s.hasScheme(schemeHTTPS) {
359                 // Use http host if https host wasn't defined
360                 if s.TLSHost == "" {
361                         s.TLSHost = s.Host
362                 }
363                 // Use http listen limit if https listen limit wasn't defined
364                 if s.TLSListenLimit == 0 {
365                         s.TLSListenLimit = s.ListenLimit
366                 }
367                 // Use http tcp keep alive if https tcp keep alive wasn't defined
368                 if int64(s.TLSKeepAlive) == 0 {
369                         s.TLSKeepAlive = s.KeepAlive
370                 }
371                 // Use http read timeout if https read timeout wasn't defined
372                 if int64(s.TLSReadTimeout) == 0 {
373                         s.TLSReadTimeout = s.ReadTimeout
374                 }
375                 // Use http write timeout if https write timeout wasn't defined
376                 if int64(s.TLSWriteTimeout) == 0 {
377                         s.TLSWriteTimeout = s.WriteTimeout
378                 }
379         }
380
381         if s.hasScheme(schemeUnix) {
382                 domSockListener, err := net.Listen("unix", string(s.SocketPath))
383                 if err != nil {
384                         return err
385                 }
386                 s.domainSocketL = domSockListener
387         }
388
389         if s.hasScheme(schemeHTTP) {
390                 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
391                 if err != nil {
392                         return err
393                 }
394
395                 h, p, err := swag.SplitHostPort(listener.Addr().String())
396                 if err != nil {
397                         return err
398                 }
399                 s.Host = h
400                 s.Port = p
401                 s.httpServerL = listener
402         }
403
404         if s.hasScheme(schemeHTTPS) {
405                 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
406                 if err != nil {
407                         return err
408                 }
409
410                 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
411                 if err != nil {
412                         return err
413                 }
414                 s.TLSHost = sh
415                 s.TLSPort = sp
416                 s.httpsServerL = tlsListener
417         }
418
419         s.hasListeners = true
420         return nil
421 }
422
423 // Shutdown server and clean up resources
424 func (s *Server) Shutdown() error {
425         if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
426                 close(s.shutdown)
427         }
428         return nil
429 }
430
431 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
432         // wg.Done must occur last, after s.api.ServerShutdown()
433         // (to preserve old behaviour)
434         defer wg.Done()
435
436         <-s.shutdown
437
438         servers := *serversPtr
439
440         ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
441         defer cancel()
442
443         // first execute the pre-shutdown hook
444         s.api.PreServerShutdown()
445
446         shutdownChan := make(chan bool)
447         for i := range servers {
448                 server := servers[i]
449                 go func() {
450                         var success bool
451                         defer func() {
452                                 shutdownChan <- success
453                         }()
454                         if err := server.Shutdown(ctx); err != nil {
455                                 // Error from closing listeners, or context timeout:
456                                 s.Logf("HTTP server Shutdown: %v", err)
457                         } else {
458                                 success = true
459                         }
460                 }()
461         }
462
463         // Wait until all listeners have successfully shut down before calling ServerShutdown
464         success := true
465         for range servers {
466                 success = success && <-shutdownChan
467         }
468         if success {
469                 s.api.ServerShutdown()
470         }
471 }
472
473 // GetHandler returns a handler useful for testing
474 func (s *Server) GetHandler() http.Handler {
475         return s.handler
476 }
477
478 // SetHandler allows for setting a http handler on this server
479 func (s *Server) SetHandler(handler http.Handler) {
480         s.handler = handler
481 }
482
483 // UnixListener returns the domain socket listener
484 func (s *Server) UnixListener() (net.Listener, error) {
485         if !s.hasListeners {
486                 if err := s.Listen(); err != nil {
487                         return nil, err
488                 }
489         }
490         return s.domainSocketL, nil
491 }
492
493 // HTTPListener returns the http listener
494 func (s *Server) HTTPListener() (net.Listener, error) {
495         if !s.hasListeners {
496                 if err := s.Listen(); err != nil {
497                         return nil, err
498                 }
499         }
500         return s.httpServerL, nil
501 }
502
503 // TLSListener returns the https listener
504 func (s *Server) TLSListener() (net.Listener, error) {
505         if !s.hasListeners {
506                 if err := s.Listen(); err != nil {
507                         return nil, err
508                 }
509         }
510         return s.httpsServerL, nil
511 }
512
513 func handleInterrupt(once *sync.Once, s *Server) {
514         once.Do(func() {
515                 for range s.interrupt {
516                         if s.interrupted {
517                                 s.Logf("Server already shutting down")
518                                 continue
519                         }
520                         s.interrupted = true
521                         s.Logf("Shutting down... ")
522                         if err := s.Shutdown(); err != nil {
523                                 s.Logf("HTTP server Shutdown: %v", err)
524                         }
525                 }
526         })
527 }
528
529 func signalNotify(interrupt chan<- os.Signal) {
530         signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
531 }