384 lines
11 KiB
Go
384 lines
11 KiB
Go
|
package storage
|
||
|
|
||
|
import (
|
||
|
"ariga.io/atlas/sql/postgres"
|
||
|
"ariga.io/atlas/sql/schema"
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"github.com/lib/pq"
|
||
|
_ "github.com/lib/pq"
|
||
|
"github.com/spf13/viper"
|
||
|
"os"
|
||
|
"strconv"
|
||
|
)
|
||
|
|
||
|
type PostgresqlStorage struct {
|
||
|
DbConnection *sql.DB
|
||
|
Tables map[string]string
|
||
|
}
|
||
|
|
||
|
func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) {
|
||
|
var (
|
||
|
host = cfg.GetString("storage.db.psql.host")
|
||
|
port = cfg.GetString("storage.db.psql.port")
|
||
|
user = cfg.GetString("storage.db.psql.user")
|
||
|
password = cfg.GetString("storage.db.psql.password")
|
||
|
dbname = cfg.GetString("storage.db.psql.dbname")
|
||
|
sslmode = cfg.GetString("storage.db.psql.sslmode")
|
||
|
pg_schema = cfg.GetString("storage.db.psql.schema")
|
||
|
pgtables_group_members = cfg.GetString("storage.db.psql.tables.group_members")
|
||
|
pgtables_groups = cfg.GetString("storage.db.psql.tables.groups")
|
||
|
)
|
||
|
portInt, _ := strconv.Atoi(port)
|
||
|
psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, portInt,
|
||
|
user, password, dbname, sslmode)
|
||
|
db, err := sql.Open("postgres", psqlconn)
|
||
|
if err != nil {
|
||
|
fmt.Println("error", err)
|
||
|
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed")
|
||
|
}
|
||
|
err = db.Ping()
|
||
|
if err != nil {
|
||
|
fmt.Println(err)
|
||
|
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed")
|
||
|
}
|
||
|
return PostgresqlStorage{
|
||
|
DbConnection: db,
|
||
|
Tables: map[string]string{
|
||
|
"groups": fmt.Sprintf("%s.%s", pg_schema, pgtables_groups),
|
||
|
"group_members": fmt.Sprintf("%s.%s", pg_schema, pgtables_group_members),
|
||
|
},
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) CreateGroup(group Group) error {
|
||
|
dataJson, err := json.Marshal(group.Data)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
insertGroupStmt, err := psql.DbConnection.Prepare(fmt.Sprintf("INSERT INTO %s (id, namespace, members, data)"+
|
||
|
" VALUES ($1, $2, $3, $4)",
|
||
|
psql.Tables["groups"]))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer insertGroupStmt.Close()
|
||
|
_, err = insertGroupStmt.Exec(group.ID, group.Namespace, pq.Array(group.Members), dataJson)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) GetGroup(id string) (*Group, error) {
|
||
|
var (
|
||
|
data []byte
|
||
|
members pq.StringArray
|
||
|
)
|
||
|
group := &Group{}
|
||
|
req := fmt.Sprintf(`SELECT id, namespace, members, data
|
||
|
FROM %s WHERE id = $1`, psql.Tables["groups"])
|
||
|
err := psql.DbConnection.QueryRow(req, id).Scan(
|
||
|
&group.ID,
|
||
|
&group.Namespace,
|
||
|
&members,
|
||
|
&data)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql select group query failed : %s", err)
|
||
|
}
|
||
|
group.Members = members
|
||
|
err = json.Unmarshal(data, &group.Data)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return group, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) GetGroups(namespaces []string) ([]Group, error) {
|
||
|
var groups []Group
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`SELECT id, namespace, members, data
|
||
|
FROM %s WHERE namespace = ANY($1)`, psql.Tables["groups"]))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql prepare select groups query failed : %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
rows, err := stmt.Query(pq.Array(namespaces))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql select groups query failed : %s", err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
for rows.Next() {
|
||
|
var (
|
||
|
id string
|
||
|
ns string
|
||
|
data []byte
|
||
|
members pq.StringArray
|
||
|
)
|
||
|
err := rows.Scan(&id, &ns, &members, &data)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql scan groups row failed : %s", err)
|
||
|
}
|
||
|
var groupData map[string]interface{}
|
||
|
err = json.Unmarshal(data, &groupData)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql unmarshal group data failed : %s", err)
|
||
|
}
|
||
|
groups = append(groups, Group{
|
||
|
ID: id,
|
||
|
Namespace: ns,
|
||
|
Members: members,
|
||
|
Data: groupData,
|
||
|
})
|
||
|
}
|
||
|
if err = rows.Err(); err != nil {
|
||
|
return nil, fmt.Errorf("psql iterate groups rows failed : %s", err)
|
||
|
}
|
||
|
return groups, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) GetGroupsByIds(ids []string) ([]Group, error) {
|
||
|
groups := make([]Group, 0)
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`SELECT id, namespace, members, data
|
||
|
FROM %s WHERE id = ANY($1)`, psql.Tables["groups"]))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql prepare select groups query failed : %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
rows, err := stmt.Query(pq.Array(ids))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql select groups query failed: %s", err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
for rows.Next() {
|
||
|
var (
|
||
|
group Group
|
||
|
data []byte
|
||
|
members pq.StringArray
|
||
|
)
|
||
|
if err := rows.Scan(&group.ID, &group.Namespace, &members, &data); err != nil {
|
||
|
return nil, fmt.Errorf("psql select groups query failed: %s", err)
|
||
|
}
|
||
|
group.Members = members
|
||
|
if err := json.Unmarshal(data, &group.Data); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
groups = append(groups, group)
|
||
|
}
|
||
|
|
||
|
if err := rows.Err(); err != nil {
|
||
|
return nil, fmt.Errorf("psql select groups query failed: %s", err)
|
||
|
}
|
||
|
return groups, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) UpdateGroup(group Group) error {
|
||
|
members := pq.StringArray(group.Members)
|
||
|
data, err := json.Marshal(group.Data)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to serialize data field: %s", err)
|
||
|
}
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`UPDATE %s SET namespace=$2, members=$3, data=$4 WHERE id=$1`,
|
||
|
psql.Tables["groups"]))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql prepare update group query failed : %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
|
||
|
_, err = stmt.Exec(group.ID, group.Namespace, members, data)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql update group query failed : %s", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) CreateGroupMember(member GroupMember) error {
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`INSERT INTO %s (id, member_id, group_id, data)
|
||
|
VALUES ($1, $2, $3, $4)`, psql.Tables["group_members"]))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql prepare insert group member query failed: %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
data, err := json.Marshal(member.Data)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = stmt.Exec(member.ID, member.Memberid, member.Groupid, data)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql insert group member query failed: %s", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) GetGroupMember(id string) (*GroupMember, error) {
|
||
|
member := &GroupMember{}
|
||
|
req := fmt.Sprintf(`SELECT id, member_id, group_id, data
|
||
|
FROM %s WHERE id = $1`, psql.Tables["group_members"])
|
||
|
var data []byte
|
||
|
err := psql.DbConnection.QueryRow(req, id).Scan(
|
||
|
&member.ID,
|
||
|
&member.Memberid,
|
||
|
&member.Groupid,
|
||
|
&data,
|
||
|
)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql select group member query failed: %s", err)
|
||
|
}
|
||
|
err = json.Unmarshal(data, &member.Data)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return member, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) GetGroupsMember(namespaces []string) ([]GroupMember, error) {
|
||
|
groupMembers := make([]GroupMember, 0)
|
||
|
if len(namespaces) == 0 {
|
||
|
return groupMembers, nil
|
||
|
}
|
||
|
var placeholders []string
|
||
|
for i := range namespaces {
|
||
|
placeholders = append(placeholders, fmt.Sprintf("$%d", i+1))
|
||
|
}
|
||
|
query := fmt.Sprintf(`
|
||
|
SELECT gm.id, gm.member_id, gm.group_id, gm.data
|
||
|
FROM %s gm
|
||
|
INNER JOIN %s g ON g.id = gm.group_id AND g.namespace = ANY($1)
|
||
|
`, psql.Tables["group_members"], psql.Tables["groups"])
|
||
|
|
||
|
rows, err := psql.DbConnection.Query(query, pq.Array(namespaces))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql select group members query failed: %s", err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
|
||
|
for rows.Next() {
|
||
|
var (
|
||
|
groupMember GroupMember
|
||
|
data []byte
|
||
|
)
|
||
|
if err := rows.Scan(&groupMember.ID, &groupMember.Memberid, &groupMember.Groupid, &data); err != nil {
|
||
|
return nil, fmt.Errorf("psql select group members query failed: %s", err)
|
||
|
}
|
||
|
if err := json.Unmarshal(data, &groupMember.Data); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
groupMembers = append(groupMembers, groupMember)
|
||
|
}
|
||
|
|
||
|
if err := rows.Err(); err != nil {
|
||
|
return nil, fmt.Errorf("psql select group members query failed: %s", err)
|
||
|
}
|
||
|
return groupMembers, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) GetGroupsMemberByIds(ids []string) ([]GroupMember, error) {
|
||
|
groupMembers := make([]GroupMember, 0)
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`SELECT id, member_id, group_id, data
|
||
|
FROM %s WHERE group_id = ANY($1)`, psql.Tables["group_members"]))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql prepare select group members query failed: %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
rows, err := stmt.Query(pq.Array(ids))
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("psql select group members query failed: %s", err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
for rows.Next() {
|
||
|
var (
|
||
|
groupMember GroupMember
|
||
|
data []byte
|
||
|
)
|
||
|
if err := rows.Scan(&groupMember.ID, &groupMember.Memberid, &groupMember.Groupid, &data); err != nil {
|
||
|
return nil, fmt.Errorf("psql select group members query failed: %s", err)
|
||
|
}
|
||
|
if err := json.Unmarshal(data, &groupMember.Data); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
groupMembers = append(groupMembers, groupMember)
|
||
|
}
|
||
|
if err := rows.Err(); err != nil {
|
||
|
return nil, fmt.Errorf("psql select group members query failed: %s", err)
|
||
|
}
|
||
|
return groupMembers, nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) UpdateGroupMember(groupMember GroupMember) error {
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`UPDATE %s SET member_id=$1, group_id=$2,
|
||
|
data=$3 WHERE id=$4`, psql.Tables["group_members"]))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql prepare update group member query failed : %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
|
||
|
data, err := json.Marshal(groupMember.Data)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to marshal data field: %s", err)
|
||
|
}
|
||
|
|
||
|
res, err := stmt.Exec(groupMember.Memberid, groupMember.Groupid, data, groupMember.ID)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql update group member query failed: %s", err)
|
||
|
}
|
||
|
|
||
|
rowsAffected, err := res.RowsAffected()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to get rows affected by update query: %s", err)
|
||
|
}
|
||
|
|
||
|
if rowsAffected == 0 {
|
||
|
return fmt.Errorf("no rows were updated for group member with id %s", groupMember.ID)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) DeleteGroupMember(id string) error {
|
||
|
stmt, err := psql.DbConnection.Prepare(fmt.Sprintf(`DELETE FROM %s WHERE id=$1`, psql.Tables["group_members"]))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("psql prepare delete group member query failed: %s", err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
if _, err := stmt.Exec(id); err != nil {
|
||
|
return fmt.Errorf("psql delete group member query failed: %s", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (psql PostgresqlStorage) Migrate() error {
|
||
|
ctx := context.Background()
|
||
|
driver, err := postgres.Open(psql.DbConnection)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
existing, err := driver.InspectRealm(ctx, nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
var desired schema.Realm
|
||
|
|
||
|
hcl, err := os.ReadFile("postgresql/schema.hcl")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err = postgres.EvalHCLBytes(hcl, &desired, nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
diff, err := driver.RealmDiff(existing, &desired)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err = driver.ApplyChanges(ctx, diff)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|