package storage import ( "ariga.io/atlas/sql/postgres" "ariga.io/atlas/sql/schema" "context" "database/sql" "encoding/json" "fmt" "github.com/google/uuid" "github.com/lib/pq" _ "github.com/lib/pq" "github.com/spf13/viper" "os" "strconv" ) type PostgresqlStorage struct { DbConnection *sql.DB Tables map[string]string } func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) { var ( host = cfg.GetString("storage.db.psql.host") port = cfg.GetString("storage.db.psql.port") user = cfg.GetString("storage.db.psql.user") password = cfg.GetString("storage.db.psql.password") dbname = cfg.GetString("storage.db.psql.dbname") sslmode = cfg.GetString("storage.db.psql.sslmode") pg_schema = cfg.GetString("storage.db.psql.schema") pgtables_group_members = cfg.GetString("storage.db.psql.tables.group_members") pgtables_groups = cfg.GetString("storage.db.psql.tables.groups") ) fmt.Println("host", host) fmt.Println("port", port) fmt.Println("user", user) fmt.Println("pwd", password) fmt.Println("dbname", dbname) fmt.Println("sslmode", sslmode) fmt.Println("pg_schena", pg_schema) fmt.Println("group|_members", pgtables_group_members) fmt.Println("groups", pgtables_groups) portInt, _ := strconv.Atoi(port) psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, portInt, user, password, dbname, sslmode) db, err := sql.Open("postgres", psqlconn) if err != nil { fmt.Println("error", err) return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed") } err = db.Ping() if err != nil { fmt.Println(err) return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed") } return PostgresqlStorage{ DbConnection: db, Tables: map[string]string{ "groups": fmt.Sprintf("%s.%s", pg_schema, pgtables_groups), "group_members": fmt.Sprintf("%s.%s", pg_schema, pgtables_group_members), }, }, nil } func (psql PostgresqlStorage) CreateGroup(group Group) error { dataJson, err := json.Marshal(group.Data) if err != nil { return err } query := fmt.Sprintf("INSERT INTO %s (id, namespace, members, data) VALUES ($1, $2, $3, $4)", psql.Tables["groups"]) _, err = psql.DbConnection.Exec(query, group.ID, group.Namespace, pq.Array(group.Members), dataJson) if err != nil { return err } return nil } func (psql PostgresqlStorage) GetGroup(id string) (*Group, error) { var ( data []byte members pq.StringArray ) group := &Group{} req := fmt.Sprintf(`SELECT id, namespace, members, data FROM %s WHERE id = $1`, psql.Tables["groups"]) err := psql.DbConnection.QueryRow(req, id).Scan( &group.ID, &group.Namespace, &members, &data) if err != nil { return nil, fmt.Errorf("psql select group query failed : %s", err) } group.Members = members err = json.Unmarshal(data, &group.Data) if err != nil { return nil, err } return group, nil } func (psql PostgresqlStorage) GetGroups(namespaces []string) ([]Group, error) { var groups []Group query := fmt.Sprintf(`SELECT id, namespace, members, data FROM %s WHERE namespace = ANY($1)`, psql.Tables["groups"]) rows, err := psql.DbConnection.Query(query, pq.Array(namespaces)) if err != nil { return nil, fmt.Errorf("psql select groups query failed : %s", err) } defer rows.Close() for rows.Next() { var ( id string ns string data []byte members pq.StringArray ) err := rows.Scan(&id, &ns, &members, &data) if err != nil { return nil, fmt.Errorf("psql scan groups row failed : %s", err) } var groupData map[string]interface{} err = json.Unmarshal(data, &groupData) if err != nil { return nil, fmt.Errorf("psql unmarshal group data failed : %s", err) } groups = append(groups, Group{ ID: id, Namespace: ns, Members: members, Data: groupData, }) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("psql iterate groups rows failed : %s", err) } return groups, nil } func (psql PostgresqlStorage) GetGroupsByIds(ids []string) ([]Group, error) { for i := len(ids) - 1; i >= 0; i-- { if !isValidUUID(ids[i]) { ids = append(ids[:i], ids[i+1:]...) } } groups := make([]Group, 0) query := fmt.Sprintf(`SELECT id, namespace, members, data FROM %s WHERE id = ANY($1)`, psql.Tables["groups"]) rows, err := psql.DbConnection.Query(query, pq.Array(ids)) if err != nil { return nil, fmt.Errorf("psql select groups query failed: %s", err) } defer rows.Close() for rows.Next() { var ( group Group data []byte members pq.StringArray ) if err := rows.Scan(&group.ID, &group.Namespace, &members, &data); err != nil { return nil, fmt.Errorf("psql select groups query failed: %s", err) } group.Members = members if err := json.Unmarshal(data, &group.Data); err != nil { return nil, err } groups = append(groups, group) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("psql select groups query failed: %s", err) } return groups, nil } func (psql PostgresqlStorage) UpdateGroup(group Group) error { members := pq.StringArray(group.Members) data, err := json.Marshal(group.Data) if err != nil { return fmt.Errorf("failed to serialize data field: %s", err) } _, err = psql.DbConnection.Exec(fmt.Sprintf(`UPDATE %s SET namespace=$2, members=$3, data=$4 WHERE id=$1`, psql.Tables["groups"]), group.ID, group.Namespace, members, data) if err != nil { return fmt.Errorf("psql update group query failed : %s", err) } return nil } func (psql PostgresqlStorage) CreateGroupMember(member GroupMember) error { dataJson, err := json.Marshal(member.Data) if err != nil { return err } _, err = psql.DbConnection.Exec(fmt.Sprintf("INSERT INTO %s (id, member_id, group_id, data)"+ " VALUES ($1, $2, $3, $4)", psql.Tables["group_members"]), member.ID, member.Memberid, member.Groupid, dataJson) if err != nil { return err } return nil } func (psql PostgresqlStorage) GetGroupMember(id string) (*GroupMember, error) { member := &GroupMember{} req := fmt.Sprintf(`SELECT id, member_id, group_id, data FROM %s WHERE id = $1`, psql.Tables["group_members"]) var data []byte err := psql.DbConnection.QueryRow(req, id).Scan( &member.ID, &member.Memberid, &member.Groupid, &data, ) if err != nil { return nil, fmt.Errorf("psql select group member query failed: %s", err) } err = json.Unmarshal(data, &member.Data) if err != nil { return nil, err } return member, nil } func (psql PostgresqlStorage) GetGroupsMember(namespaces []string) ([]GroupMember, error) { groupMembers := make([]GroupMember, 0) if len(namespaces) == 0 { return groupMembers, nil } var placeholders []string for i := range namespaces { placeholders = append(placeholders, fmt.Sprintf("$%d", i+1)) } query := fmt.Sprintf(` SELECT gm.id, gm.member_id, gm.group_id, gm.data FROM %s gm INNER JOIN %s g ON g.id = gm.group_id AND g.namespace = ANY($1) `, psql.Tables["group_members"], psql.Tables["groups"]) rows, err := psql.DbConnection.Query(query, pq.Array(namespaces)) if err != nil { return nil, fmt.Errorf("psql select group members query failed: %s", err) } defer rows.Close() for rows.Next() { var ( groupMember GroupMember data []byte ) if err := rows.Scan(&groupMember.ID, &groupMember.Memberid, &groupMember.Groupid, &data); err != nil { return nil, fmt.Errorf("psql select group members query failed: %s", err) } if err := json.Unmarshal(data, &groupMember.Data); err != nil { return nil, err } groupMembers = append(groupMembers, groupMember) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("psql select group members query failed: %s", err) } return groupMembers, nil } func (psql PostgresqlStorage) GetGroupsMemberByIds(ids []string) ([]GroupMember, error) { for i := len(ids) - 1; i >= 0; i-- { if !isValidUUID(ids[i]) { ids = append(ids[:i], ids[i+1:]...) } } groupMembers := make([]GroupMember, 0) req := fmt.Sprintf(`SELECT id, member_id, group_id, data FROM %s WHERE group_id = ANY($1)`, psql.Tables["group_members"]) rows, err := psql.DbConnection.Query(req, pq.Array(ids)) if err != nil { return nil, fmt.Errorf("psql select group members query failed: %s", err) } defer rows.Close() for rows.Next() { var ( groupMember GroupMember data []byte ) if err := rows.Scan(&groupMember.ID, &groupMember.Memberid, &groupMember.Groupid, &data); err != nil { return nil, fmt.Errorf("psql select group members query failed: %s", err) } if err := json.Unmarshal(data, &groupMember.Data); err != nil { return nil, err } groupMembers = append(groupMembers, groupMember) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("psql select group members query failed: %s", err) } return groupMembers, nil } func (psql PostgresqlStorage) UpdateGroupMember(groupMember GroupMember) error { data, err := json.Marshal(groupMember.Data) if err != nil { return fmt.Errorf("failed to marshal data field: %s", err) } req := fmt.Sprintf(`UPDATE %s SET member_id=$1, group_id=$2, data=$3 WHERE id=$4`, psql.Tables["group_members"]) res, err := psql.DbConnection.Exec(req, groupMember.Memberid, groupMember.Groupid, data, groupMember.ID) if err != nil { return fmt.Errorf("psql update group member query failed: %s", err) } rowsAffected, err := res.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected by update query: %s", err) } if rowsAffected == 0 { return fmt.Errorf("no rows were updated for group member with id %s", groupMember.ID) } return nil } func (psql PostgresqlStorage) DeleteGroupMember(id string) error { res, err := psql.DbConnection.Exec(fmt.Sprintf(`DELETE FROM %s WHERE id=$1`, psql.Tables["group_members"]), id) if err != nil { return fmt.Errorf("psql delete group member query failed: %s", err) } rowsAffected, err := res.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected by delete query: %s", err) } if rowsAffected == 0 { return fmt.Errorf("no rows were deleted for group member with id %s", id) } return nil } func (psql PostgresqlStorage) Migrate() error { ctx := context.Background() driver, err := postgres.Open(psql.DbConnection) if err != nil { return err } existing, err := driver.InspectRealm(ctx, nil) if err != nil { return err } var desired schema.Realm hcl, err := os.ReadFile("postgresql/schema.hcl") if err != nil { return err } err = postgres.EvalHCLBytes(hcl, &desired, nil) if err != nil { return err } diff, err := driver.RealmDiff(existing, &desired) if err != nil { return err } err = driver.ApplyChanges(ctx, diff) if err != nil { return err } return nil } func isValidUUID(s string) bool { _, err := uuid.Parse(s) return err == nil }