mobility-accounts/storage/postgresql.go

458 lines
12 KiB
Go

package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"ariga.io/atlas/sql/postgres"
"ariga.io/atlas/sql/schema"
"github.com/lib/pq"
_ "github.com/lib/pq"
"github.com/spf13/viper"
)
type PostgresqlStorage struct {
DbConnection *sql.DB
Schema string
Tables map[string]string
}
func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) {
var (
host = cfg.GetString("storage.db.psql.host")
port = cfg.GetString("storage.db.psql.port")
user = cfg.GetString("storage.db.psql.user")
password = cfg.GetString("storage.db.psql.password")
dbname = cfg.GetString("storage.db.psql.dbname")
sslmode = cfg.GetString("storage.db.psql.sslmode")
pg_schema = cfg.GetString("storage.db.psql.schema")
pgtables_accounts = cfg.GetString("storage.db.psql.tables.accounts")
pgtables_accounts_auth_local = cfg.GetString("storage.db.psql.tables.accounts_auth_local")
)
portInt, _ := strconv.Atoi(port)
psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, portInt, user, password, dbname, sslmode)
db, err := sql.Open("postgres", psqlconn)
if err != nil {
fmt.Println("error", err)
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed")
}
err = db.Ping()
if err != nil {
fmt.Println(err)
return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed")
}
return PostgresqlStorage{
DbConnection: db,
Schema: pg_schema,
Tables: map[string]string{
"accounts": fmt.Sprintf("%s.%s", pg_schema, pgtables_accounts),
"accounts_auth_local": fmt.Sprintf("%s.%s", pg_schema, pgtables_accounts_auth_local),
},
}, nil
}
func (psql PostgresqlStorage) GetAccount(id string) (*Account, error) {
var (
data, metadata, emailValidation, phoneValidation []byte
)
account := &Account{}
var username, password, email, phonenumber *string
req := fmt.Sprintf(`SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
FROM %s a
LEFT JOIN %s auth ON id = account_id WHERE id = $1`, psql.Tables["accounts"], psql.Tables["accounts_auth_local"])
err := psql.DbConnection.QueryRow(req, id).Scan(&account.ID,
&account.Namespace,
&data,
&metadata,
&username,
&password,
&email,
&emailValidation,
&phonenumber,
&phoneValidation)
if err != nil {
return nil, fmt.Errorf("psql select account query failed : %s", err)
}
err = json.Unmarshal(data, &account.Data)
if err != nil {
return nil, err
}
err = json.Unmarshal(metadata, &account.Metadata)
if err != nil {
return nil, err
}
if password != nil || username != nil || email != nil || phonenumber != nil {
account.Authentication.Local = &LocalAuth{
Username: username,
Password: *password,
Email: email,
PhoneNumber: phonenumber,
}
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
if err != nil {
return nil, err
}
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
if err != nil {
return nil, err
}
}
return account, nil
}
func (psql PostgresqlStorage) LocalAuthentication(namespace string, username *string, email *string, phone_number *string) (*Account, error) {
account := &Account{}
var (
data, metadata, emailValidation, phoneValidation []byte
)
req := fmt.Sprintf(`SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
FROM %s INNER JOIN %s ON id = account_id
WHERE account_namespace = '%s'`, psql.Tables["accounts"], psql.Tables["accounts_auth_local"], namespace)
if username != nil && *username != "" {
req += fmt.Sprintf(` AND username = '%s'`, *username)
}
if email != nil && *email != "" {
req += fmt.Sprintf(` AND email = '%s'`, *email)
}
if phone_number != nil && *phone_number != "" {
req += fmt.Sprintf(` AND phone_number = '%s'`, *phone_number)
}
req += ";"
fmt.Println(req)
account.Authentication.Local = &LocalAuth{}
err := psql.DbConnection.QueryRow(req).Scan(
&account.ID,
&account.Namespace, &data, &metadata,
&account.Authentication.Local.Username,
&account.Authentication.Local.Password,
&account.Authentication.Local.Email,
&emailValidation,
&account.Authentication.Local.PhoneNumber,
&phoneValidation)
if err != nil {
return nil, err
}
err = json.Unmarshal(metadata, &account.Metadata)
if err != nil {
fmt.Println("one")
return nil, err
}
err = json.Unmarshal(data, &account.Data)
if err != nil {
fmt.Println("two")
return nil, err
}
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
if err != nil {
fmt.Println("three")
return nil, err
}
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
if err != nil {
fmt.Println("four")
return nil, err
}
return account, nil
}
func (psql PostgresqlStorage) GetAccounts(namespaces []string) ([]Account, error) {
var accounts []Account
namespacesStr := "'" + strings.Join(namespaces, "', '") + "'"
query := `
SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
FROM %s LEFT JOIN %s ON id = account_id
WHERE namespace IN (%s)
`
query = fmt.Sprintf(query, psql.Tables["accounts"], psql.Tables["accounts_auth_local"], namespacesStr)
rows, err := psql.DbConnection.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var account Account
var dataBytes []byte
var metadataBytes []byte
var emailValidation []byte
var phoneValidation []byte
var username, password, email, phonenumber *string
err := rows.Scan(
&account.ID,
&account.Namespace,
&dataBytes,
&metadataBytes,
&username,
&password,
&email,
&emailValidation,
&phonenumber,
&phoneValidation,
)
if err != nil {
return nil, err
}
err = json.Unmarshal(dataBytes, &account.Data)
if err != nil {
return nil, err
}
err = json.Unmarshal(metadataBytes, &account.Metadata)
if err != nil {
return nil, err
}
if password != nil || username != nil || email != nil || phonenumber != nil {
account.Authentication.Local = &LocalAuth{
Username: username,
Password: *password,
Email: email,
PhoneNumber: phonenumber,
}
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
if err != nil {
return nil, err
}
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
if err != nil {
return nil, err
}
}
accounts = append(accounts, account)
}
return accounts, nil
}
func (psql PostgresqlStorage) GetAccountsByIds(accountids []string) ([]Account, error) {
var accounts []Account
query := `
SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation
FROM %s
LEFT JOIN %s ON id = account_id
WHERE id = ANY ($1)
`
query = fmt.Sprintf(query, psql.Tables["accounts"], psql.Tables["accounts_auth_local"]) //, accountIdsStr)
rows, err := psql.DbConnection.Query(query, pq.Array(accountids))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var account Account
var dataBytes []byte
var metadataBytes []byte
var emailValidation []byte
var phoneValidation []byte
var username, password, email, phonenumber *string
err := rows.Scan(
&account.ID,
&account.Namespace,
&dataBytes,
&metadataBytes,
&username,
&password,
&email,
&emailValidation,
&phonenumber,
&phoneValidation,
)
if err != nil {
return nil, err
}
err = json.Unmarshal(dataBytes, &account.Data)
if err != nil {
return nil, err
}
err = json.Unmarshal(metadataBytes, &account.Metadata)
if err != nil {
return nil, err
}
if password != nil || username != nil || email != nil || phonenumber != nil {
account.Authentication.Local = &LocalAuth{
Username: username,
Password: *password,
Email: email,
PhoneNumber: phonenumber,
}
err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation)
if err != nil {
return nil, err
}
err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation)
if err != nil {
return nil, err
}
}
accounts = append(accounts, account)
}
return accounts, nil
}
func (psql PostgresqlStorage) CreateAccount(account Account) error {
dataAccountJson, err := json.Marshal(account.Data)
if err != nil {
return err
}
metadataAccountJson, err := json.Marshal(account.Metadata)
if err != nil {
return err
}
tx, err := psql.DbConnection.BeginTx(context.Background(), nil)
if err != nil {
return err
}
defer tx.Rollback()
insertAccountStmt, err := tx.Prepare(fmt.Sprintf("INSERT INTO %s (id, namespace, data, metadata) VALUES ($1, $2, $3, $4)", psql.Tables["accounts"]))
if err != nil {
return err
}
defer insertAccountStmt.Close()
_, err = insertAccountStmt.Exec(account.ID, account.Namespace, dataAccountJson, metadataAccountJson)
if err != nil {
return err
}
if account.Authentication.Local != nil {
emailValidationJson, err := json.Marshal(account.Authentication.Local.EmailValidation)
if err != nil {
return err
}
phoneValidationJson, err := json.Marshal(account.Authentication.Local.PhoneNumberValidation)
if err != nil {
return err
}
insertAccountAuthStmt, err := tx.Prepare(fmt.Sprintf(`INSERT INTO %s (account_id, account_namespace, username, password, email, email_validation, phone_number, phone_number_validation)
VALUES($1, $2, $3, $4, $5, $6, $7, $8);`, psql.Tables["accounts_auth_local"]))
if err != nil {
return err
}
defer insertAccountAuthStmt.Close()
_, err = insertAccountAuthStmt.Exec(account.ID,
account.Namespace,
account.Authentication.Local.Username,
account.Authentication.Local.Password,
account.Authentication.Local.Email,
emailValidationJson,
account.Authentication.Local.PhoneNumber,
phoneValidationJson)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (psql PostgresqlStorage) UpdateAccount(account Account) error {
tx, err := psql.DbConnection.BeginTx(context.Background(), nil)
if err != nil {
return err
}
defer tx.Rollback()
dataAccountJson, err := json.Marshal(account.Data)
if err != nil {
return err
}
metadataAccountJson, err := json.Marshal(account.Metadata)
if err != nil {
return err
}
req := fmt.Sprintf("UPDATE %s SET namespace=$1, data=$2, metadata=$3 where id= $4", psql.Tables["accounts"])
_, err = tx.Exec(req, account.Namespace, dataAccountJson, metadataAccountJson, account.ID)
if err != nil {
return err
}
if account.Authentication.Local != nil {
emailValidationJson, err := json.Marshal(account.Authentication.Local.EmailValidation)
if err != nil {
return err
}
phoneValidationJson, err := json.Marshal(account.Authentication.Local.PhoneNumberValidation)
if err != nil {
return err
}
req = fmt.Sprintf(`UPDATE %s
SET username = $1, password = $2, email = $3, email_validation = $4,
phone_number = $5,phone_number_validation = $6 where account_id = $7`, psql.Tables["accounts_auth_local"])
_, err = tx.Exec(req, account.Authentication.Local.Username,
account.Authentication.Local.Password,
account.Authentication.Local.Email,
emailValidationJson,
account.Authentication.Local.PhoneNumber,
phoneValidationJson, account.ID)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (psql PostgresqlStorage) Migrate() error {
ctx := context.Background()
driver, err := postgres.Open(psql.DbConnection)
if err != nil {
return err
}
existing, err := driver.InspectRealm(ctx, &schema.InspectRealmOption{Schemas: []string{psql.Schema}})
if err != nil {
return err
}
var desired schema.Realm
hcl, err := os.ReadFile("postgresql/schema.hcl")
if err != nil {
return err
}
err = postgres.EvalHCLBytes(hcl, &desired, nil)
if err != nil {
return err
}
diff, err := driver.RealmDiff(existing, &desired)
if err != nil {
return err
}
err = driver.ApplyChanges(ctx, diff)
if err != nil {
return err
}
return nil
}