package router

import (
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"net/http"
	"os"
	"sync"

	"github.com/appleboy/gorush/config"
	"github.com/appleboy/gorush/core"
	"github.com/appleboy/gorush/logx"
	"github.com/appleboy/gorush/metric"
	"github.com/appleboy/gorush/notify"
	"github.com/appleboy/gorush/status"

	api "github.com/appleboy/gin-status-api"
	"github.com/gin-contrib/logger"
	"github.com/gin-gonic/gin"
	"github.com/gin-gonic/gin/binding"
	"github.com/golang-queue/queue"
	"github.com/mattn/go-isatty"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promhttp"
	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"
	"github.com/thoas/stats"
	"golang.org/x/crypto/acme/autocert"
)

var doOnce sync.Once

func abortWithError(c *gin.Context, code int, message string) {
	c.AbortWithStatusJSON(code, gin.H{
		"code":    code,
		"message": message,
	})
}

func rootHandler(c *gin.Context) {
	c.JSON(http.StatusOK, gin.H{
		"text": "Welcome to notification server.",
	})
}

func heartbeatHandler(c *gin.Context) {
	c.AbortWithStatus(http.StatusOK)
}

func versionHandler(c *gin.Context) {
	c.JSON(http.StatusOK, gin.H{
		"source":  "https://github.com/appleboy/gorush",
		"version": GetVersion(),
	})
}

func pushHandler(cfg *config.ConfYaml, q *queue.Queue) gin.HandlerFunc {
	return func(c *gin.Context) {
		var form notify.RequestPush
		var msg string

		if err := c.ShouldBindWith(&form, binding.JSON); err != nil {
			msg = "Missing notifications field."
			logx.LogAccess.Debug(err)
			abortWithError(c, http.StatusBadRequest, msg)
			return
		}

		if len(form.Notifications) == 0 {
			msg = "Notifications field is empty."
			logx.LogAccess.Debug(msg)
			abortWithError(c, http.StatusBadRequest, msg)
			return
		}

		if int64(len(form.Notifications)) > cfg.Core.MaxNotification {
			msg = fmt.Sprintf("Number of notifications(%d) over limit(%d)", len(form.Notifications), cfg.Core.MaxNotification)
			logx.LogAccess.Debug(msg)
			abortWithError(c, http.StatusBadRequest, msg)
			return
		}

		ctx, cancel := context.WithCancel(context.Background())
		go func() {
			// Deprecated: the CloseNotifier interface predates Go's context package.
			// New code should use Request.Context instead.
			// Change to context package
			<-c.Request.Context().Done()
			// Don't send notification after client timeout or disconnected.
			// See the following issue for detail information.
			// https://github.com/appleboy/gorush/issues/422
			if cfg.Core.Sync {
				cancel()
			}
		}()

		counts, logs := handleNotification(ctx, cfg, form, q)

		c.JSON(http.StatusOK, gin.H{
			"success": "ok",
			"counts":  counts,
			"logs":    logs,
		})
	}
}

func configHandler(cfg *config.ConfYaml) gin.HandlerFunc {
	return func(c *gin.Context) {
		c.YAML(http.StatusCreated, cfg)
	}
}

func metricsHandler(c *gin.Context) {
	promhttp.Handler().ServeHTTP(c.Writer, c.Request)
}

func appStatusHandler(q *queue.Queue) gin.HandlerFunc {
	return func(c *gin.Context) {
		result := status.App{}

		result.Version = GetVersion()
		result.BusyWorkers = q.BusyWorkers()
		result.SuccessTasks = q.SuccessTasks()
		result.FailureTasks = q.FailureTasks()
		result.SubmittedTasks = q.SubmittedTasks()
		result.TotalCount = status.StatStorage.GetTotalCount()
		result.Ios.PushSuccess = status.StatStorage.GetIosSuccess()
		result.Ios.PushError = status.StatStorage.GetIosError()
		result.Android.PushSuccess = status.StatStorage.GetAndroidSuccess()
		result.Android.PushError = status.StatStorage.GetAndroidError()
		result.Huawei.PushSuccess = status.StatStorage.GetHuaweiSuccess()
		result.Huawei.PushError = status.StatStorage.GetHuaweiError()

		c.JSON(http.StatusOK, result)
	}
}

func sysStatsHandler() gin.HandlerFunc {
	return func(c *gin.Context) {
		c.JSON(http.StatusOK, status.Stats.Data())
	}
}

// StatMiddleware response time, status code count, etc.
func StatMiddleware() gin.HandlerFunc {
	return func(c *gin.Context) {
		beginning, recorder := status.Stats.Begin(c.Writer)
		c.Next()
		status.Stats.End(beginning, stats.WithRecorder(recorder))
	}
}

