458 lines
12 KiB
Go
458 lines
12 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"ariga.io/atlas/sql/postgres"
|
|
"ariga.io/atlas/sql/schema"
|
|
"github.com/lib/pq"
|
|
_ "github.com/lib/pq"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
type PostgresqlStorage struct {
|
|
DbConnection *sql.DB
|
|
Schema string
|
|
Tables map[string]string
|
|
}
|
|
|
|
func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) {
|
|
var (
|
|
host = cfg.GetString("storage.db.psql.host")
|
|
port = cfg.GetString("storage.db.psql.port")
|
|
user = cfg.GetString("storage.db.psql.user")
|
|
password = cfg.GetString("storage.db.psql.password")
|
|
dbname = cfg.GetString("storage.db.psql.dbname")
|
|
sslmode = cfg.GetString("storage.db.psql.sslmode")
|
|
pg_schema = cfg.GetString("storage.db.psql.schema")
|
|
pgtables_accounts = cfg.GetString("storage.db.psql.tables.accounts")
|
|
pgtables_accounts_auth_local = cfg.GetString("storage.db.psql.tables.accounts_auth_local")
|
|
)
|
|
portInt, _ := strconv.Atoi(port)
|
|
psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, portInt, user, password, dbname, sslmode)
|
|
db, err := sql.Open("postgres", psqlconn)
|
|
if err != nil {
|
|
fmt.Println("error", err)
|
|
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed")
|
|
}
|
|
err = db.Ping()
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed")
|
|
}
|
|
return PostgresqlStorage{
|
|
DbConnection: db,
|
|
Schema: pg_schema,
|
|
Tables: map[string]string{
|
|
"accounts": fmt.Sprintf("%s.%s", pg_schema, pgtables_accounts),
|
|
"accounts_auth_local": fmt.Sprintf("%s.%s", pg_schema, pgtables_accounts_auth_local),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) GetAccount(id string) (*Account, error) {
|
|
var (
|
|
data, metadata, emailValidation, phoneValidation []byte
|
|
)
|
|
account := &Account{}
|
|
|
|
var username, password, email, phonenumber *string
|
|
|
|
req := fmt.Sprintf(`SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
|
|
FROM %s a
|
|
LEFT JOIN %s auth ON id = account_id WHERE id = $1`, psql.Tables["accounts"], psql.Tables["accounts_auth_local"])
|
|
err := psql.DbConnection.QueryRow(req, id).Scan(&account.ID,
|
|
&account.Namespace,
|
|
&data,
|
|
&metadata,
|
|
&username,
|
|
&password,
|
|
&email,
|
|
&emailValidation,
|
|
&phonenumber,
|
|
&phoneValidation)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("psql select account query failed : %s", err)
|
|
}
|
|
|
|
err = json.Unmarshal(data, &account.Data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(metadata, &account.Metadata)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if password != nil || username != nil || email != nil || phonenumber != nil {
|
|
account.Authentication.Local = &LocalAuth{
|
|
Username: username,
|
|
Password: *password,
|
|
Email: email,
|
|
PhoneNumber: phonenumber,
|
|
}
|
|
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return account, nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) LocalAuthentication(namespace string, username *string, email *string, phone_number *string) (*Account, error) {
|
|
account := &Account{}
|
|
var (
|
|
data, metadata, emailValidation, phoneValidation []byte
|
|
)
|
|
req := fmt.Sprintf(`SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
|
|
FROM %s INNER JOIN %s ON id = account_id
|
|
WHERE account_namespace = '%s'`, psql.Tables["accounts"], psql.Tables["accounts_auth_local"], namespace)
|
|
|
|
if username != nil && *username != "" {
|
|
req += fmt.Sprintf(` AND username = '%s'`, *username)
|
|
}
|
|
if email != nil && *email != "" {
|
|
req += fmt.Sprintf(` AND email = '%s'`, *email)
|
|
}
|
|
if phone_number != nil && *phone_number != "" {
|
|
req += fmt.Sprintf(` AND phone_number = '%s'`, *phone_number)
|
|
}
|
|
req += ";"
|
|
fmt.Println(req)
|
|
account.Authentication.Local = &LocalAuth{}
|
|
err := psql.DbConnection.QueryRow(req).Scan(
|
|
&account.ID,
|
|
&account.Namespace, &data, &metadata,
|
|
&account.Authentication.Local.Username,
|
|
&account.Authentication.Local.Password,
|
|
&account.Authentication.Local.Email,
|
|
&emailValidation,
|
|
&account.Authentication.Local.PhoneNumber,
|
|
&phoneValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(metadata, &account.Metadata)
|
|
if err != nil {
|
|
fmt.Println("one")
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(data, &account.Data)
|
|
if err != nil {
|
|
fmt.Println("two")
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
|
|
if err != nil {
|
|
fmt.Println("three")
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
|
|
if err != nil {
|
|
fmt.Println("four")
|
|
return nil, err
|
|
}
|
|
return account, nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) GetAccounts(namespaces []string) ([]Account, error) {
|
|
var accounts []Account
|
|
namespacesStr := "'" + strings.Join(namespaces, "', '") + "'"
|
|
query := `
|
|
SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
|
|
FROM %s LEFT JOIN %s ON id = account_id
|
|
WHERE namespace IN (%s)
|
|
`
|
|
query = fmt.Sprintf(query, psql.Tables["accounts"], psql.Tables["accounts_auth_local"], namespacesStr)
|
|
rows, err := psql.DbConnection.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var account Account
|
|
var dataBytes []byte
|
|
var metadataBytes []byte
|
|
var emailValidation []byte
|
|
var phoneValidation []byte
|
|
var username, password, email, phonenumber *string
|
|
err := rows.Scan(
|
|
&account.ID,
|
|
&account.Namespace,
|
|
&dataBytes,
|
|
&metadataBytes,
|
|
&username,
|
|
&password,
|
|
&email,
|
|
&emailValidation,
|
|
&phonenumber,
|
|
&phoneValidation,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = json.Unmarshal(dataBytes, &account.Data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = json.Unmarshal(metadataBytes, &account.Metadata)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if password != nil || username != nil || email != nil || phonenumber != nil {
|
|
account.Authentication.Local = &LocalAuth{
|
|
Username: username,
|
|
Password: *password,
|
|
Email: email,
|
|
PhoneNumber: phonenumber,
|
|
}
|
|
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
accounts = append(accounts, account)
|
|
}
|
|
return accounts, nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) GetAccountsByIds(accountids []string) ([]Account, error) {
|
|
var accounts []Account
|
|
|
|
query := `
|
|
SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
|
|
FROM %s
|
|
LEFT JOIN %s ON id = account_id
|
|
WHERE id = ANY ($1)
|
|
`
|
|
query = fmt.Sprintf(query, psql.Tables["accounts"], psql.Tables["accounts_auth_local"]) //, accountIdsStr)
|
|
rows, err := psql.DbConnection.Query(query, pq.Array(accountids))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var account Account
|
|
var dataBytes []byte
|
|
var metadataBytes []byte
|
|
var emailValidation []byte
|
|
var phoneValidation []byte
|
|
var username, password, email, phonenumber *string
|
|
err := rows.Scan(
|
|
&account.ID,
|
|
&account.Namespace,
|
|
&dataBytes,
|
|
&metadataBytes,
|
|
&username,
|
|
&password,
|
|
&email,
|
|
&emailValidation,
|
|
&phonenumber,
|
|
&phoneValidation,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = json.Unmarshal(dataBytes, &account.Data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = json.Unmarshal(metadataBytes, &account.Metadata)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if password != nil || username != nil || email != nil || phonenumber != nil {
|
|
account.Authentication.Local = &LocalAuth{
|
|
Username: username,
|
|
Password: *password,
|
|
Email: email,
|
|
PhoneNumber: phonenumber,
|
|
}
|
|
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
accounts = append(accounts, account)
|
|
}
|
|
return accounts, nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) CreateAccount(account Account) error {
|
|
dataAccountJson, err := json.Marshal(account.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
metadataAccountJson, err := json.Marshal(account.Metadata)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tx, err := psql.DbConnection.BeginTx(context.Background(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
insertAccountStmt, err := tx.Prepare(fmt.Sprintf("INSERT INTO %s (id, namespace, data, metadata) VALUES ($1, $2, $3, $4)", psql.Tables["accounts"]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer insertAccountStmt.Close()
|
|
_, err = insertAccountStmt.Exec(account.ID, account.Namespace, dataAccountJson, metadataAccountJson)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if account.Authentication.Local != nil {
|
|
|
|
emailValidationJson, err := json.Marshal(account.Authentication.Local.EmailValidation)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
phoneValidationJson, err := json.Marshal(account.Authentication.Local.PhoneNumberValidation)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
insertAccountAuthStmt, err := tx.Prepare(fmt.Sprintf(`INSERT INTO %s (account_id, account_namespace, username, password, email, email_validation, phone_number, phone_number_validation)
|
|
VALUES($1, $2, $3, $4, $5, $6, $7, $8);`, psql.Tables["accounts_auth_local"]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer insertAccountAuthStmt.Close()
|
|
_, err = insertAccountAuthStmt.Exec(account.ID,
|
|
account.Namespace,
|
|
account.Authentication.Local.Username,
|
|
account.Authentication.Local.Password,
|
|
account.Authentication.Local.Email,
|
|
emailValidationJson,
|
|
account.Authentication.Local.PhoneNumber,
|
|
phoneValidationJson)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) UpdateAccount(account Account) error {
|
|
|
|
tx, err := psql.DbConnection.BeginTx(context.Background(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
dataAccountJson, err := json.Marshal(account.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
metadataAccountJson, err := json.Marshal(account.Metadata)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req := fmt.Sprintf("UPDATE %s SET namespace=$1, data=$2, metadata=$3 where id= $4", psql.Tables["accounts"])
|
|
_, err = tx.Exec(req, account.Namespace, dataAccountJson, metadataAccountJson, account.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if account.Authentication.Local != nil {
|
|
emailValidationJson, err := json.Marshal(account.Authentication.Local.EmailValidation)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
phoneValidationJson, err := json.Marshal(account.Authentication.Local.PhoneNumberValidation)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req = fmt.Sprintf(`UPDATE %s
|
|
SET username = $1, password = $2, email = $3, email_validation = $4,
|
|
phone_number = $5,phone_number_validation = $6 where account_id = $7`, psql.Tables["accounts_auth_local"])
|
|
_, err = tx.Exec(req, account.Authentication.Local.Username,
|
|
account.Authentication.Local.Password,
|
|
account.Authentication.Local.Email,
|
|
emailValidationJson,
|
|
account.Authentication.Local.PhoneNumber,
|
|
phoneValidationJson, account.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (psql PostgresqlStorage) Migrate() error {
|
|
ctx := context.Background()
|
|
driver, err := postgres.Open(psql.DbConnection)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
existing, err := driver.InspectRealm(ctx, &schema.InspectRealmOption{Schemas: []string{psql.Schema}})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var desired schema.Realm
|
|
|
|
hcl, err := os.ReadFile("postgresql/schema.hcl")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = postgres.EvalHCLBytes(hcl, &desired, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
diff, err := driver.RealmDiff(existing, &desired)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = driver.ApplyChanges(ctx, diff)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|