Sessions in etcd KV store instead of cookies

This commit is contained in:
2022-10-30 20:11:36 +01:00
parent c2c6a72f81
commit f4c2d61dc3
41 changed files with 1008 additions and 202 deletions

104
utils/cache/cache.go vendored
View File

@@ -1,104 +0,0 @@
package cache
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/spf13/viper"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/namespace"
)
type CacheHandler struct {
*clientv3.Client
}
func NewCacheHandler(cfg *viper.Viper) (*CacheHandler, error) {
var (
endpoints = cfg.GetStringSlice("cache.storage.etcd.endpoints")
prefix = cfg.GetString("cache.storage.etcd.prefix")
)
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: 5 * time.Second,
})
if err != nil {
return nil, err
}
cli.KV = namespace.NewKV(cli.KV, prefix)
cli.Watcher = namespace.NewWatcher(cli.Watcher, prefix)
cli.Lease = namespace.NewLease(cli.Lease, prefix)
return &CacheHandler{
Client: cli,
}, nil
}
func (s *CacheHandler) Put(k string, v any) error {
data, err := json.Marshal(v)
if err != nil {
return err
}
// _, err = s.Client.KV.Put(context.TODO(), k, data.String())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err = s.Client.KV.Put(ctx, k, string(data))
cancel()
if err != nil {
return err
}
return nil
}
func (s *CacheHandler) PutWithTTL(k string, v any, duration time.Duration) error {
lease, err := s.Client.Lease.Grant(context.TODO(), int64(duration.Seconds()))
if err != nil {
return err
}
data, err := json.Marshal(v)
if err != nil {
return err
}
// _, err = s.Client.KV.Put(context.TODO(), k, data.String(), clientv3.WithLease(lease.ID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err = s.Client.KV.Put(ctx, k, string(data), clientv3.WithLease(lease.ID))
cancel()
if err != nil {
return err
}
return nil
}
func (s *CacheHandler) Get(k string) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
resp, err := s.Client.KV.Get(ctx, k)
cancel()
if err != nil {
return nil, err
}
for _, v := range resp.Kvs {
var data any
err := json.Unmarshal([]byte(v.Value), &data)
if err != nil {
return nil, err
}
// We return directly as we want to last revision of value
return data, nil
}
return nil, errors.New(fmt.Sprintf("no value %v", k))
}
func (s *CacheHandler) Delete(k string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := s.Client.KV.Delete(ctx, k)
cancel()
if err != nil {
return err
}
return nil
}

View File

