mirror of
https://github.com/jorgerojas26/lazysql.git
synced 2026-05-06 08:56:58 -04:00
Merge pull request #280 from qdentity/dvic/fix-postgres-switch-database
Fix PostgreSQL cross-database queries and switching
This commit is contained in:
+100
-119
@@ -64,7 +64,7 @@ func (db *Postgres) Connect(urlstr string) error {
|
||||
}
|
||||
|
||||
func (db *Postgres) GetDatabases() ([]string, error) {
|
||||
rows, err := db.Connection.Query("SELECT datname FROM pg_database;")
|
||||
rows, err := db.Connection.Query("SELECT datname FROM pg_database WHERE datallowconn AND has_database_privilege(current_user, datname, 'CONNECT');")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -91,21 +91,16 @@ func (db *Postgres) GetTables(database string) (map[string][]string, error) {
|
||||
return nil, errors.New("database name is required")
|
||||
}
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
query := "SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = $1"
|
||||
rows, err := db.Connection.Query(query, database)
|
||||
rows, err := conn.Query(query, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -144,17 +139,12 @@ func (db *Postgres) GetTableColumns(database, table string) ([][]string, error)
|
||||
return nil, errors.New("table must be in the format schema.table")
|
||||
}
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
tableSchema := splitTableString[0]
|
||||
@@ -162,7 +152,7 @@ func (db *Postgres) GetTableColumns(database, table string) ([][]string, error)
|
||||
|
||||
query := "SELECT c.column_name, c.data_type, c.is_nullable, c.column_default, COALESCE(pd.description, '') as comment FROM information_schema.columns c LEFT JOIN pg_class pc ON pc.relname = c.table_name LEFT JOIN pg_namespace pn ON pn.nspname = c.table_schema AND pn.oid = pc.relnamespace LEFT JOIN pg_description pd ON pd.objoid = pc.oid AND pd.objsubid = c.ordinal_position WHERE c.table_catalog = $1 AND c.table_schema = $2 AND c.table_name = $3 ORDER by c.ordinal_position"
|
||||
|
||||
rows, err := db.Connection.Query(query, database, tableSchema, tableName)
|
||||
rows, err := conn.Query(query, database, tableSchema, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -213,23 +203,18 @@ func (db *Postgres) GetConstraints(database, table string) ([][]string, error) {
|
||||
return nil, errors.New("table must be in the format schema.table")
|
||||
}
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
tableSchema := splitTableString[0]
|
||||
tableName := splitTableString[1]
|
||||
|
||||
rows, err := db.Connection.Query(fmt.Sprintf(`
|
||||
rows, err := conn.Query(fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
@@ -293,23 +278,18 @@ func (db *Postgres) GetForeignKeys(database, table string) ([][]string, error) {
|
||||
return nil, errors.New("table must be in the format schema.table")
|
||||
}
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
tableSchema := splitTableString[0]
|
||||
tableName := splitTableString[1]
|
||||
|
||||
rows, err := db.Connection.Query(fmt.Sprintf(`
|
||||
rows, err := conn.Query(fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
@@ -374,23 +354,18 @@ func (db *Postgres) GetIndexes(database, table string) ([][]string, error) {
|
||||
return nil, errors.New("table must be in the format schema.table")
|
||||
}
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
tableSchema := splitTableString[0]
|
||||
tableName := splitTableString[1]
|
||||
|
||||
rows, err := db.Connection.Query(fmt.Sprintf(`
|
||||
rows, err := conn.Query(fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname AS index_name,
|
||||
a.attname AS column_name,
|
||||
@@ -464,17 +439,12 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
queryString = "SELECT * FROM "
|
||||
@@ -494,7 +464,7 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
|
||||
limit = DefaultRowLimit
|
||||
}
|
||||
|
||||
paginatedRows, err := db.Connection.Query(queryString, limit, offset)
|
||||
paginatedRows, err := conn.Query(queryString, limit, offset)
|
||||
if err != nil {
|
||||
return nil, 0, queryString, err
|
||||
}
|
||||
@@ -549,7 +519,7 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
|
||||
countQuery += fmt.Sprintf(" %s", where)
|
||||
}
|
||||
|
||||
countRow := db.Connection.QueryRow(countQuery)
|
||||
countRow := conn.QueryRow(countQuery)
|
||||
|
||||
if err := countRow.Scan(&totalRecords); err != nil {
|
||||
return records, 0, queryString, err
|
||||
@@ -588,24 +558,19 @@ func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColum
|
||||
return formatErr
|
||||
}
|
||||
|
||||
switchDatabaseOnError := false
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switchDatabaseOnError = true
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
query := "UPDATE "
|
||||
query += formattedTableName
|
||||
query += fmt.Sprintf(" SET \"%s\" = $1 WHERE \"%s\" = $2", column, primaryKeyColumnName)
|
||||
|
||||
_, err := db.Connection.Exec(query, value, primaryKeyValue)
|
||||
if err != nil && switchDatabaseOnError {
|
||||
err = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
|
||||
_, err = conn.Exec(query, value, primaryKeyValue)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -628,24 +593,19 @@ func (db *Postgres) DeleteRecord(database, table, primaryKeyColumnName, primaryK
|
||||
return formatErr
|
||||
}
|
||||
|
||||
switchDatabaseOnError := false
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switchDatabaseOnError = true
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
query := "DELETE FROM "
|
||||
query += formattedTableName
|
||||
query += fmt.Sprintf(" WHERE \"%s\" = $1", primaryKeyColumnName)
|
||||
|
||||
_, err := db.Connection.Exec(query, primaryKeyValue)
|
||||
if err != nil && switchDatabaseOnError {
|
||||
err = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
|
||||
_, err = conn.Exec(query, primaryKeyValue)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -742,20 +702,15 @@ func (db *Postgres) GetPrimaryKeyColumnNames(database, table string) ([]string,
|
||||
schemaName := splitTableString[0]
|
||||
tableName := splitTableString[1]
|
||||
|
||||
if database != db.CurrentDatabase {
|
||||
err := db.SwitchDatabase(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = db.SwitchDatabase(db.PreviousDatabase)
|
||||
}
|
||||
}()
|
||||
conn, needsClose, err := db.connectionFor(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsClose {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
row, err := db.Connection.Query(`
|
||||
row, err := conn.Query(`
|
||||
SELECT
|
||||
a.attname AS column_name
|
||||
FROM
|
||||
@@ -804,37 +759,63 @@ func (db *Postgres) GetProvider() string {
|
||||
return db.Provider
|
||||
}
|
||||
|
||||
func (db *Postgres) SwitchDatabase(database string) error {
|
||||
// connectToDatabase opens a new connection to the given database without
|
||||
// mutating the receiver. The caller must close the returned connection.
|
||||
func (db *Postgres) connectToDatabase(database string) (*sql.DB, error) {
|
||||
parsedConn, err := dburl.Parse(db.Urlstr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := parsedConn.User.Username()
|
||||
password, _ := parsedConn.User.Password()
|
||||
password, hasPassword := parsedConn.User.Password()
|
||||
host := parsedConn.Hostname()
|
||||
port := parsedConn.Port()
|
||||
dbname := parsedConn.Path
|
||||
|
||||
if port == "" {
|
||||
port = defaultPort
|
||||
}
|
||||
|
||||
if dbname == "" {
|
||||
dbname = database
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s dbname=%s sslmode=disable", host, port, user, database)
|
||||
if hasPassword {
|
||||
dsn += fmt.Sprintf(" password=%s", password)
|
||||
}
|
||||
|
||||
connection, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname))
|
||||
conn, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionFor returns a connection to the given database. If it matches
|
||||
// the current database, the existing connection is returned (caller must NOT
|
||||
// close it). Otherwise a new temporary connection is opened and returned
|
||||
// (caller MUST close it).
|
||||
func (db *Postgres) connectionFor(database string) (conn *sql.DB, needsClose bool, err error) {
|
||||
if database == db.CurrentDatabase {
|
||||
return db.Connection, false, nil
|
||||
}
|
||||
conn, err = db.connectToDatabase(database)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return conn, true, nil
|
||||
}
|
||||
|
||||
func (db *Postgres) SwitchDatabase(database string) error {
|
||||
conn, err := db.connectToDatabase(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Connection.Close()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
db.Connection = connection
|
||||
db.Connection = conn
|
||||
db.PreviousDatabase = db.CurrentDatabase
|
||||
db.CurrentDatabase = database
|
||||
|
||||
|
||||
Reference in New Issue
Block a user