Merge pull request #280 from qdentity/dvic/fix-postgres-switch-database

Fix PostgreSQL cross-database queries and switching
This commit is contained in:
Jorge Rojas
2026-03-12 09:42:38 -04:00
committed by GitHub
+100 -119
View File
@@ -64,7 +64,7 @@ func (db *Postgres) Connect(urlstr string) error {
}
func (db *Postgres) GetDatabases() ([]string, error) {
rows, err := db.Connection.Query("SELECT datname FROM pg_database;")
rows, err := db.Connection.Query("SELECT datname FROM pg_database WHERE datallowconn AND has_database_privilege(current_user, datname, 'CONNECT');")
if err != nil {
return nil, err
}
@@ -91,21 +91,16 @@ func (db *Postgres) GetTables(database string) (map[string][]string, error) {
return nil, errors.New("database name is required")
}
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, err
}
if needsClose {
defer conn.Close()
}
query := "SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = $1"
rows, err := db.Connection.Query(query, database)
rows, err := conn.Query(query, database)
if err != nil {
return nil, err
}
@@ -144,17 +139,12 @@ func (db *Postgres) GetTableColumns(database, table string) ([][]string, error)
return nil, errors.New("table must be in the format schema.table")
}
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, err
}
if needsClose {
defer conn.Close()
}
tableSchema := splitTableString[0]
@@ -162,7 +152,7 @@ func (db *Postgres) GetTableColumns(database, table string) ([][]string, error)
query := "SELECT c.column_name, c.data_type, c.is_nullable, c.column_default, COALESCE(pd.description, '') as comment FROM information_schema.columns c LEFT JOIN pg_class pc ON pc.relname = c.table_name LEFT JOIN pg_namespace pn ON pn.nspname = c.table_schema AND pn.oid = pc.relnamespace LEFT JOIN pg_description pd ON pd.objoid = pc.oid AND pd.objsubid = c.ordinal_position WHERE c.table_catalog = $1 AND c.table_schema = $2 AND c.table_name = $3 ORDER by c.ordinal_position"
rows, err := db.Connection.Query(query, database, tableSchema, tableName)
rows, err := conn.Query(query, database, tableSchema, tableName)
if err != nil {
return nil, err
}
@@ -213,23 +203,18 @@ func (db *Postgres) GetConstraints(database, table string) ([][]string, error) {
return nil, errors.New("table must be in the format schema.table")
}
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, err
}
if needsClose {
defer conn.Close()
}
tableSchema := splitTableString[0]
tableName := splitTableString[1]
rows, err := db.Connection.Query(fmt.Sprintf(`
rows, err := conn.Query(fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
@@ -293,23 +278,18 @@ func (db *Postgres) GetForeignKeys(database, table string) ([][]string, error) {
return nil, errors.New("table must be in the format schema.table")
}
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, err
}
if needsClose {
defer conn.Close()
}
tableSchema := splitTableString[0]
tableName := splitTableString[1]
rows, err := db.Connection.Query(fmt.Sprintf(`
rows, err := conn.Query(fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
@@ -374,23 +354,18 @@ func (db *Postgres) GetIndexes(database, table string) ([][]string, error) {
return nil, errors.New("table must be in the format schema.table")
}
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, err
}
if needsClose {
defer conn.Close()
}
tableSchema := splitTableString[0]
tableName := splitTableString[1]
rows, err := db.Connection.Query(fmt.Sprintf(`
rows, err := conn.Query(fmt.Sprintf(`
SELECT
i.relname AS index_name,
a.attname AS column_name,
@@ -464,17 +439,12 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
return nil, 0, "", err
}
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, 0, "", err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, 0, "", err
}
if needsClose {
defer conn.Close()
}
queryString = "SELECT * FROM "
@@ -494,7 +464,7 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
limit = DefaultRowLimit
}
paginatedRows, err := db.Connection.Query(queryString, limit, offset)
paginatedRows, err := conn.Query(queryString, limit, offset)
if err != nil {
return nil, 0, queryString, err
}
@@ -549,7 +519,7 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
countQuery += fmt.Sprintf(" %s", where)
}
countRow := db.Connection.QueryRow(countQuery)
countRow := conn.QueryRow(countQuery)
if err := countRow.Scan(&totalRecords); err != nil {
return records, 0, queryString, err
@@ -588,24 +558,19 @@ func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColum
return formatErr
}
switchDatabaseOnError := false
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return err
}
switchDatabaseOnError = true
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return err
}
if needsClose {
defer conn.Close()
}
query := "UPDATE "
query += formattedTableName
query += fmt.Sprintf(" SET \"%s\" = $1 WHERE \"%s\" = $2", column, primaryKeyColumnName)
_, err := db.Connection.Exec(query, value, primaryKeyValue)
if err != nil && switchDatabaseOnError {
err = db.SwitchDatabase(db.PreviousDatabase)
}
_, err = conn.Exec(query, value, primaryKeyValue)
return err
}
@@ -628,24 +593,19 @@ func (db *Postgres) DeleteRecord(database, table, primaryKeyColumnName, primaryK
return formatErr
}
switchDatabaseOnError := false
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return err
}
switchDatabaseOnError = true
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return err
}
if needsClose {
defer conn.Close()
}
query := "DELETE FROM "
query += formattedTableName
query += fmt.Sprintf(" WHERE \"%s\" = $1", primaryKeyColumnName)
_, err := db.Connection.Exec(query, primaryKeyValue)
if err != nil && switchDatabaseOnError {
err = db.SwitchDatabase(db.PreviousDatabase)
}
_, err = conn.Exec(query, primaryKeyValue)
return err
}
@@ -742,20 +702,15 @@ func (db *Postgres) GetPrimaryKeyColumnNames(database, table string) ([]string,
schemaName := splitTableString[0]
tableName := splitTableString[1]
if database != db.CurrentDatabase {
err := db.SwitchDatabase(database)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.SwitchDatabase(db.PreviousDatabase)
}
}()
conn, needsClose, err := db.connectionFor(database)
if err != nil {
return nil, err
}
if needsClose {
defer conn.Close()
}
row, err := db.Connection.Query(`
row, err := conn.Query(`
SELECT
a.attname AS column_name
FROM
@@ -804,37 +759,63 @@ func (db *Postgres) GetProvider() string {
return db.Provider
}
func (db *Postgres) SwitchDatabase(database string) error {
// connectToDatabase opens a new connection to the given database without
// mutating the receiver. The caller must close the returned connection.
func (db *Postgres) connectToDatabase(database string) (*sql.DB, error) {
parsedConn, err := dburl.Parse(db.Urlstr)
if err != nil {
return err
return nil, err
}
user := parsedConn.User.Username()
password, _ := parsedConn.User.Password()
password, hasPassword := parsedConn.User.Password()
host := parsedConn.Hostname()
port := parsedConn.Port()
dbname := parsedConn.Path
if port == "" {
port = defaultPort
}
if dbname == "" {
dbname = database
dsn := fmt.Sprintf("host=%s port=%s user=%s dbname=%s sslmode=disable", host, port, user, database)
if hasPassword {
dsn += fmt.Sprintf(" password=%s", password)
}
connection, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname))
conn, err := sql.Open("postgres", dsn)
if err != nil {
return nil, err
}
return conn, nil
}
// connectionFor returns a connection to the given database. If it matches
// the current database, the existing connection is returned (caller must NOT
// close it). Otherwise a new temporary connection is opened and returned
// (caller MUST close it).
func (db *Postgres) connectionFor(database string) (conn *sql.DB, needsClose bool, err error) {
if database == db.CurrentDatabase {
return db.Connection, false, nil
}
conn, err = db.connectToDatabase(database)
if err != nil {
return nil, false, err
}
return conn, true, nil
}
func (db *Postgres) SwitchDatabase(database string) error {
conn, err := db.connectToDatabase(database)
if err != nil {
return err
}
err = db.Connection.Close()
if err != nil {
conn.Close()
return err
}
db.Connection = connection
db.Connection = conn
db.PreviousDatabase = db.CurrentDatabase
db.CurrentDatabase = database