@@ -19,6 +19,7 @@ func (p *IdentificationProvider) GroupsMiddleware(next http.Handler) http.Handle
o, ok := session.Values["organization"]
if !ok || o == nil {
fmt.Println("no organization")
http.Redirect(w, r, "/auth/groups/", http.StatusFound)
return
}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"git.coopgo.io/coopgo-apps/parcoursmob/services"
"git.coopgo.io/coopgo-apps/parcoursmob/utils/storage"
"github.com/coreos/go-oidc"
"github.com/gorilla/sessions"
"github.com/spf13/viper"
@@ -28,7 +29,7 @@ type IdentificationProvider struct {
Services *services.ServicesHandler
}
func NewIdentificationProvider(cfg *viper.Viper, services *services.ServicesHandler) (*IdentificationProvider, error) {
func NewIdentificationProvider(cfg *viper.Viper, services *services.ServicesHandler, kv storage.KVHandler) (*IdentificationProvider, error) {
var (
providerURL = cfg.GetString("identification.oidc.provider")
clientID = cfg.GetString("identification.oidc.client_id")
@@ -54,7 +55,7 @@ func NewIdentificationProvider(cfg *viper.Viper, services *services.ServicesHand
Scopes: []string{oidc.ScopeOpenID, "groups", "first_name", "last_name", "display_name"},
}
var store = sessions.NewCookieStore([]byte(sessionsSecret))
store := storage.NewSessionStore(kv, []byte(sessionsSecret))
verifier := provider.Verifier(&oidc.Config{ClientID: oauth2Config.ClientID})
return &IdentificationProvider{

11
utils/storage/cache.go Normal file
View File

@@ -0,0 +1,11 @@
package storage
import (
"github.com/spf13/viper"
)
type CacheHandler KVHandler
func NewCacheHandler(cfg *viper.Viper) (CacheHandler, error) {
return NewKVHandler(cfg)
}

149
utils/storage/etcd.go Normal file
View File

@@ -0,0 +1,149 @@
package storage
import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"time"
"github.com/spf13/viper"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/namespace"
)
type EtcdSerializer interface {
Deserialize(d []byte, m *any) error
Serialize(m any) ([]byte, error)
}
type JSONEtcdSerializer struct{}
// Serialize to JSON. Will err if there are unmarshalable key values
func (s JSONEtcdSerializer) Serialize(m any) ([]byte, error) {
return json.Marshal(m)
}
// Deserialize back to map[string]interface{}
func (s JSONEtcdSerializer) Deserialize(d []byte, m *any) (err error) {
err = json.Unmarshal(d, &m)
if err != nil {
fmt.Printf("JSONSerializer.deserialize() Error: %v", err)
return err
}
return
}
// GobEtcdSerializer uses gob package to encode the session map
type GobEtcdSerializer struct{}
// Serialize using gob
func (s GobEtcdSerializer) Serialize(m any) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(m)
if err == nil {
return buf.Bytes(), nil
}
return nil, err
}
// Deserialize back to map[interface{}]interface{}
func (s GobEtcdSerializer) Deserialize(d []byte, m any) error {
dec := gob.NewDecoder(bytes.NewBuffer(d))
return dec.Decode(&m)
}
type EtcdHandler struct {
*clientv3.Client
serializer EtcdSerializer
}
func NewEtcdHandler(cfg *viper.Viper) (*EtcdHandler, error) {
var (
endpoints = cfg.GetStringSlice("storage.kv.etcd.endpoints")
prefix = cfg.GetString("storage.kv.etcd.prefix")
)
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: 5 * time.Second,
})
if err != nil {
return nil, err
}
cli.KV = namespace.NewKV(cli.KV, prefix)
cli.Watcher = namespace.NewWatcher(cli.Watcher, prefix)
cli.Lease = namespace.NewLease(cli.Lease, prefix)
return &EtcdHandler{
Client: cli,
serializer: JSONEtcdSerializer{},
}, nil
}
func (s *EtcdHandler) Put(k string, v any) error {
data, err := s.serializer.Serialize(v)
if err != nil {
return err
}
// _, err = s.Client.KV.Put(context.TODO(), k, data.String())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err = s.Client.KV.Put(ctx, k, string(data))
cancel()
if err != nil {
return err
}
return nil
}
func (s *EtcdHandler) PutWithTTL(k string, v any, duration time.Duration) error {
lease, err := s.Client.Lease.Grant(context.TODO(), int64(duration.Seconds()))
if err != nil {
return err
}
data, err := s.serializer.Serialize(v)
if err != nil {
return err
}
// _, err = s.Client.KV.Put(context.TODO(), k, data.String(), clientv3.WithLease(lease.ID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err = s.Client.KV.Put(ctx, k, string(data), clientv3.WithLease(lease.ID))
cancel()
if err != nil {
return err
}
return nil
}
func (s *EtcdHandler) Get(k string) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
resp, err := s.Client.KV.Get(ctx, k)
cancel()
if err != nil {
return nil, err
}
for _, v := range resp.Kvs {
var data any
err := s.serializer.Deserialize([]byte(v.Value), &data)
if err != nil {
return nil, err
}
// We return directly as we want to last revision of value
return data, nil
}
return nil, fmt.Errorf("no value %v", k)
}
func (s *EtcdHandler) Delete(k string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := s.Client.KV.Delete(ctx, k)
cancel()
if err != nil {
return err
}
return nil
}