func autoTLSServer(cfg *config.ConfYaml, q *queue.Queue) *http.Server {
	m := autocert.Manager{
		Prompt:     autocert.AcceptTOS,
		HostPolicy: autocert.HostWhitelist(cfg.Core.AutoTLS.Host),
		Cache:      autocert.DirCache(cfg.Core.AutoTLS.Folder),
	}

	return &http.Server{
		Addr:      ":https",
		TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
		Handler:   routerEngine(cfg, q),
	}
}

func routerEngine(cfg *config.ConfYaml, q *queue.Queue) *gin.Engine {
	zerolog.SetGlobalLevel(zerolog.InfoLevel)
	if cfg.Core.Mode == "debug" {
		zerolog.SetGlobalLevel(zerolog.DebugLevel)
	}

	log.Logger = zerolog.New(os.Stdout).With().Timestamp().Logger()

	isTerm := isatty.IsTerminal(os.Stdout.Fd())
	if isTerm {
		log.Logger = log.Output(
			zerolog.ConsoleWriter{
				Out:     os.Stdout,
				NoColor: false,
			},
		)
	}

	// Support metrics
	doOnce.Do(func() {
		m := metric.NewMetrics(q)
		prometheus.MustRegister(m)
	})

	// set server mode
	gin.SetMode(cfg.Core.Mode)

	r := gin.New()

	// Global middleware
	r.Use(logger.SetLogger(
		logger.WithUTC(true),
		logger.WithSkipPath([]string{
			cfg.API.HealthURI,
			cfg.API.MetricURI,
		}),
	))
	r.Use(gin.Recovery())
	r.Use(VersionMiddleware())
	r.Use(StatMiddleware())

	r.GET(cfg.API.StatGoURI, api.GinHandler)
	r.GET(cfg.API.StatAppURI, appStatusHandler(q))
	r.GET(cfg.API.ConfigURI, configHandler(cfg))
	r.GET(cfg.API.SysStatURI, sysStatsHandler())
	r.POST(cfg.API.PushURI, pushHandler(cfg, q))
	r.GET(cfg.API.MetricURI, metricsHandler)
	r.GET(cfg.API.HealthURI, heartbeatHandler)
	r.HEAD(cfg.API.HealthURI, heartbeatHandler)
	r.GET("/version", versionHandler)
	r.GET("/", rootHandler)

	return r
}

// markFailedNotification adds failure logs for all tokens in push notification
func markFailedNotification(cfg *config.ConfYaml, notification *notify.PushNotification, reason string) []logx.LogPushEntry {
	logx.LogError.Error(reason)
	logs := make([]logx.LogPushEntry, 0)
	for _, token := range notification.Tokens {
		logs = append(logs, logx.GetLogPushEntry(&logx.InputLog{
			ID:        notification.ID,
			Status:    core.FailedPush,
			Token:     token,
			Message:   notification.Message,
			Platform:  notification.Platform,
			Error:     errors.New(reason),
			HideToken: cfg.Log.HideToken,
			Format:    cfg.Log.Format,
		}))
	}

	return logs
}

// HandleNotification add notification to queue list.
func handleNotification(ctx context.Context, cfg *config.ConfYaml, req notify.RequestPush, q *queue.Queue) (int, []logx.LogPushEntry) {
	var count int
	wg := sync.WaitGroup{}
	newNotification := []*notify.PushNotification{}

	if cfg.Core.Sync && !core.IsLocalQueue(core.Queue(cfg.Queue.Engine)) {
		cfg.Core.Sync = false
	}

	for i := range req.Notifications {
		notification := &req.Notifications[i]
		switch notification.Platform {
		case core.PlatFormIos:
			if !cfg.Ios.Enabled {
				continue
			}
		case core.PlatFormAndroid:
			if !cfg.Android.Enabled {
				continue
			}
		case core.PlatFormHuawei:
			if !cfg.Huawei.Enabled {
				continue
			}
		}
		newNotification = append(newNotification, notification)
	}

	logs := make([]logx.LogPushEntry, 0, count)
	for _, notification := range newNotification {
		if cfg.Core.Sync {
			wg.Add(1)
		}

		if core.IsLocalQueue(core.Queue(cfg.Queue.Engine)) && cfg.Core.Sync {
			func(msg *notify.PushNotification, cfg *config.ConfYaml) {
				if err := q.QueueTask(func(ctx context.Context) error {
					defer wg.Done()
					resp, err := notify.SendNotification(msg, cfg)
					if err != nil {
						return err
					}

					// add log
					logs = append(logs, resp.Logs...)

					return nil
				}); err != nil {
					logx.LogError.Error(err)
				}
			}(notification, cfg)
		} else if err := q.Queue(notification); err != nil {
			resp := markFailedNotification(cfg, notification, "max capacity reached")
			// add log
			logs = append(logs, resp...)
			wg.Done()
		}

		count += len(notification.Tokens)
		// Count topic message
		if notification.To != "" {
			count++
		}
	}

	if cfg.Core.Sync {
		wg.Wait()
	}

	status.StatStorage.AddTotalCount(int64(count))

	return count, logs
}