package storage import ( "ariga.io/atlas/sql/postgres" "ariga.io/atlas/sql/schema" "context" "database/sql" "encoding/json" "fmt" "github.com/lib/pq" "github.com/spf13/viper" "os" "strconv" ) type PostgresqlStorage struct { DbConnection *sql.DB Schema string Tables map[string]string } func NewPostgresqlStorage(cfg *viper.Viper) (PostgresqlStorage, error) { var ( host = cfg.GetString("storage.db.psql.host") port = cfg.GetString("storage.db.psql.port") user = cfg.GetString("storage.db.psql.user") password = cfg.GetString("storage.db.psql.password") dbname = cfg.GetString("storage.db.psql.dbname") sslmode = cfg.GetString("storage.db.psql.sslmode") pg_schema = cfg.GetString("storage.db.psql.schema") pgtables_event = cfg.GetString("storage.db.psql.tables.event") pgtables_subscription = cfg.GetString("storage.db.psql.tables.subscription") timezone = "Europe/Paris" ) portInt, _ := strconv.Atoi(port) psqlconn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", host, portInt, user, password, dbname, sslmode, timezone) db, err := sql.Open("postgres", psqlconn) if err != nil { fmt.Println("error", err) return PostgresqlStorage{}, fmt.Errorf("connection to postgresql failed") } err = db.Ping() if err != nil { fmt.Println(err) return PostgresqlStorage{}, fmt.Errorf("connection to postgresql database failed") } return PostgresqlStorage{ DbConnection: db, Schema: pg_schema, Tables: map[string]string{ "event": fmt.Sprintf("%s.%s", pg_schema, pgtables_event), "subscription": fmt.Sprintf("%s.%s", pg_schema, pgtables_subscription), }, }, nil } func (psql PostgresqlStorage) CreateEvent(e Event) error { tx, err := psql.DbConnection.Begin() if err != nil { return err } defer func() { if err != nil { tx.Rollback() return } tx.Commit() }() dataEvent, err := json.Marshal(e.Data) if err != nil { return err } deletedSubscriptionsJSON, err := json.Marshal(e.DeletedSubscription) if err != nil { return err } eventQuery := fmt.Sprintf(` INSERT INTO %s (id, namespace, owners, restricted_to, type, name, description, startdate, enddate, starttime, endtime, allday, maxsubscribers, data, deleted, deletedsubscriptions) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) `, psql.Tables["event"]) ownersArray, err := pq.Array(e.Owners).Value() if err != nil { return err } restrictedToArray, err := pq.Array(e.RestrictedTo).Value() if err != nil { return err } _, err = tx.Exec( eventQuery, e.ID, e.Namespace, ownersArray, restrictedToArray, e.Type, e.Name, e.Description, e.Startdate, e.Enddate, e.Starttime, e.Endtime, e.Allday, e.MaxSubscribers, dataEvent, e.Deleted, deletedSubscriptionsJSON, ) if err != nil { return err } for _, subscription := range e.Subscriptions { subscriptionQuery := fmt.Sprintf(` INSERT INTO %s (id, event_id, subscriber, tags, created_at, data) VALUES ($1, $2, $3, $4, $5, $6) `, psql.Tables["subscription"]) dataSubscription, err := json.Marshal(subscription.Data) if err != nil { return err } tagsArray, err := pq.Array(subscription.Tags).Value() if err != nil { return err } _, err = tx.Exec( subscriptionQuery, subscription.ID, e.ID, subscription.Subscriber, tagsArray, subscription.CreatedAt, dataSubscription, ) if err != nil { return err } } return nil } func (psql PostgresqlStorage) GetEvent(eventID string) (*Event, error) { var event Event eventQuery := fmt.Sprintf(` SELECT id, namespace, owners, restricted_to, type, name, description, startdate, enddate, starttime, endtime, allday, maxsubscribers, data, deleted, deletedsubscriptions FROM %s WHERE id = $1 `, psql.Tables["event"]) row := psql.DbConnection.QueryRow(eventQuery, eventID) owners := pq.StringArray{} restrictedTo := pq.StringArray{} dataEvent := []byte{} deletedSubscriptions := []byte{} err := row.Scan( &event.ID, &event.Namespace, &owners, &restrictedTo, &event.Type, &event.Name, &event.Description, &event.Startdate, &event.Enddate, &event.Starttime, &event.Endtime, &event.Allday, &event.MaxSubscribers, &dataEvent, &event.Deleted, &deletedSubscriptions, ) if err != nil { return nil, err } event.Owners = []string(owners) event.RestrictedTo = []string(restrictedTo) data := make(map[string]any) err = json.Unmarshal(dataEvent, &data) if err != nil { return nil, err } event.Data = data subscriptions, err := psql.getSubscriptions(eventID) if err != nil { return nil, err } event.Subscriptions = subscriptions deletedSubs := []Subscription{} err = json.Unmarshal(deletedSubscriptions, &deletedSubs) if err != nil { return nil, err } event.DeletedSubscription = deletedSubs return &event, nil } func (psql PostgresqlStorage) GetEvents(namespaces []string) ([]Event, error) { var events []Event eventQuery := fmt.Sprintf(` SELECT id, namespace, owners, restricted_to, type, name, description, startdate, enddate, starttime, endtime, allday, maxsubscribers, data, deletedsubscriptions, deleted FROM %s WHERE namespace = ANY($1::text[]) `, psql.Tables["event"]) rows, err := psql.DbConnection.Query(eventQuery, pq.Array(namespaces)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var event Event var owners pq.StringArray var restrictedTo pq.StringArray var dataEvent []byte var deletedSubscriptions []byte err := rows.Scan( &event.ID, &event.Namespace, &owners, &restrictedTo, &event.Type, &event.Name, &event.Description, &event.Startdate, &event.Enddate, &event.Starttime, &event.Endtime, &event.Allday, &event.MaxSubscribers, &dataEvent, &deletedSubscriptions, &event.Deleted, ) if err != nil { return nil, err } event.Owners = []string(owners) event.RestrictedTo = []string(restrictedTo) err = json.Unmarshal(dataEvent, &event.Data) if err != nil { return nil, err } event.Subscriptions, err = psql.getSubscriptions(event.ID) if err != nil { return nil, err } err = json.Unmarshal(deletedSubscriptions, &event.DeletedSubscription) if err != nil { return nil, err } events = append(events, event) } return events, nil } func (psql PostgresqlStorage) UpdateEvent(e Event) error { tx, err := psql.DbConnection.Begin() if err != nil { return err } defer func() { if err != nil { tx.Rollback() return } tx.Commit() }() dataEvent, err := json.Marshal(e.Data) if err != nil { return err } deletedSubscriptions, err := json.Marshal(e.DeletedSubscription) if err != nil { return err } eventQuery := fmt.Sprintf(` UPDATE %s SET namespace = $2, owners = $3, restricted_to = $4, type = $5, name = $6, description = $7, startdate = $8, enddate = $9, starttime = $10, endtime = $11, allday = $12, maxsubscribers = $13, data = $14, deleted = $15, deletedsubscriptions = $16 WHERE id = $1 `, psql.Tables["event"]) ownersArray, err := pq.Array(e.Owners).Value() if err != nil { return err } restrictedToArray, err := pq.Array(e.RestrictedTo).Value() if err != nil { return err } _, err = tx.Exec( eventQuery, e.ID, e.Namespace, ownersArray, restrictedToArray, e.Type, e.Name, e.Description, e.Startdate, e.Enddate, e.Starttime, e.Endtime, e.Allday, e.MaxSubscribers, dataEvent, e.Deleted, deletedSubscriptions, ) if err != nil { return err } for _, subscription := range e.Subscriptions { subscriptionQuery := fmt.Sprintf(` UPDATE %s SET subscriber = $2, tags = $3, data = $4, created_at= $5 WHERE event_id = $1 `, psql.Tables["subscription"]) dataSubscription, err := json.Marshal(subscription.Data) if err != nil { return err } tagsArray, err := pq.Array(subscription.Tags).Value() if err != nil { return err } _, err = tx.Exec( subscriptionQuery, e.ID, subscription.Subscriber, tagsArray, dataSubscription, subscription.CreatedAt, ) if err != nil { return err } } return nil } func (psql PostgresqlStorage) AddSubscription(eventid string, subscription Subscription) error { tags := pq.Array(subscription.Tags) data, err := json.Marshal(subscription.Data) if err != nil { return err } _, err = psql.DbConnection.Exec(fmt.Sprintf(` INSERT INTO %s (id, event_id, subscriber, tags, created_at, data) VALUES ($1, $2, $3, $4, $5, $6) `, psql.Tables["subscription"]), subscription.ID, eventid, subscription.Subscriber, tags, subscription.CreatedAt, data, ) return err } func (psql PostgresqlStorage) UpdateSubscription(eventid string, subscriber string, deletesubscription Subscription) error { tx, err := psql.DbConnection.Begin() if err != nil { return err } defer func() { if err != nil { tx.Rollback() return } tx.Commit() }() subscriptionQuery := fmt.Sprintf(` DELETE FROM %s WHERE event_id = $1 AND subscriber = $2 `, psql.Tables["subscription"]) _, err = tx.Exec( subscriptionQuery, eventid, subscriber, ) if err != nil { fmt.Println(err) return err } eventQuery := fmt.Sprintf(` UPDATE %s SET deletedsubscriptions = deletedsubscriptions || $1 WHERE id = $2 `, psql.Tables["event"]) deletedSubscriptions, err := json.Marshal(deletesubscription) if err != nil { fmt.Println(err) return err } _, err = tx.Exec( eventQuery, deletedSubscriptions, eventid, ) if err != nil { fmt.Println(err) return err } return nil } func (psql PostgresqlStorage) getSubscriptions(eventID string) ([]Subscription, error) { var subscriptions []Subscription subscriptionQuery := fmt.Sprintf(` SELECT id, subscriber, tags, created_at, data FROM %s WHERE event_id = $1 `, psql.Tables["subscription"]) rows, err := psql.DbConnection.Query(subscriptionQuery, eventID) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var subscription Subscription var tags pq.StringArray var dataSubscription []byte err := rows.Scan( &subscription.ID, &subscription.Subscriber, &tags, &subscription.CreatedAt, &dataSubscription, ) if err != nil { return nil, err } subscription.Tags = []string(tags) data := make(map[string]any) err = json.Unmarshal(dataSubscription, &data) if err != nil { return nil, err } subscription.Data = data subscriptions = append(subscriptions, subscription) } return subscriptions, nil } func (psql PostgresqlStorage) Migrate() error { ctx := context.Background() driver, err := postgres.Open(psql.DbConnection) if err != nil { return err } existing, err := driver.InspectRealm(ctx, &schema.InspectRealmOption{Schemas: []string{psql.Schema}}) if err != nil { return err } var desired schema.Realm hcl, err := os.ReadFile("postgresql/schema.hcl") if err != nil { return err } err = postgres.EvalHCLBytes(hcl, &desired, nil) if err != nil { return err } diff, err := driver.RealmDiff(existing, &desired) if err != nil { return err } err = driver.ApplyChanges(ctx, diff) if err != nil { return err } return nil }