18
utils/storage/kv.go Normal file
View File

@@ -0,0 +1,18 @@
package storage
import (
"time"
"github.com/spf13/viper"
)
type KVHandler interface {
Put(k string, v any) error
PutWithTTL(k string, v any, duration time.Duration) error
Get(k string) (any, error)
Delete(k string) error
}
func NewKVHandler(cfg *viper.Viper) (KVHandler, error) {
return NewEtcdHandler(cfg)
}

144
utils/storage/sessions.go Normal file
View File

@@ -0,0 +1,144 @@
package storage
import (
"context"
"encoding/base32"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
// Amount of time for cookies/kv keys to expire.
var sessionExpire = 86400 * 30
type SessionStore struct {
KV KVHandler
Codecs []securecookie.Codec
options *sessions.Options // default configuration
DefaultMaxAge int // default TiKV TTL for a MaxAge == 0 session
//maxLength int
keyPrefix string
//serializer SessionSerializer
}
func NewSessionStore(client KVHandler, keyPairs ...[]byte) *SessionStore {
es := &SessionStore{
KV: client,
Codecs: securecookie.CodecsFromPairs(keyPairs...),
options: &sessions.Options{
Path: "/",
MaxAge: sessionExpire,
},
DefaultMaxAge: 60 * 20, // 20 minutes seems like a reasonable default
//maxLength: 4096,
keyPrefix: "session/",
}
return es
}
func (s *SessionStore) Get(r *http.Request, name string) (*sessions.Session, error) {
// session := sessions.NewSession(s, name)
// ok, err := s.load(r.Context(), session)
// if !(err == nil && ok) {
// if err == nil {
// err = errors.New("key does not exist")
// }
// }
// return session, err
return sessions.GetRegistry(r).Get(s, name)
}
func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
options := *s.options
session.Options = &options
session.IsNew = true
if c, errCookie := r.Cookie(name); errCookie == nil {
err := securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
if err != nil {
return session, err
}
ok, err := s.load(r.Context(), session)
session.IsNew = !(err == nil && ok) // not new if no error and data available
}
return session, nil
}
// Save adds a single session to the response.
func (s *SessionStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
// Marked for deletion.
if session.Options.MaxAge <= 0 {
if err := s.delete(r.Context(), session); err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
} else {
// Build an alphanumeric key for the kv store.
if session.ID == "" {
session.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=")
}
if err := s.save(r.Context(), session); err != nil {
return err
}
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...)
if err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
}
return nil
}
// save stores the session in kv.
func (s *SessionStore) save(ctx context.Context, session *sessions.Session) error {
m := make(map[string]interface{}, len(session.Values))
for k, v := range session.Values {
ks, ok := k.(string)
if !ok {
err := fmt.Errorf("non-string key value, cannot serialize session: %v", k)
return err
}
m[ks] = v
}
age := session.Options.MaxAge
if age == 0 {
age = s.DefaultMaxAge
}
return s.KV.PutWithTTL(s.keyPrefix+session.ID, m, time.Duration(age)*time.Second)
}
// load reads the session from kv store.
// returns true if there is a sessoin data in DB
func (s *SessionStore) load(ctx context.Context, session *sessions.Session) (bool, error) {
data, err := s.KV.Get(s.keyPrefix + session.ID)
if err != nil {
return false, err
}
if data == nil && err == nil {
return false, errors.New("key does not exist")
}
for k, v := range data.(map[string]any) {
session.Values[k] = v
}
return true, nil
}
// delete removes keys from tikv if MaxAge<0
func (s *SessionStore) delete(ctx context.Context, session *sessions.Session) error {
if err := s.KV.Delete(s.keyPrefix + session.ID); err != nil {
return err
}
return nil
}