package grpcserver import ( "context" "net" "strings" "git.coopgo.io/coopgo-apps/silvermobi/handler" grpcproto "git.coopgo.io/coopgo-apps/silvermobi/servers/grpcapi/proto" "github.com/golang-jwt/jwt/v4" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/rs/zerolog/log" "github.com/spf13/viper" "google.golang.org/grpc" ) type contextKey string var ( contextKeyUser = contextKey("user") ) type SilvermobiGRPCService struct { Config *viper.Viper Handler *handler.SilvermobiHandler grpcproto.UnimplementedSilvermobiGRPCServer } func NewSilvermobiGRPCService(cfg *viper.Viper, handler *handler.SilvermobiHandler) SilvermobiGRPCService { return SilvermobiGRPCService{ Config: cfg, Handler: handler, } } func Run(done chan error, cfg *viper.Viper, handler *handler.SilvermobiHandler) { var ( address = "127.0.0.1:" + cfg.GetString("services.external.grpc.port") jwt_secret = cfg.GetString("identification.local.jwt_secret") ) log.Info().Msg("GRPC server on " + address) server := grpc.NewServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_ctxtags.StreamServerInterceptor(), StreamAuthServerInterceptor(GRPCAuthFunc(jwt_secret)), )), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( grpc_ctxtags.UnaryServerInterceptor(), UnaryAuthServerInterceptor(GRPCAuthFunc(jwt_secret)), )), ) grpcproto.RegisterSilvermobiGRPCServer(server, NewSilvermobiGRPCService(cfg, handler)) l, err := net.Listen("tcp", address) if err != nil { log.Fatal().Err(err) } if err := server.Serve(l); err != nil { log.Error().Err(err).Msg("gRPC service ended") done <- err } } func UnaryAuthServerInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if NoAuth(info.FullMethod) { return handler(ctx, req) } var newCtx context.Context var err error newCtx, err = authFunc(ctx) if err != nil { return nil, err } return handler(newCtx, req) } } func StreamAuthServerInterceptor(authFunc grpc_auth.AuthFunc) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if NoAuth(info.FullMethod) { wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = stream.Context() return handler(srv, wrapped) } var newCtx context.Context var err error newCtx, err = authFunc(stream.Context()) if err != nil { return err } wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx return handler(srv, wrapped) } } func GRPCAuthFunc(jwtKey string) grpc_auth.AuthFunc { return func(ctx context.Context) (context.Context, error) { tokenString, err := grpc_auth.AuthFromMD(ctx, "bearer") if err != nil { log.Error().Err(err) return nil, err } claims := jwt.MapClaims{} jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { return []byte(jwtKey), nil }) ctx = context.WithValue(ctx, contextKeyUser, claims["sub"].(string)) return ctx, nil } } func NoAuth(method string) bool { return strings.Contains(method, "Auth") }