diff --git a/rpc/server.go b/rpc/server.go index 5e6b587..c675b53 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -2,13 +2,11 @@ package rpc import ( "context" - "net/http" + "net" "sync" - "time" "github.com/appleboy/gorush/gorush" "github.com/appleboy/gorush/rpc/proto" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -110,27 +108,24 @@ func RunGRPCServer(ctx context.Context) error { rpcSrv := NewServer() proto.RegisterGorushServer(s, rpcSrv) proto.RegisterHealthServer(s, rpcSrv) + // Register reflection service on gRPC server. reflection.Register(s) - gorush.LogAccess.Info("gRPC server is running on " + gorush.PushConf.GRPC.Port + " port.") - srv := &http.Server{ - Addr: ":" + gorush.PushConf.GRPC.Port, - Handler: s, + lis, err := net.Listen("tcp", ":"+gorush.PushConf.GRPC.Port) + if err != nil { + gorush.LogError.Fatalln(err) + return err } - - var g errgroup.Group - g.Go(func() error { + gorush.LogAccess.Info("gRPC server is running on " + gorush.PushConf.GRPC.Port + " port.") + go func() { select { case <-ctx.Done(): - timeout := time.Duration(gorush.PushConf.Core.ShutdownTimeout) * time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return srv.Shutdown(ctx) + s.GracefulStop() // graceful shutdown } - }) - g.Go(func() error { - return srv.ListenAndServe() - }) - return g.Wait() + }() + if err = s.Serve(lis); err != nil { + gorush.LogError.Fatalln(err) + } + return err } diff --git a/rpc/server_test.go b/rpc/server_test.go index 9ab1e3e..aae9c00 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -1 +1,47 @@ package rpc + +import ( + "context" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + + "github.com/appleboy/gorush/gorush" +) + +const gRPCAddr = "localhost:9000" + +func TestGracefulShutDownGRPCServer(t *testing.T) { + // server configs + gorush.InitLog() + gorush.PushConf.GRPC.Enabled = true + gorush.PushConf.GRPC.Port = "9000" + gorush.PushConf.Log.Format = "json" + + // Run gRPC server + ctx, gRPCContextCancel := context.WithCancel(context.Background()) + go func() { + if err := RunGRPCServer(ctx); err != nil { + panic(err) + } + }() + + // gRPC client conn + conn, err := grpc.Dial( + gRPCAddr, + grpc.WithInsecure(), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + ) // wait for server ready + if err != nil { + t.Error(err) + } + + // Stop gRPC server + go gRPCContextCancel() + + // wait for client connection would be closed + for conn.GetState() != connectivity.TransientFailure { + } + conn.Close() +}