groups-management/storage/postgresql.go

389 lines
11 KiB
Go

package storage
import (
"ariga.io/atlas/sql/postgres"
"ariga.io/atlas/sql/schema"
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/lib/pq"
_ "github.com/lib/pq"
"github.com/spf13/viper"
"os"
"strconv"
)
type PostgresqlStorage struct {
DbConnection *sql.DB
Tables map[string]string
}
func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) {
var (
host = cfg.GetString("storage.db.psql.host")
port = cfg.GetString("storage.db.psql.port")
user = cfg.GetString("storage.db.psql.user")
password = cfg.GetString("storage.db.psql.password")
dbname = cfg.GetString("storage.db.psql.dbname")
sslmode = cfg.GetString("storage.db.psql.sslmode")
pg_schema = cfg.GetString("storage.db.psql.schema")
pgtables_group_members = cfg.GetString("storage.db.psql.tables.group_members")
pgtables_groups = cfg.GetString("storage.db.psql.tables.groups")
)
fmt.Println("host", host)
fmt.Println("port", port)
fmt.Println("user", user)
fmt.Println("pwd", password)
fmt.Println("dbname", dbname)
fmt.Println("sslmode", sslmode)
fmt.Println("pg_schena", pg_schema)
fmt.Println("group|_members", pgtables_group_members)
fmt.Println("groups", pgtables_groups)
portInt, _ := strconv.Atoi(port)
psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, portInt,
user, password, dbname, sslmode)
db, err := sql.Open("postgres", psqlconn)
if err != nil {
fmt.Println("error", err)
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed")
}
err = db.Ping()
if err != nil {
fmt.Println(err)
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed")
}
return PostgresqlStorage{
DbConnection: db,
Tables: map[string]string{
"groups": fmt.Sprintf("%s.%s", pg_schema, pgtables_groups),
"group_members": fmt.Sprintf("%s.%s", pg_schema, pgtables_group_members),
},
}, nil
}
func (psql PostgresqlStorage) CreateGroup(group Group) error {
dataJson, err := json.Marshal(group.Data)
if err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (id, namespace, members, data) VALUES ($1, $2, $3, $4)",
psql.Tables["groups"])
_, err = psql.DbConnection.Exec(query, group.ID, group.Namespace, pq.Array(group.Members), dataJson)
if err != nil {
return err
}
return nil
}
func (psql PostgresqlStorage) GetGroup(id string) (*Group, error) {
var (
data []byte
members pq.StringArray
)
group := &Group{}
req := fmt.Sprintf(`SELECT id, namespace, members, data
FROM %s WHERE id = $1`, psql.Tables["groups"])
err := psql.DbConnection.QueryRow(req, id).Scan(
&group.ID,
&group.Namespace,
&members,
&data)
if err != nil {
return nil, fmt.Errorf("psql select group query failed : %s", err)
}
group.Members = members
err = json.Unmarshal(data, &group.Data)
if err != nil {
return nil, err
}
return group, nil
}
func (psql PostgresqlStorage) GetGroups(namespaces []string) ([]Group, error) {
var groups []Group
query := fmt.Sprintf(`SELECT id, namespace, members, data FROM %s WHERE namespace = ANY($1)`, psql.Tables["groups"])
rows, err := psql.DbConnection.Query(query, pq.Array(namespaces))
if err != nil {
return nil, fmt.Errorf("psql select groups query failed : %s", err)
}
defer rows.Close()
for rows.Next() {
var (
id string
ns string
data []byte
members pq.StringArray
)
err := rows.Scan(&id, &ns, &members, &data)
if err != nil {
return nil, fmt.Errorf("psql scan groups row failed : %s", err)
}
var groupData map[string]interface{}
err = json.Unmarshal(data, &groupData)
if err != nil {
return nil, fmt.Errorf("psql unmarshal group data failed : %s", err)
}
groups = append(groups, Group{
ID: id,
Namespace: ns,
Members: members,
Data: groupData,
})
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("psql iterate groups rows failed : %s", err)
}
return groups, nil
}
func (psql PostgresqlStorage) GetGroupsByIds(ids []string) ([]Group, error) {
for i := len(ids) - 1; i >= 0; i-- {
if !isValidUUID(ids[i]) {
ids = append(ids[:i], ids[i+1:]...)
}
}
groups := make([]Group, 0)
query := fmt.Sprintf(`SELECT id, namespace, members, data FROM %s WHERE id = ANY($1)`, psql.Tables["groups"])
rows, err := psql.DbConnection.Query(query, pq.Array(ids))
if err != nil {
return nil, fmt.Errorf("psql select groups query failed: %s", err)
}
defer rows.Close()
for rows.Next() {
var (
group Group
data []byte
members pq.StringArray
)
if err := rows.Scan(&group.ID, &group.Namespace, &members, &data); err != nil {
return nil, fmt.Errorf("psql select groups query failed: %s", err)
}
group.Members = members
if err := json.Unmarshal(data, &group.Data); err != nil {
return nil, err
}
groups = append(groups, group)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("psql select groups query failed: %s", err)
}
return groups, nil
}
func (psql PostgresqlStorage) UpdateGroup(group Group) error {
members := pq.StringArray(group.Members)
data, err := json.Marshal(group.Data)
if err != nil {
return fmt.Errorf("failed to serialize data field: %s", err)
}
_, err = psql.DbConnection.Exec(fmt.Sprintf(`UPDATE %s SET namespace=$2, members=$3, data=$4 WHERE id=$1`,
psql.Tables["groups"]), group.ID, group.Namespace, members, data)
if err != nil {
return fmt.Errorf("psql update group query failed : %s", err)
}
return nil
}
func (psql PostgresqlStorage) CreateGroupMember(member GroupMember) error {
dataJson, err := json.Marshal(member.Data)
if err != nil {
return err
}
_, err = psql.DbConnection.Exec(fmt.Sprintf("INSERT INTO %s (id, member_id, group_id, data)"+
" VALUES ($1, $2, $3, $4)",
psql.Tables["group_members"]), member.ID, member.Memberid, member.Groupid, dataJson)
if err != nil {
return err
}
return nil
}
func (psql PostgresqlStorage) GetGroupMember(id string) (*GroupMember, error) {
member := &GroupMember{}
req := fmt.Sprintf(`SELECT id, member_id, group_id, data
FROM %s WHERE id = $1`, psql.Tables["group_members"])
var data []byte
err := psql.DbConnection.QueryRow(req, id).Scan(
&member.ID,
&member.Memberid,
&member.Groupid,
&data,
)
if err != nil {
return nil, fmt.Errorf("psql select group member query failed: %s", err)
}
err = json.Unmarshal(data, &member.Data)
if err != nil {
return nil, err
}
return member, nil
}
func (psql PostgresqlStorage) GetGroupsMember(namespaces []string) ([]GroupMember, error) {
groupMembers := make([]GroupMember, 0)
if len(namespaces) == 0 {
return groupMembers, nil
}
var placeholders []string
for i := range namespaces {
placeholders = append(placeholders, fmt.Sprintf("$%d", i+1))
}
query := fmt.Sprintf(`
SELECT gm.id, gm.member_id, gm.group_id, gm.data
FROM %s gm
INNER JOIN %s g ON g.id = gm.group_id AND g.namespace = ANY($1)
`, psql.Tables["group_members"], psql.Tables["groups"])
rows, err := psql.DbConnection.Query(query, pq.Array(namespaces))
if err != nil {
return nil, fmt.Errorf("psql select group members query failed: %s", err)
}
defer rows.Close()
for rows.Next() {
var (
groupMember GroupMember
data []byte
)
if err := rows.Scan(&groupMember.ID, &groupMember.Memberid, &groupMember.Groupid, &data); err != nil {
return nil, fmt.Errorf("psql select group members query failed: %s", err)
}
if err := json.Unmarshal(data, &groupMember.Data); err != nil {
return nil, err
}
groupMembers = append(groupMembers, groupMember)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("psql select group members query failed: %s", err)
}
return groupMembers, nil
}
func (psql PostgresqlStorage) GetGroupsMemberByIds(ids []string) ([]GroupMember, error) {
for i := len(ids) - 1; i >= 0; i-- {
if !isValidUUID(ids[i]) {
ids = append(ids[:i], ids[i+1:]...)
}
}
groupMembers := make([]GroupMember, 0)
req := fmt.Sprintf(`SELECT id, member_id, group_id, data
FROM %s WHERE group_id = ANY($1)`, psql.Tables["group_members"])
rows, err := psql.DbConnection.Query(req, pq.Array(ids))
if err != nil {
return nil, fmt.Errorf("psql select group members query failed: %s", err)
}
defer rows.Close()
for rows.Next() {
var (
groupMember GroupMember
data []byte
)
if err := rows.Scan(&groupMember.ID, &groupMember.Memberid, &groupMember.Groupid, &data); err != nil {
return nil, fmt.Errorf("psql select group members query failed: %s", err)
}
if err := json.Unmarshal(data, &groupMember.Data); err != nil {
return nil, err
}
groupMembers = append(groupMembers, groupMember)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("psql select group members query failed: %s", err)
}
return groupMembers, nil
}
func (psql PostgresqlStorage) UpdateGroupMember(groupMember GroupMember) error {
data, err := json.Marshal(groupMember.Data)
if err != nil {
return fmt.Errorf("failed to marshal data field: %s", err)
}
req := fmt.Sprintf(`UPDATE %s SET member_id=$1, group_id=$2, data=$3 WHERE id=$4`, psql.Tables["group_members"])
res, err := psql.DbConnection.Exec(req, groupMember.Memberid, groupMember.Groupid, data, groupMember.ID)
if err != nil {
return fmt.Errorf("psql update group member query failed: %s", err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected by update query: %s", err)
}
if rowsAffected == 0 {
return fmt.Errorf("no rows were updated for group member with id %s", groupMember.ID)
}
return nil
}
func (psql PostgresqlStorage) DeleteGroupMember(id string) error {
res, err := psql.DbConnection.Exec(fmt.Sprintf(`DELETE FROM %s WHERE id=$1`, psql.Tables["group_members"]), id)
if err != nil {
return fmt.Errorf("psql delete group member query failed: %s", err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected by delete query: %s", err)
}
if rowsAffected == 0 {
return fmt.Errorf("no rows were deleted for group member with id %s", id)
}
return nil
}
func (psql PostgresqlStorage) Migrate() error {
ctx := context.Background()
driver, err := postgres.Open(psql.DbConnection)
if err != nil {
return err
}
existing, err := driver.InspectRealm(ctx, nil)
if err != nil {
return err
}
var desired schema.Realm
hcl, err := os.ReadFile("postgresql/schema.hcl")
if err != nil {
return err
}
err = postgres.EvalHCLBytes(hcl, &desired, nil)
if err != nil {
return err
}
diff, err := driver.RealmDiff(existing, &desired)
if err != nil {
return err
}
err = driver.ApplyChanges(ctx, diff)
if err != nil {
return err
}
return nil
}
func isValidUUID(s string) bool {
_, err := uuid.Parse(s)
return err == nil
}