RIC:1060: Change in PTL
[ric-plt/a1.git] / pkg / restapi / server.go
1 /*
2 ==================================================================================
3   Copyright (c) 2021 Samsung
4
5    Licensed under the Apache License, Version 2.0 (the "License");
6    you may not use this file except in compliance with the License.
7    You may obtain a copy of the License at
8
9        http://www.apache.org/licenses/LICENSE-2.0
10
11    Unless required by applicable law or agreed to in writing, software
12    distributed under the License is distributed on an "AS IS" BASIS,
13    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14    See the License for the specific language governing permissions and
15    limitations under the License.
16
17    This source code is part of the near-RT RIC (RAN Intelligent Controller)
18    platform project (RICP).
19 ==================================================================================
20 */
21 // Code generated by go-swagger; DO NOT EDIT.
22
23 package restapi
24
25 import (
26         "context"
27         "crypto/tls"
28         "crypto/x509"
29         "errors"
30         "fmt"
31         "io/ioutil"
32         "log"
33         "net"
34         "net/http"
35         "os"
36         "os/signal"
37         "strconv"
38         "sync"
39         "sync/atomic"
40         "syscall"
41         "time"
42
43         "github.com/go-openapi/runtime/flagext"
44         "github.com/go-openapi/swag"
45         flags "github.com/jessevdk/go-flags"
46         "golang.org/x/net/netutil"
47
48         "gerrit.o-ran-sc.org/r/ric-plt/a1/pkg/restapi/operations"
49 )
50
51 const (
52         schemeHTTP  = "http"
53         schemeHTTPS = "https"
54         schemeUnix  = "unix"
55 )
56
57 var defaultSchemes []string
58
59 func init() {
60         defaultSchemes = []string{
61                 schemeHTTP,
62         }
63 }
64
65 // NewServer creates a new api a1 server but does not configure it
66 func NewServer(api *operations.A1API) *Server {
67         s := new(Server)
68
69         s.shutdown = make(chan struct{})
70         s.api = api
71         s.interrupt = make(chan os.Signal, 1)
72         return s
73 }
74
75 // ConfigureAPI configures the API and handlers.
76 func (s *Server) ConfigureAPI() {
77         if s.api != nil {
78                 s.handler = configureAPI(s.api)
79         }
80 }
81
82 // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
83 func (s *Server) ConfigureFlags() {
84         if s.api != nil {
85                 configureFlags(s.api)
86         }
87 }
88
89 // Server for the a1 API
90 type Server struct {
91         EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
92         CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
93         GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
94         MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
95
96         SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/a1.sock"`
97         domainSocketL net.Listener
98
99         Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
100         Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
101         ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
102         KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
103         ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
104         WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
105         httpServerL  net.Listener
106
107         TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
108         TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
109         TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
110         TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
111         TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
112         TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
113         TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
114         TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
115         TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
116         httpsServerL      net.Listener
117
118         api          *operations.A1API
119         handler      http.Handler
120         hasListeners bool
121         shutdown     chan struct{}
122         shuttingDown int32
123         interrupted  bool
124         interrupt    chan os.Signal
125 }
126
127 // Logf logs message either via defined user logger or via system one if no user logger is defined.
128 func (s *Server) Logf(f string, args ...interface{}) {
129         if s.api != nil && s.api.Logger != nil {
130                 s.api.Logger(f, args...)
131         } else {
132                 log.Printf(f, args...)
133         }
134 }
135
136 // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
137 // Exits with non-zero status after printing
138 func (s *Server) Fatalf(f string, args ...interface{}) {
139         if s.api != nil && s.api.Logger != nil {
140                 s.api.Logger(f, args...)
141                 os.Exit(1)
142         } else {
143                 log.Fatalf(f, args...)
144         }
145 }
146
147 // SetAPI configures the server with the specified API. Needs to be called before Serve
148 func (s *Server) SetAPI(api *operations.A1API) {
149         if api == nil {
150                 s.api = nil
151                 s.handler = nil
152                 return
153         }
154
155         s.api = api
156         s.handler = configureAPI(api)
157 }
158
159 func (s *Server) hasScheme(scheme string) bool {
160         schemes := s.EnabledListeners
161         if len(schemes) == 0 {
162                 schemes = defaultSchemes
163         }
164
165         for _, v := range schemes {
166                 if v == scheme {
167                         return true
168                 }
169         }
170         return false
171 }
172
173 // Serve the api
174 func (s *Server) Serve() (err error) {
175         if !s.hasListeners {
176                 if err = s.Listen(); err != nil {
177                         return err
178                 }
179         }
180
181         // set default handler, if none is set
182         if s.handler == nil {
183                 if s.api == nil {
184                         return errors.New("can't create the default handler, as no api is set")
185                 }
186
187                 s.SetHandler(s.api.Serve(nil))
188         }
189
190         wg := new(sync.WaitGroup)
191         once := new(sync.Once)
192         signalNotify(s.interrupt)
193         go handleInterrupt(once, s)
194
195         servers := []*http.Server{}
196
197         if s.hasScheme(schemeUnix) {
198                 domainSocket := new(http.Server)
199                 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
200                 domainSocket.Handler = s.handler
201                 if int64(s.CleanupTimeout) > 0 {
202                         domainSocket.IdleTimeout = s.CleanupTimeout
203                 }
204
205                 configureServer(domainSocket, "unix", string(s.SocketPath))
206
207                 servers = append(servers, domainSocket)
208                 wg.Add(1)
209                 s.Logf("Serving a1 at unix://%s", s.SocketPath)
210                 go func(l net.Listener) {
211                         defer wg.Done()
212                         if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
213                                 s.Fatalf("%v", err)
214                         }
215                         s.Logf("Stopped serving a1 at unix://%s", s.SocketPath)
216                 }(s.domainSocketL)
217         }
218
219         if s.hasScheme(schemeHTTP) {
220                 httpServer := new(http.Server)
221                 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
222                 httpServer.ReadTimeout = s.ReadTimeout
223                 httpServer.WriteTimeout = s.WriteTimeout
224                 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
225                 if s.ListenLimit > 0 {
226                         s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
227                 }
228
229                 if int64(s.CleanupTimeout) > 0 {
230                         httpServer.IdleTimeout = s.CleanupTimeout
231                 }
232
233                 httpServer.Handler = s.handler
234
235                 configureServer(httpServer, "http", s.httpServerL.Addr().String())
236
237                 servers = append(servers, httpServer)
238                 wg.Add(1)
239                 s.Logf("Serving a1 at http://%s", s.httpServerL.Addr())
240                 go func(l net.Listener) {
241                         defer wg.Done()
242                         if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
243                                 s.Fatalf("%v", err)
244                         }
245                         s.Logf("Stopped serving a1 at http://%s", l.Addr())
246                 }(s.httpServerL)
247         }
248
249         if s.hasScheme(schemeHTTPS) {
250                 httpsServer := new(http.Server)
251                 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
252                 httpsServer.ReadTimeout = s.TLSReadTimeout
253                 httpsServer.WriteTimeout = s.TLSWriteTimeout
254                 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
255                 if s.TLSListenLimit > 0 {
256                         s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
257                 }
258                 if int64(s.CleanupTimeout) > 0 {
259                         httpsServer.IdleTimeout = s.CleanupTimeout
260                 }
261                 httpsServer.Handler = s.handler
262
263                 // Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
264                 httpsServer.TLSConfig = &tls.Config{
265                         // Causes servers to use Go's default ciphersuite preferences,
266                         // which are tuned to avoid attacks. Does nothing on clients.
267                         PreferServerCipherSuites: true,
268                         // Only use curves which have assembly implementations
269                         // https://github.com/golang/go/tree/master/src/crypto/elliptic
270                         CurvePreferences: []tls.CurveID{tls.CurveP256},
271                         // Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
272                         NextProtos: []string{"h2", "http/1.1"},
273                         // https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
274                         MinVersion: tls.VersionTLS12,
275                         // These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
276                         CipherSuites: []uint16{
277                                 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
278                                 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
279                                 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
280                                 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
281                                 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
282                                 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
283                         },
284                 }
285
286                 // build standard config from server options
287                 if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
288                         httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
289                         httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
290                         if err != nil {
291                                 return err
292                         }
293                 }
294
295                 if s.TLSCACertificate != "" {
296                         // include specified CA certificate
297                         caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
298                         if caCertErr != nil {
299                                 return caCertErr
300                         }
301                         caCertPool := x509.NewCertPool()
302                         ok := caCertPool.AppendCertsFromPEM(caCert)
303                         if !ok {
304                                 return fmt.Errorf("cannot parse CA certificate")
305                         }
306                         httpsServer.TLSConfig.ClientCAs = caCertPool
307                         httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
308                 }
309
310                 // call custom TLS configurator
311                 configureTLS(httpsServer.TLSConfig)
312
313                 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
314                         // after standard and custom config are passed, this ends up with no certificate
315                         if s.TLSCertificate == "" {
316                                 if s.TLSCertificateKey == "" {
317                                         s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
318                                 }
319                                 s.Fatalf("the required flag `--tls-certificate` was not specified")
320                         }
321                         if s.TLSCertificateKey == "" {
322                                 s.Fatalf("the required flag `--tls-key` was not specified")
323                         }
324                         // this happens with a wrong custom TLS configurator
325                         s.Fatalf("no certificate was configured for TLS")
326                 }
327
328                 configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
329
330                 servers = append(servers, httpsServer)
331                 wg.Add(1)
332                 s.Logf("Serving a1 at https://%s", s.httpsServerL.Addr())
333                 go func(l net.Listener) {
334                         defer wg.Done()
335                         if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
336                                 s.Fatalf("%v", err)
337                         }
338                         s.Logf("Stopped serving a1 at https://%s", l.Addr())
339                 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
340         }
341
342         wg.Add(1)
343         go s.handleShutdown(wg, &servers)
344
345         wg.Wait()
346         return nil
347 }
348
349 // Listen creates the listeners for the server
350 func (s *Server) Listen() error {
351         if s.hasListeners { // already done this
352                 return nil
353         }
354
355         if s.hasScheme(schemeHTTPS) {
356                 // Use http host if https host wasn't defined
357                 if s.TLSHost == "" {
358                         s.TLSHost = s.Host
359                 }
360                 // Use http listen limit if https listen limit wasn't defined
361                 if s.TLSListenLimit == 0 {
362                         s.TLSListenLimit = s.ListenLimit
363                 }
364                 // Use http tcp keep alive if https tcp keep alive wasn't defined
365                 if int64(s.TLSKeepAlive) == 0 {
366                         s.TLSKeepAlive = s.KeepAlive
367                 }
368                 // Use http read timeout if https read timeout wasn't defined
369                 if int64(s.TLSReadTimeout) == 0 {
370                         s.TLSReadTimeout = s.ReadTimeout
371                 }
372                 // Use http write timeout if https write timeout wasn't defined
373                 if int64(s.TLSWriteTimeout) == 0 {
374                         s.TLSWriteTimeout = s.WriteTimeout
375                 }
376         }
377
378         if s.hasScheme(schemeUnix) {
379                 domSockListener, err := net.Listen("unix", string(s.SocketPath))
380                 if err != nil {
381                         return err
382                 }
383                 s.domainSocketL = domSockListener
384         }
385
386         if s.hasScheme(schemeHTTP) {
387                 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
388                 if err != nil {
389                         return err
390                 }
391
392                 h, p, err := swag.SplitHostPort(listener.Addr().String())
393                 if err != nil {
394                         return err
395                 }
396                 s.Host = h
397                 s.Port = p
398                 s.httpServerL = listener
399         }
400
401         if s.hasScheme(schemeHTTPS) {
402                 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
403                 if err != nil {
404                         return err
405                 }
406
407                 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
408                 if err != nil {
409                         return err
410                 }
411                 s.TLSHost = sh
412                 s.TLSPort = sp
413                 s.httpsServerL = tlsListener
414         }
415
416         s.hasListeners = true
417         return nil
418 }
419
420 // Shutdown server and clean up resources
421 func (s *Server) Shutdown() error {
422         if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
423                 close(s.shutdown)
424         }
425         return nil
426 }
427
428 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
429         // wg.Done must occur last, after s.api.ServerShutdown()
430         // (to preserve old behaviour)
431         defer wg.Done()
432
433         <-s.shutdown
434
435         servers := *serversPtr
436
437         ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
438         defer cancel()
439
440         // first execute the pre-shutdown hook
441         s.api.PreServerShutdown()
442
443         shutdownChan := make(chan bool)
444         for i := range servers {
445                 server := servers[i]
446                 go func() {
447                         var success bool
448                         defer func() {
449                                 shutdownChan <- success
450                         }()
451                         if err := server.Shutdown(ctx); err != nil {
452                                 // Error from closing listeners, or context timeout:
453                                 s.Logf("HTTP server Shutdown: %v", err)
454                         } else {
455                                 success = true
456                         }
457                 }()
458         }
459
460         // Wait until all listeners have successfully shut down before calling ServerShutdown
461         success := true
462         for range servers {
463                 success = success && <-shutdownChan
464         }
465         if success {
466                 s.api.ServerShutdown()
467         }
468 }
469
470 // GetHandler returns a handler useful for testing
471 func (s *Server) GetHandler() http.Handler {
472         return s.handler
473 }
474
475 // SetHandler allows for setting a http handler on this server
476 func (s *Server) SetHandler(handler http.Handler) {
477         s.handler = handler
478 }
479
480 // UnixListener returns the domain socket listener
481 func (s *Server) UnixListener() (net.Listener, error) {
482         if !s.hasListeners {
483                 if err := s.Listen(); err != nil {
484                         return nil, err
485                 }
486         }
487         return s.domainSocketL, nil
488 }
489
490 // HTTPListener returns the http listener
491 func (s *Server) HTTPListener() (net.Listener, error) {
492         if !s.hasListeners {
493                 if err := s.Listen(); err != nil {
494                         return nil, err
495                 }
496         }
497         return s.httpServerL, nil
498 }
499
500 // TLSListener returns the https listener
501 func (s *Server) TLSListener() (net.Listener, error) {
502         if !s.hasListeners {
503                 if err := s.Listen(); err != nil {
504                         return nil, err
505                 }
506         }
507         return s.httpsServerL, nil
508 }
509
510 func handleInterrupt(once *sync.Once, s *Server) {
511         once.Do(func() {
512                 for range s.interrupt {
513                         if s.interrupted {
514                                 s.Logf("Server already shutting down")
515                                 continue
516                         }
517                         s.interrupted = true
518                         s.Logf("Shutting down... ")
519                         if err := s.Shutdown(); err != nil {
520                                 s.Logf("HTTP server Shutdown: %v", err)
521                         }
522                 }
523         })
524 }
525
526 func signalNotify(interrupt chan<- os.Signal) {
527         signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
528 }