package identification import ( "context" "crypto/rand" "encoding/base64" "fmt" "io" "net/http" "git.coopgo.io/coopgo-apps/parcoursmob/services" "git.coopgo.io/coopgo-apps/parcoursmob/utils/storage" "github.com/coreos/go-oidc" "github.com/gorilla/sessions" "github.com/spf13/viper" "golang.org/x/oauth2" ) type ContextKey string const IdtokenKey ContextKey = "idtoken" const ClaimsKey ContextKey = "claims" type IdentificationProvider struct { SessionsStore sessions.Store Provider *oidc.Provider OAuth2Config oauth2.Config TokenVerifier *oidc.IDTokenVerifier Services *services.ServicesHandler } func NewIdentificationProvider(cfg *viper.Viper, services *services.ServicesHandler, kv storage.KVHandler) (*IdentificationProvider, error) { var ( providerURL = cfg.GetString("identification.oidc.provider") clientID = cfg.GetString("identification.oidc.client_id") clientSecret = cfg.GetString("identification.oidc.client_secret") redirectURL = cfg.GetString("identification.oidc.redirect_url") sessionsSecret = cfg.GetString("identification.sessions.secret") ) provider, err := oidc.NewProvider(context.Background(), providerURL) if err != nil { return nil, err } oauth2Config := oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, RedirectURL: redirectURL, // Discovery returns the OAuth2 endpoints. Endpoint: provider.Endpoint(), // "openid" is a required scope for OpenID Connect flows. Scopes: []string{oidc.ScopeOpenID, "groups", "first_name", "last_name", "display_name", "email"}, } store := storage.NewSessionStore(kv, []byte(sessionsSecret)) verifier := provider.Verifier(&oidc.Config{ClientID: oauth2Config.ClientID}) return &IdentificationProvider{ SessionsStore: store, Provider: provider, OAuth2Config: oauth2Config, TokenVerifier: verifier, Services: services, }, nil } func (p *IdentificationProvider) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Println("IDP Middleware") session, err := p.SessionsStore.Get(r, "parcoursmob_session") if err != nil { fmt.Println(err) } if session.Values["idtoken"] == nil || session.Values["idtoken"] == "" { state, err := newState() if err != nil { panic(err) } session.Values["state"] = state session.Save(r, w) http.Redirect(w, r, p.OAuth2Config.AuthCodeURL(state), http.StatusFound) return } idtoken, err := p.TokenVerifier.Verify(context.Background(), session.Values["idtoken"].(string)) if err != nil { state, err := newState() if err != nil { panic(err) } session.Values["state"] = state delete(session.Values, "idtoken") http.Redirect(w, r, p.OAuth2Config.AuthCodeURL(state), http.StatusFound) return } var claims map[string]any err = idtoken.Claims(&claims) if err != nil { fmt.Println(err) } ctx := context.WithValue(r.Context(), IdtokenKey, idtoken) ctx = context.WithValue(ctx, ClaimsKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } func newState() (string, error) { b := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil }