diff --git a/config/config.go b/config/config.go index 828e19b..5d3a9a9 100644 --- a/config/config.go +++ b/config/config.go @@ -243,7 +243,7 @@ type SectionGRPC struct { } // LoadConf load config from file and read in environment variables that match -func LoadConf(confPath string) (ConfYaml, error) { +func LoadConf(confPath ...string) (ConfYaml, error) { var conf ConfYaml viper.SetConfigType("yaml") @@ -251,8 +251,8 @@ func LoadConf(confPath string) (ConfYaml, error) { viper.SetEnvPrefix("gorush") // will be uppercased automatically viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - if confPath != "" { - content, err := ioutil.ReadFile(confPath) + if len(confPath) > 0 && confPath[0] != "" { + content, err := ioutil.ReadFile(confPath[0]) if err != nil { return conf, err } diff --git a/config/config_test.go b/config/config_test.go index ef7e5d3..5405040 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -25,7 +25,7 @@ type ConfigTestSuite struct { func (suite *ConfigTestSuite) SetupTest() { var err error - suite.ConfGorushDefault, err = LoadConf("") + suite.ConfGorushDefault, err = LoadConf() if err != nil { panic("failed to load default config.yml") } @@ -215,6 +215,6 @@ func TestLoadConfigFromEnv(t *testing.T) { func TestLoadWrongDefaultYAMLConfig(t *testing.T) { defaultConf = []byte(`a`) - _, err := LoadConf("") + _, err := LoadConf() assert.Error(t, err) } diff --git a/gorush/const.go b/gorush/const.go deleted file mode 100644 index d803d15..0000000 --- a/gorush/const.go +++ /dev/null @@ -1,12 +0,0 @@ -package gorush - -// Stat variable for redis -const ( - TotalCountKey = "gorush-total-count" - IosSuccessKey = "gorush-ios-success-count" - IosErrorKey = "gorush-ios-error-count" - AndroidSuccessKey = "gorush-android-success-count" - AndroidErrorKey = "gorush-android-error-count" - HuaweiSuccessKey = "gorush-huawei-success-count" - HuaweiErrorKey = "gorush-huawei-error-count" -) diff --git a/gorush/feedback_test.go b/gorush/feedback_test.go index 2ef91ec..dbf460d 100644 --- a/gorush/feedback_test.go +++ b/gorush/feedback_test.go @@ -13,7 +13,7 @@ import ( ) func TestEmptyFeedbackURL(t *testing.T) { - // PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() logEntry := logx.LogPushEntry{ ID: "", Type: "", @@ -23,13 +23,13 @@ func TestEmptyFeedbackURL(t *testing.T) { Error: "", } - err := DispatchFeedback(logEntry, PushConf.Core.FeedbackURL, PushConf.Core.FeedbackTimeout) + err := DispatchFeedback(logEntry, cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) assert.NotNil(t, err) } func TestHTTPErrorInFeedbackCall(t *testing.T) { - config, _ := config.LoadConf("") - config.Core.FeedbackURL = "http://test.example.com/api/" + cfg, _ := config.LoadConf() + cfg.Core.FeedbackURL = "http://test.example.com/api/" logEntry := logx.LogPushEntry{ ID: "", Type: "", @@ -39,7 +39,7 @@ func TestHTTPErrorInFeedbackCall(t *testing.T) { Error: "", } - err := DispatchFeedback(logEntry, config.Core.FeedbackURL, config.Core.FeedbackTimeout) + err := DispatchFeedback(logEntry, cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) assert.NotNil(t, err) } @@ -59,8 +59,8 @@ func TestSuccessfulFeedbackCall(t *testing.T) { ) defer httpMock.Close() - config, _ := config.LoadConf("") - config.Core.FeedbackURL = httpMock.URL + cfg, _ := config.LoadConf() + cfg.Core.FeedbackURL = httpMock.URL logEntry := logx.LogPushEntry{ ID: "", Type: "", @@ -70,6 +70,6 @@ func TestSuccessfulFeedbackCall(t *testing.T) { Error: "", } - err := DispatchFeedback(logEntry, config.Core.FeedbackURL, config.Core.FeedbackTimeout) + err := DispatchFeedback(logEntry, cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) assert.Nil(t, err) } diff --git a/gorush/global.go b/gorush/global.go index 1f51c43..cdee529 100644 --- a/gorush/global.go +++ b/gorush/global.go @@ -1,18 +1,12 @@ package gorush import ( - "github.com/appleboy/gorush/config" - "github.com/appleboy/go-fcm" "github.com/msalihkarakasli/go-hms-push/push/core" "github.com/sideshow/apns2" ) var ( - // PushConf is gorush config - PushConf config.ConfYaml - // QueueNotification is chan type - QueueNotification chan PushNotification // ApnsClient is apns client ApnsClient *apns2.Client // FCMClient is apns client diff --git a/gorush/main_test.go b/gorush/main_test.go index 81e9cd0..5832fe7 100644 --- a/gorush/main_test.go +++ b/gorush/main_test.go @@ -1,9 +1,7 @@ package gorush import ( - "context" "log" - "sync" "testing" "github.com/appleboy/gorush/config" @@ -12,33 +10,19 @@ import ( ) func TestMain(m *testing.M) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() if err := logx.InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, ); err != nil { log.Fatal(err) } - if err := status.InitAppStatus(PushConf); err != nil { + if err := status.InitAppStatus(cfg); err != nil { log.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) - wg := &sync.WaitGroup{} - wg.Add(int(PushConf.Core.WorkerNum)) - InitWorkers(ctx, wg, PushConf.Core.WorkerNum, PushConf.Core.QueueNum) - - if err := status.InitAppStatus(PushConf); err != nil { - log.Fatal(err) - } - - defer func() { - close(QueueNotification) - cancel() - }() - m.Run() } diff --git a/gorush/notification.go b/gorush/notification.go index 07fb5f8..8d6192d 100644 --- a/gorush/notification.go +++ b/gorush/notification.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/appleboy/go-fcm" + "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/logx" "github.com/msalihkarakasli/go-hms-push/push/model" @@ -54,8 +55,8 @@ type RequestPush struct { // PushNotification is single notification request type PushNotification struct { - wg *sync.WaitGroup - log *[]logx.LogPushEntry + Wg *sync.WaitGroup + Log *[]logx.LogPushEntry // Common ID string `json:"notif_id,omitempty"` @@ -112,22 +113,22 @@ type PushNotification struct { // WaitDone decrements the WaitGroup counter. func (p *PushNotification) WaitDone() { - if p.wg != nil { - p.wg.Done() + if p.Wg != nil { + p.Wg.Done() } } // AddWaitCount increments the WaitGroup counter. func (p *PushNotification) AddWaitCount() { - if p.wg != nil { - p.wg.Add(1) + if p.Wg != nil { + p.Wg.Add(1) } } // AddLog record fail log of notification func (p *PushNotification) AddLog(log logx.LogPushEntry) { - if p.log != nil { - *p.log = append(*p.log, log) + if p.Log != nil { + *p.Log = append(*p.Log, log) } } @@ -199,39 +200,55 @@ func SetProxy(proxy string) error { } // CheckPushConf provide check your yml config. -func CheckPushConf() error { - if !PushConf.Ios.Enabled && !PushConf.Android.Enabled && !PushConf.Huawei.Enabled { +func CheckPushConf(cfg config.ConfYaml) error { + if !cfg.Ios.Enabled && !cfg.Android.Enabled && !cfg.Huawei.Enabled { return errors.New("Please enable iOS, Android or Huawei config in yml config") } - if PushConf.Ios.Enabled { - if PushConf.Ios.KeyPath == "" && PushConf.Ios.KeyBase64 == "" { + if cfg.Ios.Enabled { + if cfg.Ios.KeyPath == "" && cfg.Ios.KeyBase64 == "" { return errors.New("Missing iOS certificate key") } // check certificate file exist - if PushConf.Ios.KeyPath != "" { - if _, err := os.Stat(PushConf.Ios.KeyPath); os.IsNotExist(err) { + if cfg.Ios.KeyPath != "" { + if _, err := os.Stat(cfg.Ios.KeyPath); os.IsNotExist(err) { return errors.New("certificate file does not exist") } } } - if PushConf.Android.Enabled { - if PushConf.Android.APIKey == "" { + if cfg.Android.Enabled { + if cfg.Android.APIKey == "" { return errors.New("Missing Android API Key") } } - if PushConf.Huawei.Enabled { - if PushConf.Huawei.AppSecret == "" { + if cfg.Huawei.Enabled { + if cfg.Huawei.AppSecret == "" { return errors.New("Missing Huawei App Secret") } - if PushConf.Huawei.AppID == "" { + if cfg.Huawei.AppID == "" { return errors.New("Missing Huawei App ID") } } return nil } + +// SendNotification send notification +func SendNotification(cfg config.ConfYaml, req PushNotification) { + defer func() { + req.WaitDone() + }() + + switch req.Platform { + case core.PlatFormIos: + PushToIOS(cfg, req) + case core.PlatFormAndroid: + PushToAndroid(cfg, req) + case core.PlatFormHuawei: + PushToHuawei(cfg, req) + } +} diff --git a/gorush/notification_apns.go b/gorush/notification_apns.go index 1952ce1..e018c24 100644 --- a/gorush/notification_apns.go +++ b/gorush/notification_apns.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/logx" "github.com/appleboy/gorush/status" @@ -50,23 +51,23 @@ type Sound struct { } // InitAPNSClient use for initialize APNs Client. -func InitAPNSClient() error { - if PushConf.Ios.Enabled { +func InitAPNSClient(cfg config.ConfYaml) error { + if cfg.Ios.Enabled { var err error var authKey *ecdsa.PrivateKey var certificateKey tls.Certificate var ext string - if PushConf.Ios.KeyPath != "" { - ext = filepath.Ext(PushConf.Ios.KeyPath) + if cfg.Ios.KeyPath != "" { + ext = filepath.Ext(cfg.Ios.KeyPath) switch ext { case ".p12": - certificateKey, err = certificate.FromP12File(PushConf.Ios.KeyPath, PushConf.Ios.Password) + certificateKey, err = certificate.FromP12File(cfg.Ios.KeyPath, cfg.Ios.Password) case ".pem": - certificateKey, err = certificate.FromPemFile(PushConf.Ios.KeyPath, PushConf.Ios.Password) + certificateKey, err = certificate.FromPemFile(cfg.Ios.KeyPath, cfg.Ios.Password) case ".p8": - authKey, err = token.AuthKeyFromFile(PushConf.Ios.KeyPath) + authKey, err = token.AuthKeyFromFile(cfg.Ios.KeyPath) default: err = errors.New("wrong certificate key extension") } @@ -76,9 +77,9 @@ func InitAPNSClient() error { return err } - } else if PushConf.Ios.KeyBase64 != "" { - ext = "." + PushConf.Ios.KeyType - key, err := base64.StdEncoding.DecodeString(PushConf.Ios.KeyBase64) + } else if cfg.Ios.KeyBase64 != "" { + ext = "." + cfg.Ios.KeyType + key, err := base64.StdEncoding.DecodeString(cfg.Ios.KeyBase64) if err != nil { logx.LogError.Error("base64 decode error:", err.Error()) @@ -86,9 +87,9 @@ func InitAPNSClient() error { } switch ext { case ".p12": - certificateKey, err = certificate.FromP12Bytes(key, PushConf.Ios.Password) + certificateKey, err = certificate.FromP12Bytes(key, cfg.Ios.Password) case ".pem": - certificateKey, err = certificate.FromPemBytes(key, PushConf.Ios.Password) + certificateKey, err = certificate.FromPemBytes(key, cfg.Ios.Password) case ".p8": authKey, err = token.AuthKeyFromBytes(key) default: @@ -103,7 +104,7 @@ func InitAPNSClient() error { } if ext == ".p8" { - if PushConf.Ios.KeyID == "" || PushConf.Ios.TeamID == "" { + if cfg.Ios.KeyID == "" || cfg.Ios.TeamID == "" { msg := "You should provide ios.KeyID and ios.TeamID for P8 token" logx.LogError.Error(msg) return errors.New(msg) @@ -111,14 +112,14 @@ func InitAPNSClient() error { token := &token.Token{ AuthKey: authKey, // KeyID from developer account (Certificates, Identifiers & Profiles -> Keys) - KeyID: PushConf.Ios.KeyID, + KeyID: cfg.Ios.KeyID, // TeamID from developer account (View Account -> Membership) - TeamID: PushConf.Ios.TeamID, + TeamID: cfg.Ios.TeamID, } - ApnsClient, err = newApnsTokenClient(token) + ApnsClient, err = newApnsTokenClient(cfg, token) } else { - ApnsClient, err = newApnsClient(certificateKey) + ApnsClient, err = newApnsClient(cfg, certificateKey) } if h2Transport, ok := ApnsClient.HTTPClient.Transport.(*http2.Transport); ok { @@ -130,21 +131,23 @@ func InitAPNSClient() error { return err } + + MaxConcurrentIOSPushes = make(chan struct{}, cfg.Ios.MaxConcurrentPushes) } return nil } -func newApnsClient(certificate tls.Certificate) (*apns2.Client, error) { +func newApnsClient(cfg config.ConfYaml, certificate tls.Certificate) (*apns2.Client, error) { var client *apns2.Client - if PushConf.Ios.Production { + if cfg.Ios.Production { client = apns2.NewClient(certificate).Production() } else { client = apns2.NewClient(certificate).Development() } - if PushConf.Core.HTTPProxy == "" { + if cfg.Core.HTTPProxy == "" { return client, nil } @@ -175,16 +178,16 @@ func newApnsClient(certificate tls.Certificate) (*apns2.Client, error) { return client, nil } -func newApnsTokenClient(token *token.Token) (*apns2.Client, error) { +func newApnsTokenClient(cfg config.ConfYaml, token *token.Token) (*apns2.Client, error) { var client *apns2.Client - if PushConf.Ios.Production { + if cfg.Ios.Production { client = apns2.NewTokenClient(token).Production() } else { client = apns2.NewTokenClient(token).Development() } - if PushConf.Core.HTTPProxy == "" { + if cfg.Core.HTTPProxy == "" { return client, nil } @@ -365,13 +368,13 @@ func GetIOSNotification(req PushNotification) *apns2.Notification { return notification } -func getApnsClient(req PushNotification) (client *apns2.Client) { +func getApnsClient(cfg config.ConfYaml, req PushNotification) (client *apns2.Client) { if req.Production { client = ApnsClient.Production() } else if req.Development { client = ApnsClient.Development() } else { - if PushConf.Ios.Production { + if cfg.Ios.Production { client = ApnsClient.Production() } else { client = ApnsClient.Development() @@ -381,12 +384,12 @@ func getApnsClient(req PushNotification) (client *apns2.Client) { } // PushToIOS provide send notification to APNs server. -func PushToIOS(req PushNotification) { +func PushToIOS(cfg config.ConfYaml, req PushNotification) { logx.LogAccess.Debug("Start push notification for iOS") var ( retryCount = 0 - maxRetry = PushConf.Ios.MaxRetry + maxRetry = cfg.Ios.MaxRetry ) if req.Retry > 0 && req.Retry < maxRetry { @@ -397,7 +400,7 @@ Retry: var newTokens []string notification := GetIOSNotification(req) - client := getApnsClient(req) + client := getApnsClient(cfg, req) var wg sync.WaitGroup for _, token := range req.Tokens { @@ -416,17 +419,17 @@ Retry: err = errors.New(res.Reason) } // apns server error - logPush(core.FailedPush, token, req, err) + logPush(cfg, core.FailedPush, token, req, err) - if PushConf.Core.Sync { - req.AddLog(createLogPushEntry(core.FailedPush, token, req, err)) - } else if PushConf.Core.FeedbackURL != "" { + if cfg.Core.Sync { + req.AddLog(createLogPushEntry(cfg, core.FailedPush, token, req, err)) + } else if cfg.Core.FeedbackURL != "" { go func(logger *logrus.Logger, log logx.LogPushEntry, url string, timeout int64) { err := DispatchFeedback(log, url, timeout) if err != nil { logger.Error(err) } - }(logx.LogError, createLogPushEntry(core.FailedPush, token, req, err), PushConf.Core.FeedbackURL, PushConf.Core.FeedbackTimeout) + }(logx.LogError, createLogPushEntry(cfg, core.FailedPush, token, req, err), cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) } status.StatStorage.AddIosError(1) @@ -438,7 +441,7 @@ Retry: } if res != nil && res.Sent() { - logPush(core.SucceededPush, token, req, nil) + logPush(cfg, core.SucceededPush, token, req, nil) status.StatStorage.AddIosSuccess(1) } // free push slot diff --git a/gorush/notification_apns_test.go b/gorush/notification_apns_test.go index eca561c..748f9cf 100644 --- a/gorush/notification_apns_test.go +++ b/gorush/notification_apns_test.go @@ -1,17 +1,14 @@ package gorush import ( - "context" "encoding/json" "log" "net/http" "net/url" - "os" "testing" "time" "github.com/appleboy/gorush/config" - "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/status" "github.com/buger/jsonparser" "github.com/sideshow/apns2" @@ -27,29 +24,29 @@ const authkeyInvalidP8 = `TUlHSEFnRUFNQk1HQnlxR1NNNDlBZ0VHQ0NxR1NNNDlBd0VIQkcwd2 const authkeyValidP8 = `LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JR0hBZ0VBTUJNR0J5cUdTTTQ5QWdFR0NDcUdTTTQ5QXdFSEJHMHdhd0lCQVFRZ0ViVnpmUG5aUHhmQXl4cUUKWlYwNWxhQW9KQWwrLzZYdDJPNG1PQjYxMXNPaFJBTkNBQVNnRlRLandKQUFVOTVnKysvdnpLV0hrekFWbU5NSQp0QjV2VGpaT09Jd25FYjcwTXNXWkZJeVVGRDFQOUd3c3R6NCtha0hYN3ZJOEJINmhIbUJtZmVRbAotLS0tLUVORCBQUklWQVRFIEtFWS0tLS0tCg==` func TestDisabledAndroidIosConf(t *testing.T) { - PushConf, _ = config.LoadConf("") - PushConf.Android.Enabled = false - PushConf.Huawei.Enabled = false + cfg, _ := config.LoadConf() + cfg.Android.Enabled = false + cfg.Huawei.Enabled = false - err := CheckPushConf() + err := CheckPushConf(cfg) assert.Error(t, err) assert.Equal(t, "Please enable iOS, Android or Huawei config in yml config", err.Error()) } func TestMissingIOSCertificate(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "" - PushConf.Ios.KeyBase64 = "" - err := CheckPushConf() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "" + cfg.Ios.KeyBase64 = "" + err := CheckPushConf(cfg) assert.Error(t, err) assert.Equal(t, "Missing iOS certificate key", err.Error()) - PushConf.Ios.KeyPath = "test.pem" - err = CheckPushConf() + cfg.Ios.KeyPath = "test.pem" + err = CheckPushConf(cfg) assert.Error(t, err) assert.Equal(t, "certificate file does not exist", err.Error()) @@ -566,166 +563,130 @@ func TestIOSAlertNotificationStructure(t *testing.T) { assert.Contains(t, locArgs, "b") } -func TestDisabledIosNotifications(t *testing.T) { - ctx := context.Background() - PushConf, _ = config.LoadConf("") - - PushConf.Ios.Enabled = false - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) - - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") - - androidToken := os.Getenv("ANDROID_TEST_TOKEN") - - req := RequestPush{ - Notifications: []PushNotification{ - // ios - { - Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, - Platform: core.PlatFormIos, - Message: "Welcome", - }, - // android - { - Tokens: []string{androidToken, androidToken + "_"}, - Platform: core.PlatFormAndroid, - Message: "Welcome", - }, - }, - } - - count, logs := HandleNotification(ctx, req) - assert.Equal(t, 2, count) - assert.Equal(t, 0, len(logs)) -} - func TestWrongIosCertificateExt(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "test" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "test" + err := InitAPNSClient(cfg) assert.Error(t, err) assert.Equal(t, "wrong certificate key extension", err.Error()) - PushConf.Ios.KeyPath = "" - PushConf.Ios.KeyBase64 = "abcd" - PushConf.Ios.KeyType = "abcd" - err = InitAPNSClient() + cfg.Ios.KeyPath = "" + cfg.Ios.KeyBase64 = "abcd" + cfg.Ios.KeyType = "abcd" + err = InitAPNSClient(cfg) assert.Error(t, err) assert.Equal(t, "wrong certificate key type", err.Error()) } func TestAPNSClientDevHost(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.p12" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.p12" + err := InitAPNSClient(cfg) assert.Nil(t, err) assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) - PushConf.Ios.KeyPath = "" - PushConf.Ios.KeyBase64 = certificateValidP12 - PushConf.Ios.KeyType = "p12" - err = InitAPNSClient() + cfg.Ios.KeyPath = "" + cfg.Ios.KeyBase64 = certificateValidP12 + cfg.Ios.KeyType = "p12" + err = InitAPNSClient(cfg) assert.Nil(t, err) assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) } func TestAPNSClientProdHost(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.Production = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.Production = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := InitAPNSClient(cfg) assert.Nil(t, err) assert.Equal(t, apns2.HostProduction, ApnsClient.Host) - PushConf.Ios.KeyPath = "" - PushConf.Ios.KeyBase64 = certificateValidPEM - PushConf.Ios.KeyType = "pem" - err = InitAPNSClient() + cfg.Ios.KeyPath = "" + cfg.Ios.KeyBase64 = certificateValidPEM + cfg.Ios.KeyType = "pem" + err = InitAPNSClient(cfg) assert.Nil(t, err) assert.Equal(t, apns2.HostProduction, ApnsClient.Host) } func TestAPNSClientInvaildToken(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/authkey-invalid.p8" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/authkey-invalid.p8" + err := InitAPNSClient(cfg) assert.Error(t, err) - PushConf.Ios.KeyPath = "" - PushConf.Ios.KeyBase64 = authkeyInvalidP8 - PushConf.Ios.KeyType = "p8" - err = InitAPNSClient() + cfg.Ios.KeyPath = "" + cfg.Ios.KeyBase64 = authkeyInvalidP8 + cfg.Ios.KeyType = "p8" + err = InitAPNSClient(cfg) assert.Error(t, err) // empty key-id or team-id - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/authkey-valid.p8" - err = InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/authkey-valid.p8" + err = InitAPNSClient(cfg) assert.Error(t, err) - PushConf.Ios.KeyID = "key-id" - PushConf.Ios.TeamID = "" - err = InitAPNSClient() + cfg.Ios.KeyID = "key-id" + cfg.Ios.TeamID = "" + err = InitAPNSClient(cfg) assert.Error(t, err) - PushConf.Ios.KeyID = "" - PushConf.Ios.TeamID = "team-id" - err = InitAPNSClient() + cfg.Ios.KeyID = "" + cfg.Ios.TeamID = "team-id" + err = InitAPNSClient(cfg) assert.Error(t, err) } func TestAPNSClientVaildToken(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/authkey-valid.p8" - PushConf.Ios.KeyID = "key-id" - PushConf.Ios.TeamID = "team-id" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/authkey-valid.p8" + cfg.Ios.KeyID = "key-id" + cfg.Ios.TeamID = "team-id" + err := InitAPNSClient(cfg) assert.NoError(t, err) assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) - PushConf.Ios.Production = true - err = InitAPNSClient() + cfg.Ios.Production = true + err = InitAPNSClient(cfg) assert.NoError(t, err) assert.Equal(t, apns2.HostProduction, ApnsClient.Host) // test base64 - PushConf.Ios.Production = false - PushConf.Ios.KeyPath = "" - PushConf.Ios.KeyBase64 = authkeyValidP8 - PushConf.Ios.KeyType = "p8" - err = InitAPNSClient() + cfg.Ios.Production = false + cfg.Ios.KeyPath = "" + cfg.Ios.KeyBase64 = authkeyValidP8 + cfg.Ios.KeyType = "p8" + err = InitAPNSClient(cfg) assert.NoError(t, err) assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) - PushConf.Ios.Production = true - err = InitAPNSClient() + cfg.Ios.Production = true + err = InitAPNSClient(cfg) assert.NoError(t, err) assert.Equal(t, apns2.HostProduction, ApnsClient.Host) } func TestAPNSClientUseProxy(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.p12" - PushConf.Core.HTTPProxy = "http://127.0.0.1:8080" - _ = SetProxy(PushConf.Core.HTTPProxy) - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.p12" + cfg.Core.HTTPProxy = "http://127.0.0.1:8080" + _ = SetProxy(cfg.Core.HTTPProxy) + err := InitAPNSClient(cfg) assert.Nil(t, err) assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) @@ -733,13 +694,13 @@ func TestAPNSClientUseProxy(t *testing.T) { actualProxyURL, err := ApnsClient.HTTPClient.Transport.(*http.Transport).Proxy(req) assert.Nil(t, err) - expectedProxyURL, _ := url.ParseRequestURI(PushConf.Core.HTTPProxy) + expectedProxyURL, _ := url.ParseRequestURI(cfg.Core.HTTPProxy) assert.Equal(t, expectedProxyURL, actualProxyURL) - PushConf.Ios.KeyPath = "../certificate/authkey-valid.p8" - PushConf.Ios.TeamID = "example.team" - PushConf.Ios.KeyID = "example.key" - err = InitAPNSClient() + cfg.Ios.KeyPath = "../certificate/authkey-valid.p8" + cfg.Ios.TeamID = "example.team" + cfg.Ios.KeyID = "example.key" + err = InitAPNSClient(cfg) assert.Nil(t, err) assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) assert.NotNil(t, ApnsClient.Token) @@ -748,21 +709,21 @@ func TestAPNSClientUseProxy(t *testing.T) { actualProxyURL, err = ApnsClient.HTTPClient.Transport.(*http.Transport).Proxy(req) assert.Nil(t, err) - expectedProxyURL, _ = url.ParseRequestURI(PushConf.Core.HTTPProxy) + expectedProxyURL, _ = url.ParseRequestURI(cfg.Core.HTTPProxy) assert.Equal(t, expectedProxyURL, actualProxyURL) http.DefaultTransport.(*http.Transport).Proxy = nil } func TestPushToIOS(t *testing.T) { - PushConf, _ = config.LoadConf("") - MaxConcurrentIOSPushes = make(chan struct{}, PushConf.Ios.MaxConcurrentPushes) + cfg, _ := config.LoadConf() + MaxConcurrentIOSPushes = make(chan struct{}, cfg.Ios.MaxConcurrentPushes) - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := InitAPNSClient(cfg) assert.Nil(t, err) - err = status.InitAppStatus(PushConf) + err = status.InitAppStatus(cfg) assert.Nil(t, err) req := PushNotification{ @@ -772,37 +733,37 @@ func TestPushToIOS(t *testing.T) { } // send fail - PushToIOS(req) + PushToIOS(cfg, req) } func TestApnsHostFromRequest(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := InitAPNSClient(cfg) assert.Nil(t, err) - err = status.InitAppStatus(PushConf) + err = status.InitAppStatus(cfg) assert.Nil(t, err) req := PushNotification{ Production: true, } - client := getApnsClient(req) + client := getApnsClient(cfg, req) assert.Equal(t, apns2.HostProduction, client.Host) req = PushNotification{ Development: true, } - client = getApnsClient(req) + client = getApnsClient(cfg, req) assert.Equal(t, apns2.HostDevelopment, client.Host) req = PushNotification{} - PushConf.Ios.Production = true - client = getApnsClient(req) + cfg.Ios.Production = true + client = getApnsClient(cfg, req) assert.Equal(t, apns2.HostProduction, client.Host) - PushConf.Ios.Production = false - client = getApnsClient(req) + cfg.Ios.Production = false + client = getApnsClient(cfg, req) assert.Equal(t, apns2.HostDevelopment, client.Host) } diff --git a/gorush/notification_fcm.go b/gorush/notification_fcm.go index fe65b64..f02b709 100644 --- a/gorush/notification_fcm.go +++ b/gorush/notification_fcm.go @@ -4,27 +4,29 @@ import ( "errors" "fmt" - "github.com/appleboy/go-fcm" + "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/logx" "github.com/appleboy/gorush/status" + + "github.com/appleboy/go-fcm" "github.com/sirupsen/logrus" ) // InitFCMClient use for initialize FCM Client. -func InitFCMClient(key string) (*fcm.Client, error) { +func InitFCMClient(cfg config.ConfYaml, key string) (*fcm.Client, error) { var err error - if key == "" { + if key == "" && cfg.Android.APIKey == "" { return nil, errors.New("Missing Android API Key") } - if key != PushConf.Android.APIKey { + if key != "" && key != cfg.Android.APIKey { return fcm.NewClient(key) } if FCMClient == nil { - FCMClient, err = fcm.NewClient(key) + FCMClient, err = fcm.NewClient(cfg.Android.APIKey) return FCMClient, err } @@ -104,13 +106,13 @@ func GetAndroidNotification(req PushNotification) *fcm.Message { } // PushToAndroid provide send notification to Android server. -func PushToAndroid(req PushNotification) { +func PushToAndroid(cfg config.ConfYaml, req PushNotification) { logx.LogAccess.Debug("Start push notification for Android") var ( client *fcm.Client retryCount = 0 - maxRetry = PushConf.Android.MaxRetry + maxRetry = cfg.Android.MaxRetry ) if req.Retry > 0 && req.Retry < maxRetry { @@ -128,9 +130,9 @@ Retry: notification := GetAndroidNotification(req) if req.APIKey != "" { - client, err = InitFCMClient(req.APIKey) + client, err = InitFCMClient(cfg, req.APIKey) } else { - client, err = InitFCMClient(PushConf.Android.APIKey) + client, err = InitFCMClient(cfg, cfg.Android.APIKey) } if err != nil { @@ -145,28 +147,28 @@ Retry: logx.LogError.Error("FCM server send message error: " + err.Error()) if req.IsTopic() { - if PushConf.Core.Sync { - req.AddLog(createLogPushEntry(core.FailedPush, req.To, req, err)) - } else if PushConf.Core.FeedbackURL != "" { + if cfg.Core.Sync { + req.AddLog(createLogPushEntry(cfg, core.FailedPush, req.To, req, err)) + } else if cfg.Core.FeedbackURL != "" { go func(logger *logrus.Logger, log logx.LogPushEntry, url string, timeout int64) { err := DispatchFeedback(log, url, timeout) if err != nil { logger.Error(err) } - }(logx.LogError, createLogPushEntry(core.FailedPush, req.To, req, err), PushConf.Core.FeedbackURL, PushConf.Core.FeedbackTimeout) + }(logx.LogError, createLogPushEntry(cfg, core.FailedPush, req.To, req, err), cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) } status.StatStorage.AddAndroidError(1) } else { for _, token := range req.Tokens { - if PushConf.Core.Sync { - req.AddLog(createLogPushEntry(core.FailedPush, token, req, err)) - } else if PushConf.Core.FeedbackURL != "" { + if cfg.Core.Sync { + req.AddLog(createLogPushEntry(cfg, core.FailedPush, token, req, err)) + } else if cfg.Core.FeedbackURL != "" { go func(logger *logrus.Logger, log logx.LogPushEntry, url string, timeout int64) { err := DispatchFeedback(log, url, timeout) if err != nil { logger.Error(err) } - }(logx.LogError, createLogPushEntry(core.FailedPush, token, req, err), PushConf.Core.FeedbackURL, PushConf.Core.FeedbackTimeout) + }(logx.LogError, createLogPushEntry(cfg, core.FailedPush, token, req, err), cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) } } status.StatStorage.AddAndroidError(int64(len(req.Tokens))) @@ -198,21 +200,21 @@ Retry: newTokens = append(newTokens, to) } - logPush(core.FailedPush, to, req, result.Error) - if PushConf.Core.Sync { - req.AddLog(createLogPushEntry(core.FailedPush, to, req, result.Error)) - } else if PushConf.Core.FeedbackURL != "" { + logPush(cfg, core.FailedPush, to, req, result.Error) + if cfg.Core.Sync { + req.AddLog(createLogPushEntry(cfg, core.FailedPush, to, req, result.Error)) + } else if cfg.Core.FeedbackURL != "" { go func(logger *logrus.Logger, log logx.LogPushEntry, url string, timeout int64) { err := DispatchFeedback(log, url, timeout) if err != nil { logger.Error(err) } - }(logx.LogError, createLogPushEntry(core.FailedPush, to, req, result.Error), PushConf.Core.FeedbackURL, PushConf.Core.FeedbackTimeout) + }(logx.LogError, createLogPushEntry(cfg, core.FailedPush, to, req, result.Error), cfg.Core.FeedbackURL, cfg.Core.FeedbackTimeout) } continue } - logPush(core.SucceededPush, to, req, nil) + logPush(cfg, core.SucceededPush, to, req, nil) } // result from Send messages to topics @@ -226,12 +228,12 @@ Retry: logx.LogAccess.Debug("Send Topic Message: ", to) // Success if res.MessageID != 0 { - logPush(core.SucceededPush, to, req, nil) + logPush(cfg, core.SucceededPush, to, req, nil) } else { // failure - logPush(core.FailedPush, to, req, res.Error) - if PushConf.Core.Sync { - req.AddLog(createLogPushEntry(core.FailedPush, to, req, res.Error)) + logPush(cfg, core.FailedPush, to, req, res.Error) + if cfg.Core.Sync { + req.AddLog(createLogPushEntry(cfg, core.FailedPush, to, req, res.Error)) } } } @@ -240,9 +242,9 @@ Retry: if len(res.FailedRegistrationIDs) > 0 { newTokens = append(newTokens, res.FailedRegistrationIDs...) - logPush(core.FailedPush, notification.To, req, errors.New("device group: partial success or all fails")) - if PushConf.Core.Sync { - req.AddLog(createLogPushEntry(core.FailedPush, notification.To, req, errors.New("device group: partial success or all fails"))) + logPush(cfg, core.FailedPush, notification.To, req, errors.New("device group: partial success or all fails")) + if cfg.Core.Sync { + req.AddLog(createLogPushEntry(cfg, core.FailedPush, notification.To, req, errors.New("device group: partial success or all fails"))) } } @@ -255,7 +257,7 @@ Retry: } } -func createLogPushEntry(status, token string, req PushNotification, err error) logx.LogPushEntry { +func createLogPushEntry(cfg config.ConfYaml, status, token string, req PushNotification, err error) logx.LogPushEntry { return logx.GetLogPushEntry(&logx.InputLog{ ID: req.ID, Status: status, @@ -263,12 +265,12 @@ func createLogPushEntry(status, token string, req PushNotification, err error) l Message: req.Message, Platform: req.Platform, Error: err, - HideToken: PushConf.Log.HideToken, - Format: PushConf.Log.Format, + HideToken: cfg.Log.HideToken, + Format: cfg.Log.Format, }) } -func logPush(status, token string, req PushNotification, err error) { +func logPush(cfg config.ConfYaml, status, token string, req PushNotification, err error) { logx.LogPush(&logx.InputLog{ ID: req.ID, Status: status, @@ -276,7 +278,7 @@ func logPush(status, token string, req PushNotification, err error) { Message: req.Message, Platform: req.Platform, Error: err, - HideToken: PushConf.Log.HideToken, - Format: PushConf.Log.Format, + HideToken: cfg.Log.HideToken, + Format: cfg.Log.Format, }) } diff --git a/gorush/notification_fcm_test.go b/gorush/notification_fcm_test.go index ff324a6..a3bd9f9 100644 --- a/gorush/notification_fcm_test.go +++ b/gorush/notification_fcm_test.go @@ -13,19 +13,21 @@ import ( ) func TestMissingAndroidAPIKey(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Android.Enabled = true - PushConf.Android.APIKey = "" + cfg.Android.Enabled = true + cfg.Android.APIKey = "" - err := CheckPushConf() + err := CheckPushConf(cfg) assert.Error(t, err) assert.Equal(t, "Missing Android API Key", err.Error()) } func TestMissingKeyForInitFCMClient(t *testing.T) { - client, err := InitFCMClient("") + cfg, _ := config.LoadConf() + cfg.Android.APIKey = "" + client, err := InitFCMClient(cfg, "") assert.Nil(t, client) assert.Error(t, err) @@ -33,10 +35,10 @@ func TestMissingKeyForInitFCMClient(t *testing.T) { } func TestPushToAndroidWrongToken(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") req := PushNotification{ Tokens: []string{"aaaaaa", "bbbbb"}, @@ -45,16 +47,16 @@ func TestPushToAndroidWrongToken(t *testing.T) { } // Android Success count: 0, Failure count: 2 - PushToAndroid(req) + PushToAndroid(cfg, req) } func TestPushToAndroidRightTokenForJSONLog(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") // log for json - PushConf.Log.Format = "json" + cfg.Log.Format = "json" androidToken := os.Getenv("ANDROID_TEST_TOKEN") @@ -64,14 +66,14 @@ func TestPushToAndroidRightTokenForJSONLog(t *testing.T) { Message: "Welcome", } - PushToAndroid(req) + PushToAndroid(cfg, req) } func TestPushToAndroidRightTokenForStringLog(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") androidToken := os.Getenv("ANDROID_TEST_TOKEN") @@ -81,15 +83,15 @@ func TestPushToAndroidRightTokenForStringLog(t *testing.T) { Message: "Welcome", } - PushToAndroid(req) + PushToAndroid(cfg, req) } func TestOverwriteAndroidAPIKey(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Core.Sync = true - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Core.Sync = true + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") androidToken := os.Getenv("ANDROID_TEST_TOKEN") @@ -100,13 +102,13 @@ func TestOverwriteAndroidAPIKey(t *testing.T) { // overwrite android api key APIKey: "1234", - log: &[]logx.LogPushEntry{}, + Log: &[]logx.LogPushEntry{}, } // FCM server error: 401 error: 401 Unauthorized (Wrong API Key) - PushToAndroid(req) + PushToAndroid(cfg, req) - assert.Len(t, *req.log, 2) + assert.Len(t, *req.Log, 2) } func TestFCMMessage(t *testing.T) { @@ -188,10 +190,10 @@ func TestFCMMessage(t *testing.T) { } func TestCheckAndroidMessage(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") timeToLive := uint(2419201) req := PushNotification{ @@ -201,7 +203,7 @@ func TestCheckAndroidMessage(t *testing.T) { TimeToLive: &timeToLive, } - PushToAndroid(req) + PushToAndroid(cfg, req) } func TestAndroidNotificationStructure(t *testing.T) { diff --git a/gorush/notification_hms.go b/gorush/notification_hms.go index 3fb03f7..8447e96 100644 --- a/gorush/notification_hms.go +++ b/gorush/notification_hms.go @@ -6,10 +6,11 @@ import ( "errors" "sync" + "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/logx" "github.com/appleboy/gorush/status" - "github.com/msalihkarakasli/go-hms-push/push/config" + c "github.com/msalihkarakasli/go-hms-push/push/config" "github.com/msalihkarakasli/go-hms-push/push/core" "github.com/msalihkarakasli/go-hms-push/push/model" ) @@ -21,7 +22,7 @@ var ( ) // GetPushClient use for create HMS Push -func GetPushClient(conf *config.Config) (*core.HMSClient, error) { +func GetPushClient(conf *c.Config) (*core.HMSClient, error) { once.Do(func() { client, err := core.NewHttpClient(conf) if err != nil { @@ -35,7 +36,7 @@ func GetPushClient(conf *config.Config) (*core.HMSClient, error) { } // InitHMSClient use for initialize HMS Client. -func InitHMSClient(appSecret, appID string) (*core.HMSClient, error) { +func InitHMSClient(cfg config.ConfYaml, appSecret, appID string) (*core.HMSClient, error) { if appSecret == "" { return nil, errors.New("Missing Huawei App Secret") } @@ -44,14 +45,14 @@ func InitHMSClient(appSecret, appID string) (*core.HMSClient, error) { return nil, errors.New("Missing Huawei App ID") } - conf := &config.Config{ + conf := &c.Config{ AppId: appID, AppSecret: appSecret, AuthUrl: "https://oauth-login.cloud.huawei.com/oauth2/v3/token", PushUrl: "https://push-api.cloud.huawei.com", } - if appSecret != PushConf.Huawei.AppSecret || appID != PushConf.Huawei.AppID { + if appSecret != cfg.Huawei.AppSecret || appID != cfg.Huawei.AppID { return GetPushClient(conf) } @@ -165,13 +166,13 @@ func GetHuaweiNotification(req PushNotification) (*model.MessageRequest, error) } // PushToHuawei provide send notification to Android server. -func PushToHuawei(req PushNotification) bool { +func PushToHuawei(cfg config.ConfYaml, req PushNotification) bool { logx.LogAccess.Debug("Start push notification for Huawei") var ( client *core.HMSClient retryCount = 0 - maxRetry = PushConf.Huawei.MaxRetry + maxRetry = cfg.Huawei.MaxRetry ) if req.Retry > 0 && req.Retry < maxRetry { @@ -190,7 +191,7 @@ Retry: notification, _ := GetHuaweiNotification(req) - client, err = InitHMSClient(PushConf.Huawei.AppSecret, PushConf.Huawei.AppID) + client, err = InitHMSClient(cfg, cfg.Huawei.AppSecret, cfg.Huawei.AppID) if err != nil { // HMS server error diff --git a/gorush/notification_hms_test.go b/gorush/notification_hms_test.go index d939dfb..9e00fe3 100644 --- a/gorush/notification_hms_test.go +++ b/gorush/notification_hms_test.go @@ -8,31 +8,32 @@ import ( ) func TestMissingHuaweiAppSecret(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Huawei.Enabled = true - PushConf.Huawei.AppSecret = "" + cfg.Huawei.Enabled = true + cfg.Huawei.AppSecret = "" - err := CheckPushConf() + err := CheckPushConf(cfg) assert.Error(t, err) assert.Equal(t, "Missing Huawei App Secret", err.Error()) } func TestMissingHuaweiAppID(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Huawei.Enabled = true - PushConf.Huawei.AppID = "" + cfg.Huawei.Enabled = true + cfg.Huawei.AppID = "" - err := CheckPushConf() + err := CheckPushConf(cfg) assert.Error(t, err) assert.Equal(t, "Missing Huawei App ID", err.Error()) } func TestMissingAppSecretForInitHMSClient(t *testing.T) { - client, err := InitHMSClient("", "APP_SECRET") + cfg, _ := config.LoadConf() + client, err := InitHMSClient(cfg, "", "APP_SECRET") assert.Nil(t, client) assert.Error(t, err) @@ -40,7 +41,8 @@ func TestMissingAppSecretForInitHMSClient(t *testing.T) { } func TestMissingAppIDForInitHMSClient(t *testing.T) { - client, err := InitHMSClient("APP_ID", "") + cfg, _ := config.LoadConf() + client, err := InitHMSClient(cfg, "APP_ID", "") assert.Nil(t, client) assert.Error(t, err) diff --git a/gorush/notification_test.go b/gorush/notification_test.go index d00ae27..7e77562 100644 --- a/gorush/notification_test.go +++ b/gorush/notification_test.go @@ -1,210 +1,26 @@ package gorush import ( - "context" - "os" "testing" "github.com/appleboy/gorush/config" - "github.com/appleboy/gorush/core" "github.com/stretchr/testify/assert" ) func TestCorrectConf(t *testing.T) { - PushConf, _ = config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Android.Enabled = true - PushConf.Android.APIKey = "xxxxx" + cfg.Android.Enabled = true + cfg.Android.APIKey = "xxxxx" - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := CheckPushConf() + err := CheckPushConf(cfg) assert.NoError(t, err) } -func TestSenMultipleNotifications(t *testing.T) { - ctx := context.Background() - PushConf, _ = config.LoadConf("") - - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) - - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") - - androidToken := os.Getenv("ANDROID_TEST_TOKEN") - - req := RequestPush{ - Notifications: []PushNotification{ - // ios - { - Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, - Platform: core.PlatFormIos, - Message: "Welcome", - }, - // android - { - Tokens: []string{androidToken, "bbbbb"}, - Platform: core.PlatFormAndroid, - Message: "Welcome", - }, - }, - } - - count, logs := HandleNotification(ctx, req) - assert.Equal(t, 3, count) - assert.Equal(t, 0, len(logs)) -} - -func TestDisabledAndroidNotifications(t *testing.T) { - ctx := context.Background() - PushConf, _ = config.LoadConf("") - - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) - - PushConf.Android.Enabled = false - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") - - androidToken := os.Getenv("ANDROID_TEST_TOKEN") - - req := RequestPush{ - Notifications: []PushNotification{ - // ios - { - Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, - Platform: core.PlatFormIos, - Message: "Welcome", - }, - // android - { - Tokens: []string{androidToken, "bbbbb"}, - Platform: core.PlatFormAndroid, - Message: "Welcome", - }, - }, - } - - count, logs := HandleNotification(ctx, req) - assert.Equal(t, 1, count) - assert.Equal(t, 0, len(logs)) -} - -func TestSyncModeForNotifications(t *testing.T) { - ctx := context.Background() - PushConf, _ = config.LoadConf("") - - PushConf.Ios.Enabled = true - PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) - - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") - - // enable sync mode - PushConf.Core.Sync = true - - androidToken := os.Getenv("ANDROID_TEST_TOKEN") - - req := RequestPush{ - Notifications: []PushNotification{ - // ios - { - Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, - Platform: core.PlatFormIos, - Message: "Welcome", - }, - // android - { - Tokens: []string{androidToken, "bbbbb"}, - Platform: core.PlatFormAndroid, - Message: "Welcome", - }, - }, - } - - count, logs := HandleNotification(ctx, req) - assert.Equal(t, 3, count) - assert.Equal(t, 2, len(logs)) -} - -func TestSyncModeForTopicNotification(t *testing.T) { - ctx := context.Background() - PushConf, _ = config.LoadConf("") - - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") - PushConf.Log.HideToken = false - - // enable sync mode - PushConf.Core.Sync = true - - req := RequestPush{ - Notifications: []PushNotification{ - // android - { - // error:InvalidParameters - // Check that the provided parameters have the right name and type. - To: "/topics/foo-bar@@@##", - Platform: core.PlatFormAndroid, - Message: "This is a Firebase Cloud Messaging Topic Message!", - }, - // android - { - // success - To: "/topics/foo-bar", - Platform: core.PlatFormAndroid, - Message: "This is a Firebase Cloud Messaging Topic Message!", - }, - // android - { - // success - Condition: "'dogs' in topics || 'cats' in topics", - Platform: core.PlatFormAndroid, - Message: "This is a Firebase Cloud Messaging Topic Message!", - }, - }, - } - - count, logs := HandleNotification(ctx, req) - assert.Equal(t, 2, count) - assert.Equal(t, 1, len(logs)) -} - -func TestSyncModeForDeviceGroupNotification(t *testing.T) { - ctx := context.Background() - PushConf, _ = config.LoadConf("") - - PushConf.Android.Enabled = true - PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") - PushConf.Log.HideToken = false - - // enable sync mode - PushConf.Core.Sync = true - - req := RequestPush{ - Notifications: []PushNotification{ - // android - { - To: "aUniqueKey", - Platform: core.PlatFormAndroid, - Message: "This is a Firebase Cloud Messaging Device Group Message!", - }, - }, - } - - count, logs := HandleNotification(ctx, req) - assert.Equal(t, 1, count) - assert.Equal(t, 1, len(logs)) -} - func TestSetProxyURL(t *testing.T) { err := SetProxy("87.236.233.92:8080") assert.Error(t, err) diff --git a/gorush/worker.go b/gorush/worker.go deleted file mode 100644 index 6ca0cdb..0000000 --- a/gorush/worker.go +++ /dev/null @@ -1,124 +0,0 @@ -package gorush - -import ( - "context" - "errors" - "sync" - - "github.com/appleboy/gorush/core" - "github.com/appleboy/gorush/logx" - "github.com/appleboy/gorush/status" -) - -// InitWorkers for initialize all workers. -func InitWorkers(ctx context.Context, wg *sync.WaitGroup, workerNum, queueNum int64) { - logx.LogAccess.Info("worker number is ", workerNum, ", queue number is ", queueNum) - QueueNotification = make(chan PushNotification, queueNum) - for i := int64(0); i < workerNum; i++ { - go startWorker(ctx, wg, i) - } -} - -// SendNotification is send message to iOS, Android or Huawei -func SendNotification(ctx context.Context, req PushNotification) { - if PushConf.Core.Sync { - defer req.WaitDone() - } - - switch req.Platform { - case core.PlatFormIos: - PushToIOS(req) - case core.PlatFormAndroid: - PushToAndroid(req) - case core.PlatFormHuawei: - PushToHuawei(req) - } -} - -func startWorker(ctx context.Context, wg *sync.WaitGroup, num int64) { - defer wg.Done() - for notification := range QueueNotification { - SendNotification(ctx, notification) - } - logx.LogAccess.Info("closed the worker num ", num) -} - -// markFailedNotification adds failure logs for all tokens in push notification -func markFailedNotification(notification *PushNotification, reason string) { - logx.LogError.Error(reason) - for _, token := range notification.Tokens { - notification.AddLog(logx.GetLogPushEntry(&logx.InputLog{ - ID: notification.ID, - Status: core.FailedPush, - Token: token, - Message: notification.Message, - Platform: notification.Platform, - Error: errors.New(reason), - HideToken: PushConf.Log.HideToken, - Format: PushConf.Log.Format, - })) - } - notification.WaitDone() -} - -// HandleNotification add notification to queue list. -func HandleNotification(ctx context.Context, req RequestPush) (int, []logx.LogPushEntry) { - var count int - wg := sync.WaitGroup{} - newNotification := []*PushNotification{} - for i := range req.Notifications { - notification := &req.Notifications[i] - switch notification.Platform { - case core.PlatFormIos: - if !PushConf.Ios.Enabled { - continue - } - case core.PlatFormAndroid: - if !PushConf.Android.Enabled { - continue - } - case core.PlatFormHuawei: - if !PushConf.Huawei.Enabled { - continue - } - } - newNotification = append(newNotification, notification) - } - - log := make([]logx.LogPushEntry, 0, count) - for _, notification := range newNotification { - if PushConf.Core.Sync { - notification.wg = &wg - notification.log = &log - notification.AddWaitCount() - } - if !tryEnqueue(*notification, QueueNotification) { - markFailedNotification(notification, "max capacity reached") - } - count += len(notification.Tokens) - // Count topic message - if notification.To != "" { - count++ - } - } - - if PushConf.Core.Sync { - wg.Wait() - } - - status.StatStorage.AddTotalCount(int64(count)) - - return count, log -} - -// tryEnqueue tries to enqueue a job to the given job channel. Returns true if -// the operation was successful, and false if enqueuing would not have been -// possible without blocking. Job is not enqueued in the latter case. -func tryEnqueue(job PushNotification, jobChan chan<- PushNotification) bool { - select { - case jobChan <- job: - return true - default: - return false - } -} diff --git a/gorush/worker_test.go b/gorush/worker_test.go deleted file mode 100644 index 1f22319..0000000 --- a/gorush/worker_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package gorush - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestTryEnqueue(t *testing.T) { - chn := make(chan PushNotification, 2) - assert.True(t, tryEnqueue(PushNotification{}, chn)) - assert.Equal(t, 1, len(chn)) - assert.True(t, tryEnqueue(PushNotification{}, chn)) - assert.Equal(t, 2, len(chn)) - assert.False(t, tryEnqueue(PushNotification{}, chn)) - assert.Equal(t, 2, len(chn)) -} diff --git a/logx/log_test.go b/logx/log_test.go index b9608b2..ee1ed33 100644 --- a/logx/log_test.go +++ b/logx/log_test.go @@ -38,75 +38,75 @@ func TestSetLogOut(t *testing.T) { } func TestInitDefaultLog(t *testing.T) { - PushConf, _ := config.LoadConf("") + cfg, _ := config.LoadConf() // no errors on default config assert.Nil(t, InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, )) - PushConf.Log.AccessLevel = "invalid" + cfg.Log.AccessLevel = "invalid" assert.NotNil(t, InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, )) } func TestAccessLevel(t *testing.T) { - PushConf, _ := config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Log.AccessLevel = "invalid" + cfg.Log.AccessLevel = "invalid" assert.NotNil(t, InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, )) } func TestErrorLevel(t *testing.T) { - PushConf, _ := config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Log.ErrorLevel = "invalid" + cfg.Log.ErrorLevel = "invalid" assert.NotNil(t, InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, )) } func TestAccessLogPath(t *testing.T) { - PushConf, _ := config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Log.AccessLog = "logs/access.log" + cfg.Log.AccessLog = "logs/access.log" assert.NotNil(t, InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, )) } func TestErrorLogPath(t *testing.T) { - PushConf, _ := config.LoadConf("") + cfg, _ := config.LoadConf() - PushConf.Log.ErrorLog = "logs/error.log" + cfg.Log.ErrorLog = "logs/error.log" assert.NotNil(t, InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, )) } diff --git a/main.go b/main.go index 79ca450..089934e 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,6 @@ import ( "os/signal" "path/filepath" "strconv" - "sync" "syscall" "time" @@ -19,6 +18,7 @@ import ( "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/gorush" "github.com/appleboy/gorush/logx" + "github.com/appleboy/gorush/queue" "github.com/appleboy/gorush/router" "github.com/appleboy/gorush/rpc" "github.com/appleboy/gorush/status" @@ -105,10 +105,8 @@ func main() { os.Exit(0) } - var err error - // set default parameters. - gorush.PushConf, err = config.LoadConf(configFile) + cfg, err := config.LoadConf(configFile) if err != nil { log.Printf("Load yaml config file error: '%v'", err) @@ -116,67 +114,67 @@ func main() { } // Initialize push slots for concurrent iOS pushes - gorush.MaxConcurrentIOSPushes = make(chan struct{}, gorush.PushConf.Ios.MaxConcurrentPushes) + gorush.MaxConcurrentIOSPushes = make(chan struct{}, cfg.Ios.MaxConcurrentPushes) if opts.Ios.KeyPath != "" { - gorush.PushConf.Ios.KeyPath = opts.Ios.KeyPath + cfg.Ios.KeyPath = opts.Ios.KeyPath } if opts.Ios.KeyID != "" { - gorush.PushConf.Ios.KeyID = opts.Ios.KeyID + cfg.Ios.KeyID = opts.Ios.KeyID } if opts.Ios.TeamID != "" { - gorush.PushConf.Ios.TeamID = opts.Ios.TeamID + cfg.Ios.TeamID = opts.Ios.TeamID } if opts.Ios.Password != "" { - gorush.PushConf.Ios.Password = opts.Ios.Password + cfg.Ios.Password = opts.Ios.Password } if opts.Android.APIKey != "" { - gorush.PushConf.Android.APIKey = opts.Android.APIKey + cfg.Android.APIKey = opts.Android.APIKey } if opts.Huawei.AppSecret != "" { - gorush.PushConf.Huawei.AppSecret = opts.Huawei.AppSecret + cfg.Huawei.AppSecret = opts.Huawei.AppSecret } if opts.Huawei.AppID != "" { - gorush.PushConf.Huawei.AppID = opts.Huawei.AppID + cfg.Huawei.AppID = opts.Huawei.AppID } if opts.Stat.Engine != "" { - gorush.PushConf.Stat.Engine = opts.Stat.Engine + cfg.Stat.Engine = opts.Stat.Engine } if opts.Stat.Redis.Addr != "" { - gorush.PushConf.Stat.Redis.Addr = opts.Stat.Redis.Addr + cfg.Stat.Redis.Addr = opts.Stat.Redis.Addr } // overwrite server port and address if opts.Core.Port != "" { - gorush.PushConf.Core.Port = opts.Core.Port + cfg.Core.Port = opts.Core.Port } if opts.Core.Address != "" { - gorush.PushConf.Core.Address = opts.Core.Address + cfg.Core.Address = opts.Core.Address } if err = logx.InitLog( - gorush.PushConf.Log.AccessLevel, - gorush.PushConf.Log.AccessLog, - gorush.PushConf.Log.ErrorLevel, - gorush.PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, ); err != nil { log.Fatalf("Can't load log module, error: %v", err) } if opts.Core.HTTPProxy != "" { - gorush.PushConf.Core.HTTPProxy = opts.Core.HTTPProxy + cfg.Core.HTTPProxy = opts.Core.HTTPProxy } - if gorush.PushConf.Core.HTTPProxy != "" { - err = gorush.SetProxy(gorush.PushConf.Core.HTTPProxy) + if cfg.Core.HTTPProxy != "" { + err = gorush.SetProxy(cfg.Core.HTTPProxy) if err != nil { logx.LogError.Fatalf("Set Proxy error: %v", err) @@ -184,7 +182,7 @@ func main() { } if ping { - if err := pinger(); err != nil { + if err := pinger(cfg); err != nil { logx.LogError.Warnf("ping server error: %v", err) } return @@ -192,7 +190,7 @@ func main() { // send android notification if opts.Android.Enabled { - gorush.PushConf.Android.Enabled = opts.Android.Enabled + cfg.Android.Enabled = opts.Android.Enabled req := gorush.PushNotification{ Platform: core.PlatFormAndroid, Message: message, @@ -214,18 +212,18 @@ func main() { logx.LogError.Fatal(err) } - if err := status.InitAppStatus(gorush.PushConf); err != nil { + if err := status.InitAppStatus(cfg); err != nil { return } - gorush.PushToAndroid(req) + gorush.PushToAndroid(cfg, req) return } // send huawei notification if opts.Huawei.Enabled { - gorush.PushConf.Huawei.Enabled = opts.Huawei.Enabled + cfg.Huawei.Enabled = opts.Huawei.Enabled req := gorush.PushNotification{ Platform: core.PlatFormHuawei, Message: message, @@ -247,11 +245,11 @@ func main() { logx.LogError.Fatal(err) } - if err := status.InitAppStatus(gorush.PushConf); err != nil { + if err := status.InitAppStatus(cfg); err != nil { return } - gorush.PushToHuawei(req) + gorush.PushToHuawei(cfg, req) return } @@ -259,10 +257,10 @@ func main() { // send ios notification if opts.Ios.Enabled { if opts.Ios.Production { - gorush.PushConf.Ios.Production = opts.Ios.Production + cfg.Ios.Production = opts.Ios.Production } - gorush.PushConf.Ios.Enabled = opts.Ios.Enabled + cfg.Ios.Enabled = opts.Ios.Enabled req := gorush.PushNotification{ Platform: core.PlatFormIos, Message: message, @@ -284,68 +282,71 @@ func main() { logx.LogError.Fatal(err) } - if err := status.InitAppStatus(gorush.PushConf); err != nil { + if err := status.InitAppStatus(cfg); err != nil { return } - if err := gorush.InitAPNSClient(); err != nil { + if err := gorush.InitAPNSClient(cfg); err != nil { return } - gorush.PushToIOS(req) + gorush.PushToIOS(cfg, req) return } - if err = gorush.CheckPushConf(); err != nil { + if err = gorush.CheckPushConf(cfg); err != nil { logx.LogError.Fatal(err) } if opts.Core.PID.Path != "" { - gorush.PushConf.Core.PID.Path = opts.Core.PID.Path - gorush.PushConf.Core.PID.Enabled = true - gorush.PushConf.Core.PID.Override = true + cfg.Core.PID.Path = opts.Core.PID.Path + cfg.Core.PID.Enabled = true + cfg.Core.PID.Override = true } - if err = createPIDFile(); err != nil { + if err = createPIDFile(cfg); err != nil { logx.LogError.Fatal(err) } - if err = status.InitAppStatus(gorush.PushConf); err != nil { + if err = status.InitAppStatus(cfg); err != nil { logx.LogError.Fatal(err) } + q := queue.NewQueue(cfg) + q.Start() + finished := make(chan struct{}) - wg := &sync.WaitGroup{} - wg.Add(int(gorush.PushConf.Core.WorkerNum)) ctx := withContextFunc(context.Background(), func() { - logx.LogAccess.Info("close the notification queue channel, current queue len: ", len(gorush.QueueNotification)) - close(gorush.QueueNotification) - wg.Wait() - logx.LogAccess.Info("the notification queue has been clear") + logx.LogAccess.Info("close the queue system, current queue usage: ", q.Usage()) + // stop queue system + q.Stop() + // wait job completed + q.Wait() close(finished) // close the connection with storage - logx.LogAccess.Info("close the storage connection: ", gorush.PushConf.Stat.Engine) + logx.LogAccess.Info("close the storage connection: ", cfg.Stat.Engine) if err := status.StatStorage.Close(); err != nil { logx.LogError.Fatal("can't close the storage connection: ", err.Error()) } }) - gorush.InitWorkers(ctx, wg, gorush.PushConf.Core.WorkerNum, gorush.PushConf.Core.QueueNum) + // gorush.InitQueue(cfg.Core.WorkerNum, cfg.Core.QueueNum) + // gorush.InitWorkers(ctx, wg, cfg.Core.WorkerNum, cfg.Core.QueueNum) - if gorush.PushConf.Ios.Enabled { - if err = gorush.InitAPNSClient(); err != nil { + if cfg.Ios.Enabled { + if err = gorush.InitAPNSClient(cfg); err != nil { logx.LogError.Fatal(err) } } - if gorush.PushConf.Android.Enabled { - if _, err = gorush.InitFCMClient(gorush.PushConf.Android.APIKey); err != nil { + if cfg.Android.Enabled { + if _, err = gorush.InitFCMClient(cfg, cfg.Android.APIKey); err != nil { logx.LogError.Fatal(err) } } - if gorush.PushConf.Huawei.Enabled { - if _, err = gorush.InitHMSClient(gorush.PushConf.Huawei.AppSecret, gorush.PushConf.Huawei.AppID); err != nil { + if cfg.Huawei.Enabled { + if _, err = gorush.InitHMSClient(cfg, cfg.Huawei.AppSecret, cfg.Huawei.AppID); err != nil { logx.LogError.Fatal(err) } } @@ -354,12 +355,12 @@ func main() { // Run httpd server g.Go(func() error { - return router.RunHTTPServer(ctx, gorush.PushConf) + return router.RunHTTPServer(ctx, cfg, q) }) // Run gRPC internal server g.Go(func() error { - return rpc.RunGRPCServer(ctx) + return rpc.RunGRPCServer(ctx, cfg) }) // check job completely @@ -426,7 +427,7 @@ func usage() { // handles pinging the endpoint and returns an error if the // agent is in an unhealthy state. -func pinger() error { +func pinger(cfg config.ConfYaml) error { transport := &http.Transport{ Dial: (&net.Dialer{ Timeout: 5 * time.Second, @@ -437,7 +438,7 @@ func pinger() error { Timeout: time.Second * 10, Transport: transport, } - resp, err := client.Get("http://localhost:" + gorush.PushConf.Core.Port + gorush.PushConf.API.HealthURI) + resp, err := client.Get("http://localhost:" + cfg.Core.Port + cfg.API.HealthURI) if err != nil { return err } @@ -448,14 +449,14 @@ func pinger() error { return nil } -func createPIDFile() error { - if !gorush.PushConf.Core.PID.Enabled { +func createPIDFile(cfg config.ConfYaml) error { + if !cfg.Core.PID.Enabled { return nil } - pidPath := gorush.PushConf.Core.PID.Path + pidPath := cfg.Core.PID.Path _, err := os.Stat(pidPath) - if os.IsNotExist(err) || gorush.PushConf.Core.PID.Override { + if os.IsNotExist(err) || cfg.Core.PID.Override { currentPid := os.Getpid() if err := os.MkdirAll(filepath.Dir(pidPath), os.ModePerm); err != nil { return fmt.Errorf("Can't create PID folder on %v", err) diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 0000000..5cc6fc8 --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,89 @@ +package queue + +import ( + "runtime" + + "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/logx" + "github.com/appleboy/gorush/queue/simple" +) + +type ( + // A Queue is a message queue. + Queue struct { + workerCount int + queueCount int + routineGroup *routineGroup + quit chan struct{} + worker Worker + } +) + +// NewQueue returns a Queue. +func NewQueue(cfg config.ConfYaml) *Queue { + q := &Queue{ + workerCount: int(cfg.Core.WorkerNum), + queueCount: int(cfg.Core.QueueNum), + routineGroup: newRoutineGroup(), + quit: make(chan struct{}), + worker: simple.NewWorker(cfg), + } + + if q.workerCount != 0 { + q.workerCount = runtime.NumCPU() + } + + if q.queueCount == 0 { + q.queueCount = runtime.NumCPU() << 1 + } + + return q +} + +// Capacity for queue max size +func (q *Queue) Capacity() int { + return q.worker.Capacity() +} + +// Usage for count of queue usage +func (q *Queue) Usage() int { + return q.worker.Usage() +} + +// Config update current config +func (q *Queue) Config(cfg config.ConfYaml) { + q.worker.Config(cfg) +} + +// Start to enable all worker +func (q *Queue) Start() { + q.startWorker() +} + +// Stop stops q. +func (q *Queue) Stop() { + q.worker.Stop() + close(q.quit) +} + +// Wait all process +func (q *Queue) Wait() { + q.routineGroup.Wait() +} + +// Enqueue queue all job +func (q *Queue) Enqueue(job interface{}) error { + return q.worker.Enqueue(job) +} + +func (q *Queue) startWorker() { + for i := 0; i < q.workerCount; i++ { + go func(num int) { + q.routineGroup.Run(func() { + logx.LogAccess.Info("started the worker num ", num) + q.worker.Run(q.quit) + logx.LogAccess.Info("closed the worker num ", num) + }) + }(i) + } +} diff --git a/queue/simple/simple.go b/queue/simple/simple.go new file mode 100644 index 0000000..e8e4988 --- /dev/null +++ b/queue/simple/simple.go @@ -0,0 +1,59 @@ +package simple + +import ( + "errors" + + "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/gorush" +) + +// Worker for simple queue using channel +type Worker struct { + cfg config.ConfYaml + queueNotification chan gorush.PushNotification +} + +// Run start the worker +func (s *Worker) Run(_ chan struct{}) { + for notification := range s.queueNotification { + gorush.SendNotification(s.cfg, notification) + } +} + +// Stop worker +func (s *Worker) Stop() { + close(s.queueNotification) +} + +// Capacity for channel +func (s *Worker) Capacity() int { + return cap(s.queueNotification) +} + +// Usage for count of channel usage +func (s *Worker) Usage() int { + return len(s.queueNotification) +} + +// Enqueue send notification to queue +func (s *Worker) Enqueue(job interface{}) error { + select { + case s.queueNotification <- job.(gorush.PushNotification): + return nil + default: + return errors.New("max capacity reached") + } +} + +// Config update current config +func (s *Worker) Config(cfg config.ConfYaml) { + s.cfg = cfg +} + +// NewWorker for struct +func NewWorker(cfg config.ConfYaml) *Worker { + return &Worker{ + cfg: cfg, + queueNotification: make(chan gorush.PushNotification, cfg.Core.QueueNum), + } +} diff --git a/queue/thread.go b/queue/thread.go new file mode 100644 index 0000000..473c351 --- /dev/null +++ b/queue/thread.go @@ -0,0 +1,24 @@ +package queue + +import "sync" + +type routineGroup struct { + waitGroup sync.WaitGroup +} + +func newRoutineGroup() *routineGroup { + return new(routineGroup) +} + +func (g *routineGroup) Run(fn func()) { + g.waitGroup.Add(1) + + go func() { + defer g.waitGroup.Done() + fn() + }() +} + +func (g *routineGroup) Wait() { + g.waitGroup.Wait() +} diff --git a/queue/worker.go b/queue/worker.go new file mode 100644 index 0000000..39cf42b --- /dev/null +++ b/queue/worker.go @@ -0,0 +1,13 @@ +package queue + +import "github.com/appleboy/gorush/config" + +// Worker interface +type Worker interface { + Run(chan struct{}) + Stop() + Enqueue(job interface{}) error + Capacity() int + Usage() int + Config(config.ConfYaml) +} diff --git a/router/server.go b/router/server.go index a77e4d3..8fa3164 100644 --- a/router/server.go +++ b/router/server.go @@ -3,14 +3,18 @@ package router import ( "context" "crypto/tls" + "errors" "fmt" "net/http" "os" + "sync" "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/gorush" "github.com/appleboy/gorush/logx" "github.com/appleboy/gorush/metric" + "github.com/appleboy/gorush/queue" "github.com/appleboy/gorush/status" api "github.com/appleboy/gin-status-api" @@ -26,14 +30,12 @@ import ( "golang.org/x/crypto/acme/autocert" ) -var isTerm bool +var ( + isTerm bool + doOnce sync.Once +) func init() { - // Support metrics - m := metric.NewMetrics(func() int { - return len(gorush.QueueNotification) - }) - prometheus.MustRegister(m) isTerm = isatty.IsTerminal(os.Stdout.Fd()) } @@ -61,7 +63,7 @@ func versionHandler(c *gin.Context) { }) } -func pushHandler(cfg config.ConfYaml) gin.HandlerFunc { +func pushHandler(cfg config.ConfYaml, q *queue.Queue) gin.HandlerFunc { return func(c *gin.Context) { var form gorush.RequestPush var msg string @@ -101,7 +103,7 @@ func pushHandler(cfg config.ConfYaml) gin.HandlerFunc { } }() - counts, logs := gorush.HandleNotification(ctx, form) + counts, logs := handleNotification(ctx, cfg, form, q) c.JSON(http.StatusOK, gin.H{ "success": "ok", @@ -121,21 +123,23 @@ func metricsHandler(c *gin.Context) { promhttp.Handler().ServeHTTP(c.Writer, c.Request) } -func appStatusHandler(c *gin.Context) { - result := status.App{} +func appStatusHandler(q *queue.Queue) gin.HandlerFunc { + return func(c *gin.Context) { + result := status.App{} - result.Version = GetVersion() - result.QueueMax = cap(gorush.QueueNotification) - result.QueueUsage = len(gorush.QueueNotification) - result.TotalCount = status.StatStorage.GetTotalCount() - result.Ios.PushSuccess = status.StatStorage.GetIosSuccess() - result.Ios.PushError = status.StatStorage.GetIosError() - result.Android.PushSuccess = status.StatStorage.GetAndroidSuccess() - result.Android.PushError = status.StatStorage.GetAndroidError() - result.Huawei.PushSuccess = status.StatStorage.GetHuaweiSuccess() - result.Huawei.PushError = status.StatStorage.GetHuaweiError() + result.Version = GetVersion() + result.QueueMax = q.Capacity() + result.QueueUsage = q.Usage() + result.TotalCount = status.StatStorage.GetTotalCount() + result.Ios.PushSuccess = status.StatStorage.GetIosSuccess() + result.Ios.PushError = status.StatStorage.GetIosError() + result.Android.PushSuccess = status.StatStorage.GetAndroidSuccess() + result.Android.PushError = status.StatStorage.GetAndroidError() + result.Huawei.PushSuccess = status.StatStorage.GetHuaweiSuccess() + result.Huawei.PushError = status.StatStorage.GetHuaweiError() - c.JSON(http.StatusOK, result) + c.JSON(http.StatusOK, result) + } } func sysStatsHandler() gin.HandlerFunc { @@ -153,7 +157,7 @@ func StatMiddleware() gin.HandlerFunc { } } -func autoTLSServer(cfg config.ConfYaml) *http.Server { +func autoTLSServer(cfg config.ConfYaml, q *queue.Queue) *http.Server { m := autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(cfg.Core.AutoTLS.Host), @@ -163,11 +167,11 @@ func autoTLSServer(cfg config.ConfYaml) *http.Server { return &http.Server{ Addr: ":https", TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, - Handler: routerEngine(cfg), + Handler: routerEngine(cfg, q), } } -func routerEngine(cfg config.ConfYaml) *gin.Engine { +func routerEngine(cfg config.ConfYaml, q *queue.Queue) *gin.Engine { zerolog.SetGlobalLevel(zerolog.InfoLevel) if cfg.Core.Mode == "debug" { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -184,6 +188,14 @@ func routerEngine(cfg config.ConfYaml) *gin.Engine { ) } + // Support metrics + doOnce.Do(func() { + m := metric.NewMetrics(func() int { + return q.Usage() + }) + prometheus.MustRegister(m) + }) + // set server mode gin.SetMode(cfg.Core.Mode) @@ -202,10 +214,10 @@ func routerEngine(cfg config.ConfYaml) *gin.Engine { r.Use(StatMiddleware()) r.GET(cfg.API.StatGoURI, api.GinHandler) - r.GET(cfg.API.StatAppURI, appStatusHandler) + r.GET(cfg.API.StatAppURI, appStatusHandler(q)) r.GET(cfg.API.ConfigURI, configHandler(cfg)) r.GET(cfg.API.SysStatURI, sysStatsHandler()) - r.POST(cfg.API.PushURI, pushHandler(cfg)) + r.POST(cfg.API.PushURI, pushHandler(cfg, q)) r.GET(cfg.API.MetricURI, metricsHandler) r.GET(cfg.API.HealthURI, heartbeatHandler) r.HEAD(cfg.API.HealthURI, heartbeatHandler) @@ -214,3 +226,73 @@ func routerEngine(cfg config.ConfYaml) *gin.Engine { return r } + +// markFailedNotification adds failure logs for all tokens in push notification +func markFailedNotification(cfg config.ConfYaml, notification *gorush.PushNotification, reason string) { + logx.LogError.Error(reason) + for _, token := range notification.Tokens { + notification.AddLog(logx.GetLogPushEntry(&logx.InputLog{ + ID: notification.ID, + Status: core.FailedPush, + Token: token, + Message: notification.Message, + Platform: notification.Platform, + Error: errors.New(reason), + HideToken: cfg.Log.HideToken, + Format: cfg.Log.Format, + })) + } + notification.WaitDone() +} + +// HandleNotification add notification to queue list. +func handleNotification(ctx context.Context, cfg config.ConfYaml, req gorush.RequestPush, q *queue.Queue) (int, []logx.LogPushEntry) { + var count int + wg := sync.WaitGroup{} + newNotification := []*gorush.PushNotification{} + for i := range req.Notifications { + notification := &req.Notifications[i] + switch notification.Platform { + case core.PlatFormIos: + if !cfg.Ios.Enabled { + continue + } + case core.PlatFormAndroid: + if !cfg.Android.Enabled { + continue + } + case core.PlatFormHuawei: + if !cfg.Huawei.Enabled { + continue + } + } + newNotification = append(newNotification, notification) + } + + log := make([]logx.LogPushEntry, 0, count) + for _, notification := range newNotification { + if cfg.Core.Sync { + notification.Wg = &wg + notification.Log = &log + notification.AddWaitCount() + } + + if err := q.Enqueue(*notification); err != nil { + markFailedNotification(cfg, notification, "max capacity reached") + } + + count += len(notification.Tokens) + // Count topic message + if notification.To != "" { + count++ + } + } + + if cfg.Core.Sync { + wg.Wait() + } + + status.StatStorage.AddTotalCount(int64(count)) + + return count, log +} diff --git a/router/server_lambda.go b/router/server_lambda.go index 40f1b05..299e44e 100644 --- a/router/server_lambda.go +++ b/router/server_lambda.go @@ -8,12 +8,13 @@ import ( "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/logx" + "github.com/appleboy/gorush/queue" "github.com/apex/gateway" ) // RunHTTPServer provide run http or https protocol. -func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, s ...*http.Server) (err error) { +func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, q *queue.Queue, s ...*http.Server) (err error) { if !cfg.Core.Enabled { logx.LogAccess.Debug("httpd server is disabled.") return nil @@ -21,5 +22,5 @@ func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, s ...*http.Server) logx.LogAccess.Info("HTTPD server is running on " + cfg.Core.Port + " port.") - return gateway.ListenAndServe(cfg.Core.Address+":"+cfg.Core.Port, routerEngine(cfg)) + return gateway.ListenAndServe(cfg.Core.Address+":"+cfg.Core.Port, routerEngine(cfg, q)) } diff --git a/router/server_normal.go b/router/server_normal.go index 00445ba..7884246 100644 --- a/router/server_normal.go +++ b/router/server_normal.go @@ -12,12 +12,13 @@ import ( "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/logx" + "github.com/appleboy/gorush/queue" "golang.org/x/sync/errgroup" ) // RunHTTPServer provide run http or https protocol. -func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, s ...*http.Server) (err error) { +func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, q *queue.Queue, s ...*http.Server) (err error) { var server *http.Server if !cfg.Core.Enabled { @@ -28,7 +29,7 @@ func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, s ...*http.Server) if len(s) == 0 { server = &http.Server{ Addr: cfg.Core.Address + ":" + cfg.Core.Port, - Handler: routerEngine(cfg), + Handler: routerEngine(cfg, q), } } else { server = s[0] @@ -36,7 +37,7 @@ func RunHTTPServer(ctx context.Context, cfg config.ConfYaml, s ...*http.Server) logx.LogAccess.Info("HTTPD server is running on " + cfg.Core.Port + " port.") if cfg.Core.AutoTLS.Enabled { - return startServer(ctx, autoTLSServer(cfg), cfg) + return startServer(ctx, autoTLSServer(cfg, q), cfg) } else if cfg.Core.SSL { config := &tls.Config{ MinVersion: tls.VersionTLS10, diff --git a/router/server_test.go b/router/server_test.go index 663e4a2..716ee94 100644 --- a/router/server_test.go +++ b/router/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/gorush" "github.com/appleboy/gorush/logx" + "github.com/appleboy/gorush/queue" "github.com/appleboy/gorush/status" "github.com/appleboy/gofight/v2" @@ -23,7 +24,10 @@ import ( "github.com/stretchr/testify/assert" ) -var goVersion = runtime.Version() +var ( + goVersion = runtime.Version() + q *queue.Queue +) func TestMain(m *testing.M) { cfg := initTest() @@ -40,11 +44,14 @@ func TestMain(m *testing.M) { log.Fatal(err) } + q = queue.NewQueue(cfg) + q.Start() + defer q.Stop() m.Run() } func initTest() config.ConfYaml { - cfg, _ := config.LoadConf("") + cfg, _ := config.LoadConf() cfg.Core.Mode = "test" return cfg } @@ -88,7 +95,7 @@ func TestRunNormalServer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - assert.NoError(t, RunHTTPServer(ctx, cfg)) + assert.NoError(t, RunHTTPServer(ctx, cfg, q)) }() defer func() { @@ -112,7 +119,7 @@ func TestRunTLSServer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - assert.NoError(t, RunHTTPServer(ctx, cfg)) + assert.NoError(t, RunHTTPServer(ctx, cfg, q)) }() defer func() { @@ -140,7 +147,7 @@ func TestRunTLSBase64Server(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - assert.NoError(t, RunHTTPServer(ctx, cfg)) + assert.NoError(t, RunHTTPServer(ctx, cfg, q)) }() defer func() { @@ -159,7 +166,7 @@ func TestRunAutoTLSServer(t *testing.T) { cfg.Core.AutoTLS.Enabled = true ctx, cancel := context.WithCancel(context.Background()) go func() { - assert.NoError(t, RunHTTPServer(ctx, cfg)) + assert.NoError(t, RunHTTPServer(ctx, cfg, q)) }() defer func() { @@ -179,7 +186,7 @@ func TestLoadTLSCertError(t *testing.T) { cfg.Core.CertPath = "../config/config.yml" cfg.Core.KeyPath = "../config/config.yml" - assert.Error(t, RunHTTPServer(context.Background(), cfg)) + assert.Error(t, RunHTTPServer(context.Background(), cfg, q)) } func TestMissingTLSCertcfgg(t *testing.T) { @@ -192,8 +199,8 @@ func TestMissingTLSCertcfgg(t *testing.T) { cfg.Core.CertBase64 = "" cfg.Core.KeyBase64 = "" - err := RunHTTPServer(context.Background(), cfg) - assert.Error(t, RunHTTPServer(context.Background(), cfg)) + err := RunHTTPServer(context.Background(), cfg, q) + assert.Error(t, RunHTTPServer(context.Background(), cfg, q)) assert.Equal(t, "missing https cert config", err.Error()) } @@ -206,7 +213,7 @@ func TestRootHandler(t *testing.T) { cfg.Log.Format = "json" r.GET("/"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { data := r.Body.Bytes() value, _ := jsonparser.GetString(data, "text") @@ -223,7 +230,7 @@ func TestAPIStatusGoHandler(t *testing.T) { r := gofight.New() r.GET("/api/stat/go"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { data := r.Body.Bytes() value, _ := jsonparser.GetString(data, "go_version") @@ -242,7 +249,7 @@ func TestAPIStatusAppHandler(t *testing.T) { SetVersion(appVersion) r.GET("/api/stat/app"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { data := r.Body.Bytes() value, _ := jsonparser.GetString(data, "version") @@ -258,7 +265,7 @@ func TestAPIConfigHandler(t *testing.T) { r := gofight.New() r.GET("/api/config"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusCreated, r.Code) }) } @@ -270,7 +277,7 @@ func TestMissingNotificationsParameter(t *testing.T) { // missing notifications parameter. r.POST("/api/push"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusBadRequest, r.Code) assert.Equal(t, "application/json; charset=utf-8", r.HeaderMap.Get("Content-Type")) }) @@ -286,7 +293,7 @@ func TestEmptyNotifications(t *testing.T) { SetJSON(gofight.D{ "notifications": []gorush.PushNotification{}, }). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusBadRequest, r.Code) }) } @@ -314,8 +321,8 @@ func TestMutableContent(t *testing.T) { }, }, }). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { - // json: cannot unmarshal number into Go struct field PushNotification.mutable_content of type bool + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + // json: cannot unmarshal number into Go struct field gorush.PushNotification.mutable_content of type bool assert.Equal(t, http.StatusBadRequest, r.Code) }) } @@ -343,7 +350,7 @@ func TestOutOfRangeMaxNotifications(t *testing.T) { }, }, }). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusBadRequest, r.Code) }) } @@ -369,7 +376,7 @@ func TestSuccessPushHandler(t *testing.T) { }, }, }). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusOK, r.Code) }) } @@ -380,7 +387,7 @@ func TestSysStatsHandler(t *testing.T) { r := gofight.New() r.GET("/sys/stats"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusOK, r.Code) }) } @@ -391,7 +398,7 @@ func TestMetricsHandler(t *testing.T) { r := gofight.New() r.GET("/metrics"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusOK, r.Code) }) } @@ -402,7 +409,7 @@ func TestGETHeartbeatHandler(t *testing.T) { r := gofight.New() r.GET("/healthz"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusOK, r.Code) }) } @@ -413,7 +420,7 @@ func TestHEADHeartbeatHandler(t *testing.T) { r := gofight.New() r.HEAD("/healthz"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusOK, r.Code) }) } @@ -425,7 +432,7 @@ func TestVersionHandler(t *testing.T) { r := gofight.New() r.GET("/version"). - Run(routerEngine(cfg), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + Run(routerEngine(cfg, q), func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { assert.Equal(t, http.StatusOK, r.Code) data := r.Body.Bytes() @@ -438,8 +445,231 @@ func TestVersionHandler(t *testing.T) { func TestDisabledHTTPServer(t *testing.T) { cfg := initTest() cfg.Core.Enabled = false - err := RunHTTPServer(context.Background(), cfg) + err := RunHTTPServer(context.Background(), cfg, q) cfg.Core.Enabled = true assert.Nil(t, err) } + +func TestSenMultipleNotifications(t *testing.T) { + ctx := context.Background() + cfg := initTest() + + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := gorush.InitAPNSClient(cfg) + assert.Nil(t, err) + + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") + q.Config(cfg) + + androidToken := os.Getenv("ANDROID_TEST_TOKEN") + + req := gorush.RequestPush{ + Notifications: []gorush.PushNotification{ + // ios + { + Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, + Platform: core.PlatFormIos, + Message: "Welcome", + }, + // android + { + Tokens: []string{androidToken, "bbbbb"}, + Platform: core.PlatFormAndroid, + Message: "Welcome", + }, + }, + } + + count, logs := handleNotification(ctx, cfg, req, q) + assert.Equal(t, 3, count) + assert.Equal(t, 0, len(logs)) +} + +func TestDisabledAndroidNotifications(t *testing.T) { + ctx := context.Background() + cfg := initTest() + + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := gorush.InitAPNSClient(cfg) + assert.Nil(t, err) + + cfg.Android.Enabled = false + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") + q.Config(cfg) + + androidToken := os.Getenv("ANDROID_TEST_TOKEN") + + req := gorush.RequestPush{ + Notifications: []gorush.PushNotification{ + // ios + { + Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, + Platform: core.PlatFormIos, + Message: "Welcome", + }, + // android + { + Tokens: []string{androidToken, "bbbbb"}, + Platform: core.PlatFormAndroid, + Message: "Welcome", + }, + }, + } + + count, logs := handleNotification(ctx, cfg, req, q) + assert.Equal(t, 1, count) + assert.Equal(t, 0, len(logs)) +} + +func TestSyncModeForNotifications(t *testing.T) { + ctx := context.Background() + cfg := initTest() + + cfg.Ios.Enabled = true + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := gorush.InitAPNSClient(cfg) + assert.Nil(t, err) + + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") + + // enable sync mode + cfg.Core.Sync = true + q.Config(cfg) + + androidToken := os.Getenv("ANDROID_TEST_TOKEN") + + req := gorush.RequestPush{ + Notifications: []gorush.PushNotification{ + // ios + { + Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, + Platform: core.PlatFormIos, + Message: "Welcome", + }, + // android + { + Tokens: []string{androidToken, "bbbbb"}, + Platform: core.PlatFormAndroid, + Message: "Welcome", + }, + }, + } + + count, logs := handleNotification(ctx, cfg, req, q) + assert.Equal(t, 3, count) + assert.Equal(t, 2, len(logs)) +} + +func TestSyncModeForTopicNotification(t *testing.T) { + ctx := context.Background() + cfg := initTest() + + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Log.HideToken = false + + // enable sync mode + cfg.Core.Sync = true + q.Config(cfg) + + req := gorush.RequestPush{ + Notifications: []gorush.PushNotification{ + // android + { + // error:InvalidParameters + // Check that the provided parameters have the right name and type. + To: "/topics/foo-bar@@@##", + Platform: core.PlatFormAndroid, + Message: "This is a Firebase Cloud Messaging Topic Message!", + }, + // android + { + // success + To: "/topics/foo-bar", + Platform: core.PlatFormAndroid, + Message: "This is a Firebase Cloud Messaging Topic Message!", + }, + // android + { + // success + Condition: "'dogs' in topics || 'cats' in topics", + Platform: core.PlatFormAndroid, + Message: "This is a Firebase Cloud Messaging Topic Message!", + }, + }, + } + + count, logs := handleNotification(ctx, cfg, req, q) + assert.Equal(t, 2, count) + assert.Equal(t, 1, len(logs)) +} + +func TestSyncModeForDeviceGroupNotification(t *testing.T) { + ctx := context.Background() + cfg := initTest() + + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") + cfg.Log.HideToken = false + + // enable sync mode + cfg.Core.Sync = true + q.Config(cfg) + + req := gorush.RequestPush{ + Notifications: []gorush.PushNotification{ + // android + { + To: "aUniqueKey", + Platform: core.PlatFormAndroid, + Message: "This is a Firebase Cloud Messaging Device Group Message!", + }, + }, + } + + count, logs := handleNotification(ctx, cfg, req, q) + assert.Equal(t, 1, count) + assert.Equal(t, 1, len(logs)) +} + +func TestDisabledIosNotifications(t *testing.T) { + ctx := context.Background() + cfg := initTest() + + cfg.Ios.Enabled = false + cfg.Ios.KeyPath = "../certificate/certificate-valid.pem" + err := gorush.InitAPNSClient(cfg) + assert.Nil(t, err) + + cfg.Android.Enabled = true + cfg.Android.APIKey = os.Getenv("ANDROID_API_KEY") + q.Config(cfg) + + androidToken := os.Getenv("ANDROID_TEST_TOKEN") + + req := gorush.RequestPush{ + Notifications: []gorush.PushNotification{ + // ios + { + Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, + Platform: core.PlatFormIos, + Message: "Welcome", + }, + // android + { + Tokens: []string{androidToken, androidToken + "_"}, + Platform: core.PlatFormAndroid, + Message: "Welcome", + }, + }, + } + + count, logs := handleNotification(ctx, cfg, req, q) + assert.Equal(t, 2, count) + assert.Equal(t, 0, len(logs)) +} diff --git a/rpc/server.go b/rpc/server.go index cfa3c1b..5aa4077 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -6,6 +6,7 @@ import ( "strings" "sync" + "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/core" "github.com/appleboy/gorush/gorush" "github.com/appleboy/gorush/logx" @@ -19,14 +20,16 @@ import ( // Server is used to implement gorush grpc server. type Server struct { - mu sync.Mutex + cfg config.ConfYaml + mu sync.Mutex // statusMap stores the serving status of the services this Server monitors. statusMap map[string]proto.HealthCheckResponse_ServingStatus } // NewServer returns a new Server. -func NewServer() *Server { +func NewServer(cfg config.ConfYaml) *Server { return &Server{ + cfg: cfg, statusMap: make(map[string]proto.HealthCheckResponse_ServingStatus), } } @@ -98,7 +101,7 @@ func (s *Server) Send(ctx context.Context, in *proto.NotificationRequest) (*prot } } - go gorush.SendNotification(ctx, notification) + go gorush.SendNotification(s.cfg, notification) return &proto.NotificationReply{ Success: true, @@ -107,26 +110,26 @@ func (s *Server) Send(ctx context.Context, in *proto.NotificationRequest) (*prot } // RunGRPCServer run gorush grpc server -func RunGRPCServer(ctx context.Context) error { - if !gorush.PushConf.GRPC.Enabled { +func RunGRPCServer(ctx context.Context, cfg config.ConfYaml) error { + if !cfg.GRPC.Enabled { logx.LogAccess.Info("gRPC server is disabled.") return nil } s := grpc.NewServer() - rpcSrv := NewServer() + rpcSrv := NewServer(cfg) proto.RegisterGorushServer(s, rpcSrv) proto.RegisterHealthServer(s, rpcSrv) // Register reflection service on gRPC server. reflection.Register(s) - lis, err := net.Listen("tcp", ":"+gorush.PushConf.GRPC.Port) + lis, err := net.Listen("tcp", ":"+cfg.GRPC.Port) if err != nil { logx.LogError.Fatalln(err) return err } - logx.LogAccess.Info("gRPC server is running on " + gorush.PushConf.GRPC.Port + " port.") + logx.LogAccess.Info("gRPC server is running on " + cfg.GRPC.Port + " port.") go func() { select { case <-ctx.Done(): diff --git a/rpc/server_test.go b/rpc/server_test.go index 75c30c0..545e308 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/appleboy/gorush/gorush" + "github.com/appleboy/gorush/config" "github.com/appleboy/gorush/logx" "google.golang.org/grpc" @@ -13,22 +13,29 @@ import ( const gRPCAddr = "localhost:9000" +func initTest() config.ConfYaml { + cfg, _ := config.LoadConf() + cfg.Core.Mode = "test" + return cfg +} + func TestGracefulShutDownGRPCServer(t *testing.T) { + cfg := initTest() // server configs logx.InitLog( - gorush.PushConf.Log.AccessLevel, - gorush.PushConf.Log.AccessLog, - gorush.PushConf.Log.ErrorLevel, - gorush.PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, ) - gorush.PushConf.GRPC.Enabled = true - gorush.PushConf.GRPC.Port = "9000" - gorush.PushConf.Log.Format = "json" + cfg.GRPC.Enabled = true + cfg.GRPC.Port = "9000" + cfg.Log.Format = "json" // Run gRPC server ctx, gRPCContextCancel := context.WithCancel(context.Background()) go func() { - if err := RunGRPCServer(ctx); err != nil { + if err := RunGRPCServer(ctx, cfg); err != nil { panic(err) } }() diff --git a/status/status_test.go b/status/status_test.go index 589eacc..126425f 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -12,12 +12,12 @@ import ( ) func TestMain(m *testing.M) { - PushConf, _ := config.LoadConf("") + cfg, _ := config.LoadConf() if err := logx.InitLog( - PushConf.Log.AccessLevel, - PushConf.Log.AccessLog, - PushConf.Log.ErrorLevel, - PushConf.Log.ErrorLog, + cfg.Log.AccessLevel, + cfg.Log.AccessLog, + cfg.Log.ErrorLevel, + cfg.Log.ErrorLog, ); err != nil { log.Fatal(err) } @@ -26,9 +26,9 @@ func TestMain(m *testing.M) { } func TestStorageDriverExist(t *testing.T) { - PushConf, _ := config.LoadConf("") - PushConf.Stat.Engine = "Test" - err := InitAppStatus(PushConf) + cfg, _ := config.LoadConf() + cfg.Stat.Engine = "Test" + err := InitAppStatus(cfg) assert.Error(t, err) } @@ -37,9 +37,9 @@ func TestStatForMemoryEngine(t *testing.T) { time.Sleep(5 * time.Second) var val int64 - PushConf, _ := config.LoadConf("") - PushConf.Stat.Engine = "memory" - err := InitAppStatus(PushConf) + cfg, _ := config.LoadConf() + cfg.Stat.Engine = "memory" + err := InitAppStatus(cfg) assert.Nil(t, err) StatStorage.AddTotalCount(100) @@ -61,31 +61,31 @@ func TestStatForMemoryEngine(t *testing.T) { } func TestRedisServerSuccess(t *testing.T) { - PushConf, _ := config.LoadConf("") - PushConf.Stat.Engine = "redis" - PushConf.Stat.Redis.Addr = "redis:6379" + cfg, _ := config.LoadConf() + cfg.Stat.Engine = "redis" + cfg.Stat.Redis.Addr = "redis:6379" - err := InitAppStatus(PushConf) + err := InitAppStatus(cfg) assert.NoError(t, err) } func TestRedisServerError(t *testing.T) { - PushConf, _ := config.LoadConf("") - PushConf.Stat.Engine = "redis" - PushConf.Stat.Redis.Addr = "redis:6370" + cfg, _ := config.LoadConf() + cfg.Stat.Engine = "redis" + cfg.Stat.Redis.Addr = "redis:6370" - err := InitAppStatus(PushConf) + err := InitAppStatus(cfg) assert.Error(t, err) } func TestStatForRedisEngine(t *testing.T) { var val int64 - PushConf, _ := config.LoadConf("") - PushConf.Stat.Engine = "redis" - PushConf.Stat.Redis.Addr = "redis:6379" - err := InitAppStatus(PushConf) + cfg, _ := config.LoadConf() + cfg.Stat.Engine = "redis" + cfg.Stat.Redis.Addr = "redis:6379" + err := InitAppStatus(cfg) assert.Nil(t, err) StatStorage.Init() @@ -112,8 +112,8 @@ func TestStatForRedisEngine(t *testing.T) { func TestDefaultEngine(t *testing.T) { var val int64 // defaul engine as memory - PushConf, _ := config.LoadConf("") - err := InitAppStatus(PushConf) + cfg, _ := config.LoadConf() + err := InitAppStatus(cfg) assert.Nil(t, err) StatStorage.Reset() @@ -138,9 +138,9 @@ func TestDefaultEngine(t *testing.T) { func TestStatForBoltDBEngine(t *testing.T) { var val int64 - PushConf, _ := config.LoadConf("") - PushConf.Stat.Engine = "boltdb" - err := InitAppStatus(PushConf) + cfg, _ := config.LoadConf() + cfg.Stat.Engine = "boltdb" + err := InitAppStatus(cfg) assert.Nil(t, err) StatStorage.Reset() @@ -165,7 +165,7 @@ func TestStatForBoltDBEngine(t *testing.T) { // func TestStatForBuntDBEngine(t *testing.T) { // var val int64 -// PushConf.Stat.Engine = "buntdb" +// cfg.Stat.Engine = "buntdb" // err := InitAppStatus() // assert.Nil(t, err) @@ -191,7 +191,7 @@ func TestStatForBoltDBEngine(t *testing.T) { // func TestStatForLevelDBEngine(t *testing.T) { // var val int64 -// PushConf.Stat.Engine = "leveldb" +// cfg.Stat.Engine = "leveldb" // err := InitAppStatus() // assert.Nil(t, err) @@ -217,7 +217,7 @@ func TestStatForBoltDBEngine(t *testing.T) { // func TestStatForBadgerEngine(t *testing.T) { // var val int64 -// PushConf.Stat.Engine = "badger" +// cfg.Stat.Engine = "badger" // err := InitAppStatus() // assert.Nil(t, err) diff --git a/storage/badger/badger_test.go b/storage/badger/badger_test.go index 033e11d..3c26b27 100644 --- a/storage/badger/badger_test.go +++ b/storage/badger/badger_test.go @@ -3,16 +3,16 @@ package badger import ( "testing" - c "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/config" "github.com/stretchr/testify/assert" ) func TestBadgerEngine(t *testing.T) { var val int64 - config, _ := c.LoadConf("") + cfg, _ := config.LoadConf() - badger := New(config) + badger := New(cfg) err := badger.Init() assert.Nil(t, err) badger.Reset() diff --git a/storage/boltdb/boltdb_test.go b/storage/boltdb/boltdb_test.go index 8ea2670..07e1687 100644 --- a/storage/boltdb/boltdb_test.go +++ b/storage/boltdb/boltdb_test.go @@ -3,16 +3,16 @@ package boltdb import ( "testing" - c "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/config" "github.com/stretchr/testify/assert" ) func TestBoltDBEngine(t *testing.T) { var val int64 - config, _ := c.LoadConf("") + cfg, _ := config.LoadConf() - boltDB := New(config) + boltDB := New(cfg) err := boltDB.Init() assert.Nil(t, err) boltDB.Reset() diff --git a/storage/buntdb/buntdb_test.go b/storage/buntdb/buntdb_test.go index 5a65f03..022ce42 100644 --- a/storage/buntdb/buntdb_test.go +++ b/storage/buntdb/buntdb_test.go @@ -4,21 +4,21 @@ import ( "os" "testing" - c "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/config" "github.com/stretchr/testify/assert" ) func TestBuntDBEngine(t *testing.T) { var val int64 - config, _ := c.LoadConf("") + cfg, _ := config.LoadConf() - if _, err := os.Stat(config.Stat.BuntDB.Path); os.IsNotExist(err) { - err := os.RemoveAll(config.Stat.BuntDB.Path) + if _, err := os.Stat(cfg.Stat.BuntDB.Path); os.IsNotExist(err) { + err := os.RemoveAll(cfg.Stat.BuntDB.Path) assert.Nil(t, err) } - buntDB := New(config) + buntDB := New(cfg) err := buntDB.Init() assert.Nil(t, err) buntDB.Reset() diff --git a/storage/leveldb/leveldb_test.go b/storage/leveldb/leveldb_test.go index aa69189..8268214 100644 --- a/storage/leveldb/leveldb_test.go +++ b/storage/leveldb/leveldb_test.go @@ -4,21 +4,21 @@ import ( "os" "testing" - c "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/config" "github.com/stretchr/testify/assert" ) func TestLevelDBEngine(t *testing.T) { var val int64 - config, _ := c.LoadConf("") + cfg, _ := config.LoadConf() - if _, err := os.Stat(config.Stat.LevelDB.Path); os.IsNotExist(err) { - err = os.RemoveAll(config.Stat.LevelDB.Path) + if _, err := os.Stat(cfg.Stat.LevelDB.Path); os.IsNotExist(err) { + err = os.RemoveAll(cfg.Stat.LevelDB.Path) assert.Nil(t, err) } - levelDB := New(config) + levelDB := New(cfg) err := levelDB.Init() assert.Nil(t, err) levelDB.Reset() diff --git a/storage/redis/redis_test.go b/storage/redis/redis_test.go index 3075df7..0f31072 100644 --- a/storage/redis/redis_test.go +++ b/storage/redis/redis_test.go @@ -4,15 +4,15 @@ import ( "sync" "testing" - c "github.com/appleboy/gorush/config" + "github.com/appleboy/gorush/config" "github.com/stretchr/testify/assert" ) func TestRedisServerError(t *testing.T) { - config, _ := c.LoadConf("") - config.Stat.Redis.Addr = "redis:6370" + cfg, _ := config.LoadConf() + cfg.Stat.Redis.Addr = "redis:6370" - redis := New(config) + redis := New(cfg) err := redis.Init() assert.Error(t, err) @@ -21,10 +21,10 @@ func TestRedisServerError(t *testing.T) { func TestRedisEngine(t *testing.T) { var val int64 - config, _ := c.LoadConf("") - config.Stat.Redis.Addr = "redis:6379" + cfg, _ := config.LoadConf() + cfg.Stat.Redis.Addr = "redis:6379" - redis := New(config) + redis := New(cfg) err := redis.Init() assert.Nil(t, err) redis.Reset()