parcoursmob/utils/identification/oidc.go

122 lines
3.2 KiB
Go

package identification
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"git.coopgo.io/coopgo-apps/parcoursmob/services"
"git.coopgo.io/coopgo-apps/parcoursmob/utils/storage"
"github.com/coreos/go-oidc"
"github.com/gorilla/sessions"
"github.com/spf13/viper"
"golang.org/x/oauth2"
)
type ContextKey string
const IdtokenKey ContextKey = "idtoken"
const ClaimsKey ContextKey = "claims"
type IdentificationProvider struct {
SessionsStore sessions.Store
Provider *oidc.Provider
OAuth2Config oauth2.Config
TokenVerifier *oidc.IDTokenVerifier
Services *services.ServicesHandler
}
func NewIdentificationProvider(cfg *viper.Viper, services *services.ServicesHandler, kv storage.KVHandler) (*IdentificationProvider, error) {
var (
providerURL = cfg.GetString("identification.oidc.provider")
clientID = cfg.GetString("identification.oidc.client_id")
clientSecret = cfg.GetString("identification.oidc.client_secret")
redirectURL = cfg.GetString("identification.oidc.redirect_url")
sessionsSecret = cfg.GetString("identification.sessions.secret")
)
provider, err := oidc.NewProvider(context.Background(), providerURL)
if err != nil {
return nil, err
}
oauth2Config := oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
// Discovery returns the OAuth2 endpoints.
Endpoint: provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "groups", "first_name", "last_name", "display_name", "email"},
}
store := storage.NewSessionStore(kv, []byte(sessionsSecret))
verifier := provider.Verifier(&oidc.Config{ClientID: oauth2Config.ClientID})
return &IdentificationProvider{
SessionsStore: store,
Provider: provider,
OAuth2Config: oauth2Config,
TokenVerifier: verifier,
Services: services,
}, nil
}
func (p *IdentificationProvider) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := p.SessionsStore.Get(r, "parcoursmob_session")
if err != nil {
fmt.Println(err)
}
if session.Values["idtoken"] == nil || session.Values["idtoken"] == "" {
state, err := newState()
if err != nil {
panic(err)
}
session.Values["state"] = state
session.Save(r, w)
http.Redirect(w, r, p.OAuth2Config.AuthCodeURL(state), http.StatusFound)
return
}
idtoken, err := p.TokenVerifier.Verify(context.Background(), session.Values["idtoken"].(string))
if err != nil {
state, err := newState()
if err != nil {
panic(err)
}
session.Values["state"] = state
delete(session.Values, "idtoken")
http.Redirect(w, r, p.OAuth2Config.AuthCodeURL(state), http.StatusFound)
return
}
var claims map[string]any
err = idtoken.Claims(&claims)
if err != nil {
fmt.Println(err)
}
ctx := context.WithValue(r.Context(), IdtokenKey, idtoken)
ctx = context.WithValue(ctx, ClaimsKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func newState() (string, error) {
b := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}