package storage import ( "context" "database/sql" "encoding/json" "fmt" "os" "strconv" "strings" "ariga.io/atlas/sql/postgres" "ariga.io/atlas/sql/schema" "github.com/lib/pq" _ "github.com/lib/pq" "github.com/spf13/viper" ) type PostgresqlStorage struct { DbConnection *sql.DB Tables map[string]string } func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) { var ( host = cfg.GetString("storage.db.psql.host") port = cfg.GetString("storage.db.psql.port") user = cfg.GetString("storage.db.psql.user") password = cfg.GetString("storage.db.psql.password") dbname = cfg.GetString("storage.db.psql.dbname") sslmode = cfg.GetString("storage.db.psql.sslmode") pg_schema = cfg.GetString("storage.db.psql.schema") pgtables_accounts = cfg.GetString("storage.db.psql.tables.accounts") pgtables_accounts_auth_local = cfg.GetString("storage.db.psql.tables.accounts_auth_local") ) portInt, _ := strconv.Atoi(port) psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, portInt, user, password, dbname, sslmode) db, err := sql.Open("postgres", psqlconn) if err != nil { fmt.Println("error", err) return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed") } err = db.Ping() if err != nil { fmt.Println(err) return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed") } return PostgresqlStorage{ DbConnection: db, Tables: map[string]string{ "accounts": fmt.Sprintf("%s.%s", pg_schema, pgtables_accounts), "accounts_auth_local": fmt.Sprintf("%s.%s", pg_schema, pgtables_accounts_auth_local), }, }, nil } func (psql PostgresqlStorage) GetAccount(id string) (*Account, error) { var ( data, metadata, emailValidation, phoneValidation []byte ) account := &Account{} var username, password, email, phonenumber *string req := fmt.Sprintf(`SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation FROM %s a LEFT JOIN %s auth ON id = account_id WHERE id = $1`, psql.Tables["accounts"], psql.Tables["accounts_auth_local"]) err := psql.DbConnection.QueryRow(req, id).Scan(&account.ID, &account.Namespace, &data, &metadata, &username, &password, &email, &emailValidation, &phonenumber, &phoneValidation) if err != nil { return nil, fmt.Errorf("psql select account query failed : %s", err) } err = json.Unmarshal(data, &account.Data) if err != nil { return nil, err } err = json.Unmarshal(metadata, &account.Metadata) if err != nil { return nil, err } if password != nil || username != nil || email != nil || phonenumber != nil { account.Authentication.Local = &LocalAuth{ Username: username, Password: *password, Email: email, PhoneNumber: phonenumber, } err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation) if err != nil { return nil, err } err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation) if err != nil { return nil, err } } return account, nil } func (psql PostgresqlStorage) LocalAuthentication(namespace string, username *string, email *string, phone_number *string) (*Account, error) { account := &Account{} var ( data, metadata, emailValidation, phoneValidation []byte ) requested_field := "" requested_value := "" if username != nil { requested_field = "username" requested_value = *username } else if email != nil { requested_field = "email" requested_value = *email } else if phone_number != nil { requested_field = "phone_number" requested_value = *phone_number } else { return nil, fmt.Errorf("localauthentication func error PSQL") } req := fmt.Sprintf(`SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation FROM %s INNER JOIN %s ON id = account_id WHERE account_namespace = $1 AND %s = $2;`, psql.Tables["accounts"], psql.Tables["accounts_auth_local"], requested_field) account.Authentication.Local = &LocalAuth{} err := psql.DbConnection.QueryRow(req, namespace, requested_value).Scan( &account.ID, &account.Namespace, &data, &metadata, &account.Authentication.Local.Username, &account.Authentication.Local.Password, &account.Authentication.Local.Email, &emailValidation, &account.Authentication.Local.PhoneNumber, &phoneValidation) if err != nil { return nil, err } err = json.Unmarshal(metadata, &account.Metadata) if err != nil { return nil, err } err = json.Unmarshal(data, &account.Data) if err != nil { return nil, err } err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation) if err != nil { return nil, err } err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation) if err != nil { return nil, err } return account, nil } func (psql PostgresqlStorage) GetAccounts(namespaces []string) ([]Account, error) { var accounts []Account namespacesStr := "'" + strings.Join(namespaces, "', '") + "'" query := ` SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation FROM %s LEFT JOIN %s ON id = account_id WHERE namespace IN (%s) ` query = fmt.Sprintf(query, psql.Tables["accounts"], psql.Tables["accounts_auth_local"], namespacesStr) rows, err := psql.DbConnection.Query(query) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var account Account var dataBytes []byte var metadataBytes []byte var emailValidation []byte var phoneValidation []byte var username, password, email, phonenumber *string err := rows.Scan( &account.ID, &account.Namespace, &dataBytes, &metadataBytes, &username, &password, &email, &emailValidation, &phonenumber, &phoneValidation, ) if err != nil { return nil, err } err = json.Unmarshal(dataBytes, &account.Data) if err != nil { return nil, err } err = json.Unmarshal(metadataBytes, &account.Metadata) if err != nil { return nil, err } if password != nil || username != nil || email != nil || phonenumber != nil { account.Authentication.Local = &LocalAuth{ Username: username, Password: *password, Email: email, PhoneNumber: phonenumber, } err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation) if err != nil { return nil, err } err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation) if err != nil { return nil, err } } accounts = append(accounts, account) } return accounts, nil } func (psql PostgresqlStorage) GetAccountsByIds(accountids []string) ([]Account, error) { var accounts []Account query := ` SELECT id, namespace, data, metadata, username, password, email, email_validation, phone_number, phone_number_validation FROM %s LEFT JOIN %s ON id = account_id WHERE id = ANY ($1) ` query = fmt.Sprintf(query, psql.Tables["accounts"], psql.Tables["accounts_auth_local"]) //, accountIdsStr) rows, err := psql.DbConnection.Query(query, pq.Array(accountids)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var account Account var dataBytes []byte var metadataBytes []byte var emailValidation []byte var phoneValidation []byte var username, password, email, phonenumber *string err := rows.Scan( &account.ID, &account.Namespace, &dataBytes, &metadataBytes, &username, &password, &email, &emailValidation, &phonenumber, &phoneValidation, ) if err != nil { return nil, err } err = json.Unmarshal(dataBytes, &account.Data) if err != nil { return nil, err } err = json.Unmarshal(metadataBytes, &account.Metadata) if err != nil { return nil, err } if password != nil || username != nil || email != nil || phonenumber != nil { account.Authentication.Local = &LocalAuth{ Username: username, Password: *password, Email: email, PhoneNumber: phonenumber, } err = json.Unmarshal(emailValidation, &account.Authentication.Local.EmailValidation) if err != nil { return nil, err } err = json.Unmarshal(phoneValidation, &account.Authentication.Local.PhoneNumberValidation) if err != nil { return nil, err } } accounts = append(accounts, account) } return accounts, nil } func (psql PostgresqlStorage) CreateAccount(account Account) error { dataAccountJson, err := json.Marshal(account.Data) if err != nil { return err } metadataAccountJson, err := json.Marshal(account.Metadata) if err != nil { return err } tx, err := psql.DbConnection.BeginTx(context.Background(), nil) if err != nil { return err } defer tx.Rollback() insertAccountStmt, err := tx.Prepare(fmt.Sprintf("INSERT INTO %s (id, namespace, data, metadata) VALUES ($1, $2, $3, $4)", psql.Tables["accounts"])) if err != nil { return err } defer insertAccountStmt.Close() _, err = insertAccountStmt.Exec(account.ID, account.Namespace, dataAccountJson, metadataAccountJson) if err != nil { return err } if account.Authentication.Local != nil { emailValidationJson, err := json.Marshal(account.Authentication.Local.EmailValidation) if err != nil { return err } phoneValidationJson, err := json.Marshal(account.Authentication.Local.PhoneNumberValidation) if err != nil { return err } insertAccountAuthStmt, err := tx.Prepare(fmt.Sprintf(`INSERT INTO %s (account_id, account_namespace, username, password, email, email_validation, phone_number, phone_number_validation) VALUES($1, $2, $3, $4, $5, $6, $7, $8);`, psql.Tables["accounts_auth_local"])) if err != nil { return err } defer insertAccountAuthStmt.Close() _, err = insertAccountAuthStmt.Exec(account.ID, account.Namespace, account.Authentication.Local.Username, account.Authentication.Local.Password, account.Authentication.Local.Email, emailValidationJson, account.Authentication.Local.PhoneNumber, phoneValidationJson) if err != nil { return err } } if err := tx.Commit(); err != nil { return err } return nil } func (psql PostgresqlStorage) UpdateAccount(account Account) error { tx, err := psql.DbConnection.BeginTx(context.Background(), nil) if err != nil { return err } defer tx.Rollback() dataAccountJson, err := json.Marshal(account.Data) if err != nil { return err } metadataAccountJson, err := json.Marshal(account.Metadata) if err != nil { return err } req := fmt.Sprintf("UPDATE %s SET namespace=$1, data=$2, metadata=$3 where id= $4", psql.Tables["accounts"]) _, err = tx.Exec(req, account.Namespace, dataAccountJson, metadataAccountJson, account.ID) if err != nil { return err } if account.Authentication.Local != nil { emailValidationJson, err := json.Marshal(account.Authentication.Local.EmailValidation) if err != nil { return err } phoneValidationJson, err := json.Marshal(account.Authentication.Local.PhoneNumberValidation) if err != nil { return err } req = fmt.Sprintf(`UPDATE %s SET username = $1, password = $2, email = $3, email_validation = $4, phone_number = $5,phone_number_validation = $6 where account_id = $7`, psql.Tables["accounts_auth_local"]) _, err = tx.Exec(req, account.Authentication.Local.Username, account.Authentication.Local.Password, account.Authentication.Local.Email, emailValidationJson, account.Authentication.Local.PhoneNumber, phoneValidationJson, account.ID) if err != nil { return err } } if err := tx.Commit(); err != nil { return err } return nil } func (psql PostgresqlStorage) Migrate() error { ctx := context.Background() driver, err := postgres.Open(psql.DbConnection) if err != nil { return err } existing, err := driver.InspectRealm(ctx, nil) if err != nil { return err } var desired schema.Realm hcl, err := os.ReadFile("postgresql/schema.hcl") if err != nil { return err } err = postgres.EvalHCLBytes(hcl, &desired, nil) if err != nil { return err } diff, err := driver.RealmDiff(existing, &desired) if err != nil { return err } err = driver.ApplyChanges(ctx, diff) if err != nil { return err } return nil }