package grpcserver import ( "context" "fmt" "git.coopgo.io/coopgo-apps/silvermobi/handler" grpcproto "git.coopgo.io/coopgo-apps/silvermobi/servers/grpcapi/proto" "github.com/golang-jwt/jwt/v4" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/rs/zerolog/log" "github.com/spf13/viper" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "net" ) type contextKey string var ( contextKeyUser = contextKey("user") ) func NoAuth(method string) bool { noAuthMethods := []string{ "/SilvermobiGRPC/ForgetAccount", "/SilvermobiGRPC/UpdatePassword", "/SilvermobiGRPC/AuthRegister", "/SilvermobiGRPC/AuthLogin", "/SilvermobiGRPC/GeoAutocomplete", "/SilvermobiGRPC/GeoRouteWithReturn", "/SilvermobiGRPC/GeoRoute", } for _, m := range noAuthMethods { if method == m { return true } } return false } func UnaryAuthServerInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { print(info.FullMethod) if NoAuth(info.FullMethod) { return handler(ctx, req) } var newCtx context.Context var err error newCtx, err = authFunc(ctx) if err != nil { return nil, err } return handler(newCtx, req) } } func StreamAuthServerInterceptor(authFunc grpc_auth.AuthFunc) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if NoAuth(info.FullMethod) { wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = stream.Context() return handler(srv, wrapped) } var newCtx context.Context var err error newCtx, err = authFunc(stream.Context()) if err != nil { return err } wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx return handler(srv, wrapped) } } type SilvermobiGRPCService struct { Config *viper.Viper Handler *handler.SilvermobiHandler grpcproto.UnimplementedSilvermobiGRPCServer } type SolidarityService struct { Config *viper.Viper Handler *handler.SilvermobiHandler SolidarityClient grpcproto.SolidarityServiceClient grpcproto.UnimplementedSolidarityServiceServer // Add this client } func NewSolidarityService(cfg *viper.Viper, handler *handler.SilvermobiHandler) SolidarityService { solidarityServiceAddress := cfg.GetString("solidarity_service.address") conn, err := grpc.Dial(solidarityServiceAddress, grpc.WithInsecure()) if err != nil { log.Fatal().Err(err) } solidarityClient := grpcproto.NewSolidarityServiceClient(conn) return SolidarityService{ Config: cfg, Handler: handler, SolidarityClient: solidarityClient, } } func NewSilvermobiGRPCService(cfg *viper.Viper, handler *handler.SilvermobiHandler) SilvermobiGRPCService { return SilvermobiGRPCService{ Config: cfg, Handler: handler, } } func Run(done chan error, cfg *viper.Viper, handler *handler.SilvermobiHandler) { var ( address = cfg.GetString("services.external.grpc.ip") + ":" + cfg.GetString("services.external.grpc.port") jwt_secret = cfg.GetString("identification.local.jwt_secret") ) log.Info().Msg("GRPC server on " + address) server := grpc.NewServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_ctxtags.StreamServerInterceptor(), StreamAuthServerInterceptor(GRPCAuthFunc(jwt_secret)), )), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( grpc_ctxtags.UnaryServerInterceptor(), UnaryAuthServerInterceptor(GRPCAuthFunc(jwt_secret)), )), ) solidarity_service := NewSolidarityService(cfg, handler) silvermobi_service := NewSilvermobiGRPCService(cfg, handler) grpcproto.RegisterSilvermobiGRPCServer(server, silvermobi_service) grpcproto.RegisterSolidarityServiceServer(server, &solidarity_service) l, err := net.Listen("tcp", address) if err != nil { log.Fatal().Err(err) } if err := server.Serve(l); err != nil { log.Error().Err(err).Msg("gRPC service ended") done <- err } } func GRPCAuthFunc(jwtKey string) grpc_auth.AuthFunc { return func(ctx context.Context) (context.Context, error) { tokenString, err := grpc_auth.AuthFromMD(ctx, "bearer") if err != nil { return nil, err } claims := jwt.MapClaims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { return []byte(jwtKey), nil }) if err != nil || !token.Valid { fmt.Println(err) return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired token") } ctx = context.WithValue(ctx, contextKeyUser, claims["sub"].(string)) return ctx, nil } }