From bcd0e70252966293fa322601edbd9bee14d0e9d2 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Tue, 4 Feb 2020 17:10:12 +0800 Subject: [PATCH] feat(server): support graceful shutdown (#461) * feat(server): support graceful shutdown for http server Signed-off-by: Bo-Yi Wu --- gorush/server_lambda.go | 4 +++- gorush/server_normal.go | 48 ++++++++++++++++++++++++++++++++++------- gorush/server_test.go | 18 ++++++++-------- main.go | 17 +++++++++++++-- 4 files changed, 67 insertions(+), 20 deletions(-) diff --git a/gorush/server_lambda.go b/gorush/server_lambda.go index 852f243..b4bdbc2 100644 --- a/gorush/server_lambda.go +++ b/gorush/server_lambda.go @@ -3,11 +3,13 @@ package gorush import ( + "context" + "github.com/apex/gateway" ) // RunHTTPServer provide run http or https protocol. -func RunHTTPServer() error { +func RunHTTPServer(ctx context.Context) error { if !PushConf.Core.Enabled { LogAccess.Debug("httpd server is disabled.") return nil diff --git a/gorush/server_normal.go b/gorush/server_normal.go index a96a61c..6613115 100644 --- a/gorush/server_normal.go +++ b/gorush/server_normal.go @@ -3,16 +3,19 @@ package gorush import ( + "context" "crypto/tls" "encoding/base64" "errors" "net/http" + + "golang.org/x/sync/errgroup" ) // RunHTTPServer provide run http or https protocol. -func RunHTTPServer() (err error) { +func RunHTTPServer(ctx context.Context) (err error) { if !PushConf.Core.Enabled { - LogAccess.Debug("httpd server is disabled.") + LogAccess.Info("httpd server is disabled.") return nil } @@ -23,7 +26,7 @@ func RunHTTPServer() (err error) { LogAccess.Info("HTTPD server is running on " + PushConf.Core.Port + " port.") if PushConf.Core.AutoTLS.Enabled { - return startServer(autoTLSServer()) + return startServer(ctx, autoTLSServer()) } else if PushConf.Core.SSL { config := &tls.Config{ MinVersion: tls.VersionTLS10, @@ -62,12 +65,41 @@ func RunHTTPServer() (err error) { server.TLSConfig = config } - return startServer(server) + return startServer(ctx, server) } -func startServer(s *http.Server) error { - if s.TLSConfig == nil { +func listenAndServe(ctx context.Context, s *http.Server) error { + var g errgroup.Group + g.Go(func() error { + select { + case <-ctx.Done(): + return s.Shutdown(ctx) + } + }) + g.Go(func() error { return s.ListenAndServe() - } - return s.ListenAndServeTLS("", "") + }) + return g.Wait() +} + +func listenAndServeTLS(ctx context.Context, s *http.Server) error { + var g errgroup.Group + g.Go(func() error { + select { + case <-ctx.Done(): + return s.Shutdown(ctx) + } + }) + g.Go(func() error { + return s.ListenAndServeTLS("", "") + }) + return g.Wait() +} + +func startServer(ctx context.Context, s *http.Server) error { + if s.TLSConfig == nil { + return listenAndServe(ctx, s) + } + + return listenAndServeTLS(ctx, s) } diff --git a/gorush/server_test.go b/gorush/server_test.go index 20b2115..e8b8b8e 100644 --- a/gorush/server_test.go +++ b/gorush/server_test.go @@ -1,6 +1,7 @@ package gorush import ( + "context" "crypto/tls" "io/ioutil" "log" @@ -63,13 +64,12 @@ func TestRunNormalServer(t *testing.T) { gin.SetMode(gin.TestMode) go func() { - assert.NoError(t, RunHTTPServer()) + assert.NoError(t, RunHTTPServer(context.Background())) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete time.Sleep(5 * time.Millisecond) - assert.Error(t, RunHTTPServer()) testRequest(t, "http://localhost:8088/api/stat/go") } @@ -82,7 +82,7 @@ func TestRunTLSServer(t *testing.T) { PushConf.Core.KeyPath = "../certificate/localhost.key" go func() { - assert.NoError(t, RunHTTPServer()) + assert.NoError(t, RunHTTPServer(context.Background())) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete @@ -104,7 +104,7 @@ func TestRunTLSBase64Server(t *testing.T) { PushConf.Core.KeyBase64 = key go func() { - assert.NoError(t, RunHTTPServer()) + assert.NoError(t, RunHTTPServer(context.Background())) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete @@ -117,7 +117,7 @@ func TestRunAutoTLSServer(t *testing.T) { initTest() PushConf.Core.AutoTLS.Enabled = true go func() { - assert.NoError(t, RunHTTPServer()) + assert.NoError(t, RunHTTPServer(context.Background())) }() // have to wait for the goroutine to start and run the server // otherwise the main thread will complete @@ -132,7 +132,7 @@ func TestLoadTLSCertError(t *testing.T) { PushConf.Core.CertPath = "../config/config.yml" PushConf.Core.KeyPath = "../config/config.yml" - assert.Error(t, RunHTTPServer()) + assert.Error(t, RunHTTPServer(context.Background())) } func TestMissingTLSCertConfg(t *testing.T) { @@ -145,8 +145,8 @@ func TestMissingTLSCertConfg(t *testing.T) { PushConf.Core.CertBase64 = "" PushConf.Core.KeyBase64 = "" - err := RunHTTPServer() - assert.Error(t, RunHTTPServer()) + err := RunHTTPServer(context.Background()) + assert.Error(t, RunHTTPServer(context.Background())) assert.Equal(t, "missing https cert config", err.Error()) } @@ -383,7 +383,7 @@ func TestVersionHandler(t *testing.T) { func TestDisabledHTTPServer(t *testing.T) { initTest() PushConf.Core.Enabled = false - err := RunHTTPServer() + err := RunHTTPServer(context.Background()) PushConf.Core.Enabled = true assert.Nil(t, err) diff --git a/main.go b/main.go index 56371eb..089e4f2 100644 --- a/main.go +++ b/main.go @@ -248,12 +248,14 @@ func main() { gorush.LogError.Fatal(err) } + finished := make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(int(gorush.PushConf.Core.WorkerNum)) ctx := withContextFunc(context.Background(), func() { gorush.LogAccess.Info("close the notification queue channel") close(gorush.QueueNotification) wg.Wait() + close(finished) gorush.LogAccess.Info("the notification queue has been clear") }) @@ -269,8 +271,19 @@ func main() { var g errgroup.Group - g.Go(gorush.RunHTTPServer) // Run httpd server - g.Go(rpc.RunGRPCServer) // Run gRPC internal server + g.Go(func() error { + return gorush.RunHTTPServer(ctx) + }) // Run httpd server + + g.Go(rpc.RunGRPCServer) // Run gRPC internal server + + // check job completely + g.Go(func() error { + select { + case <-finished: + } + return nil + }) if err = g.Wait(); err != nil { gorush.LogError.Fatal(err)