feat(studio): mark sql provenance for safety (#45336)

Mark provenance of SQL via the branded types SafeSqlFragment and
UntrustedSqlFragment. Only SafeSqlFragment should be executed;
UntrustedSqlFragments require some kind of implicit user approval (show
on screen + user has to click something) before they are promoted to
SafeSqlFragment.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Editor and RLS tester show loading states for inferred/generated SQL
and include a dedicated user SQL editor for safer edits.

* **Refactor**
* Platform-wide SQL handling tightened: snippets and AI-generated SQL
are treated as untrusted/display-only until promoted, improving safety
and consistency.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Charis
2026-05-04 13:08:06 -04:00
committed by GitHub
parent 2944a3f0ec
commit 0433eeb5f5
89 changed files with 1450 additions and 728 deletions
@@ -1,4 +1,3 @@
import type { PostgresPolicy } from '@supabase/postgres-meta'
import { useParams } from 'common'
import { isEmpty } from 'lodash'
import Link from 'next/link'
@@ -11,6 +10,7 @@ import {
PolicyTableRow,
PolicyTableRowProps,
} from '@/components/interfaces/Auth/Policies/PolicyTableRow'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { ProtectedSchemaWarning } from '@/components/interfaces/Database/ProtectedSchemaWarning'
import { NoSearchResults } from '@/components/ui/NoSearchResults'
import { useDatabasePolicyDeleteMutation } from '@/data/database-policies/database-policy-delete-mutation'
@@ -25,7 +25,7 @@ interface PoliciesProps {
isLocked: boolean
visibleTableIds: Set<number>
onSelectCreatePolicy: (table: string) => void
onSelectEditPolicy: (policy: PostgresPolicy) => void
onSelectEditPolicy: (policy: Policy) => void
onResetSearch?: () => void
}
@@ -83,13 +83,13 @@ export const Policies = ({
)
const onSelectEditPolicy = useCallback(
(policy: PostgresPolicy) => {
(policy: Policy) => {
onSelectEditPolicyAI(policy)
},
[onSelectEditPolicyAI]
)
const onSelectDeletePolicy = useCallback((policy: PostgresPolicy) => {
const onSelectDeletePolicy = useCallback((policy: Policy) => {
setSelectedPolicyToDelete(policy)
}, [])
@@ -1,13 +1,13 @@
import type { PostgresPolicy } from '@supabase/postgres-meta'
import type { PropsWithChildren } from 'react'
import { createContext, useCallback, useContext, useMemo } from 'react'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import type { ResponseError } from '@/types'
type TableKey = `${string}.${string}`
type PoliciesDataContextValue = {
getPoliciesForTable: (schema: string, table: string) => PostgresPolicy[]
getPoliciesForTable: (schema: string, table: string) => Array<Policy>
isPoliciesLoading: boolean
isPoliciesError: boolean
policiesError?: ResponseError | Error
@@ -23,7 +23,7 @@ export const usePoliciesData = () => {
}
type PoliciesDataProviderProps = {
policies: PostgresPolicy[]
policies: Array<Policy>
isPoliciesLoading: boolean
isPoliciesError: boolean
policiesError?: ResponseError | Error
@@ -39,7 +39,7 @@ export const PoliciesDataProvider = ({
exposedSchemas,
}: PropsWithChildren<PoliciesDataProviderProps>) => {
const policiesByTable = useMemo(() => {
const map = new Map<TableKey, PostgresPolicy[]>()
const map = new Map<TableKey, Array<Policy>>()
for (const policy of policies) {
const key = `${policy.schema}.${policy.table}` satisfies TableKey
@@ -1,4 +1,3 @@
import type { PostgresPolicy } from '@supabase/postgres-meta'
import { PermissionAction } from '@supabase/shared-types/out/constants'
import { noop } from 'lodash'
import { Edit, MoreVertical, Trash } from 'lucide-react'
@@ -17,6 +16,7 @@ import {
} from 'ui'
import { generatePolicyUpdateSQL } from './PolicyTableRow.utils'
import type { Policy } from './PolicyTableRow.utils'
import { SIDEBAR_KEYS } from '@/components/layouts/ProjectLayout/LayoutSidebar/LayoutSidebarProvider'
import { DropdownMenuItemTooltip } from '@/components/ui/DropdownMenuItemTooltip'
import { useAuthConfigQuery } from '@/data/auth/auth-config-query'
@@ -26,9 +26,9 @@ import { useAiAssistantStateSnapshot } from '@/state/ai-assistant-state'
import { useSidebarManagerSnapshot } from '@/state/sidebar-manager-state'
interface PolicyRowProps {
policy: PostgresPolicy
onSelectEditPolicy: (policy: PostgresPolicy) => void
onSelectDeletePolicy: (policy: PostgresPolicy) => void
policy: Policy
onSelectEditPolicy: (policy: Policy) => void
onSelectDeletePolicy: (policy: Policy) => void
isLocked?: boolean
}
@@ -1,7 +1,13 @@
import { ident, joinSqlFragments, safeSql, type SafeSqlFragment } from '@supabase/pg-meta'
import { PostgresPolicy } from '@supabase/postgres-meta'
import type { TableApiAccessData } from '@/data/privileges/table-api-access-query'
export type Policy = Omit<PostgresPolicy, 'definition' | 'check'> & {
definition: SafeSqlFragment | null
check: SafeSqlFragment | null
}
/**
* Single classifier for the RLS page's per-table admonition state. Shares the
* "granted / custom / revoked" grant semantics used by the Data API settings
@@ -61,21 +67,22 @@ export function getTableAdmonitionMessage(status: TableDataApiStatus): string |
}
}
export const generatePolicyUpdateSQL = (policy: PostgresPolicy) => {
let expression = ''
if (policy.definition !== null && policy.definition !== undefined) {
expression += `using (${policy.definition})${
policy.check === null || policy.check === undefined ? ';' : ''
}\n`
export const generatePolicyUpdateSQL = (policy: Policy): SafeSqlFragment => {
const parts: Array<SafeSqlFragment> = []
if (policy.definition != null) {
const semicolon = policy.check == null ? safeSql`;` : safeSql``
parts.push(safeSql`using (${policy.definition})${semicolon}`)
}
if (policy.check !== null && policy.check !== undefined) {
expression += `with check (${policy.check});\n`
if (policy.check != null) {
parts.push(safeSql`with check (${policy.check});`)
}
return `
alter policy "${policy.name}"
on "${policy.schema}"."${policy.table}"
to ${policy.roles.join(', ')}
${expression}
`.trim()
const expression = parts.length > 0 ? joinSqlFragments(parts, '\n') : safeSql``
return safeSql`
alter policy ${ident(policy.name)}
on ${ident(policy.schema)}.${ident(policy.table)}
to ${joinSqlFragments(policy.roles.map(ident), ', ')}
${expression}`
}
@@ -1,4 +1,3 @@
import type { PostgresPolicy } from '@supabase/postgres-meta'
import { useParams } from 'common'
import { noop } from 'lodash'
import { memo, useMemo } from 'react'
@@ -19,6 +18,7 @@ import { ShimmeringLoader } from 'ui-patterns/ShimmeringLoader'
import { usePoliciesData } from '../PoliciesDataContext'
import { PolicyRow } from './PolicyRow'
import type { PolicyTable } from './PolicyTableRow.types'
import type { Policy } from './PolicyTableRow.utils'
import { getTableAdmonitionMessage, getTableDataApiStatus } from './PolicyTableRow.utils'
import { PolicyTableRowHeader } from './PolicyTableRowHeader'
import AlertError from '@/components/ui/AlertError'
@@ -31,8 +31,8 @@ export interface PolicyTableRowProps {
isLocked: boolean
onSelectToggleRLS: (table: PolicyTable) => void
onSelectCreatePolicy: (table: PolicyTable) => void
onSelectEditPolicy: (policy: PostgresPolicy) => void
onSelectDeletePolicy: (policy: PostgresPolicy) => void
onSelectEditPolicy: (policy: Policy) => void
onSelectDeletePolicy: (policy: Policy) => void
}
const PolicyTableRowComponent = ({
@@ -1,24 +1,43 @@
import { UntrustedSqlFragment } from '@supabase/pg-meta'
import { Loader2 } from 'lucide-react'
import { Badge, Tooltip, TooltipContent, TooltipTrigger } from 'ui'
import CodeEditor from '@/components/ui/CodeEditor/CodeEditor'
export const InferredSQLViewer = ({ sql }: { sql: string }) => {
export const InferredSQLViewer = ({
sql,
isLoading = false,
}: {
sql: UntrustedSqlFragment | undefined
isLoading?: boolean
}) => {
return (
<>
<div className="flex items-center justify-between px-4 py-2">
<p className="text-sm">Inferred SQL:</p>
<Tooltip>
<TooltipTrigger>
<Badge variant="warning">Generated</Badge>
</TooltipTrigger>
<TooltipContent side="bottom" align="end" className="w-64 text-center">
This query is inferred from client library code with the help of the Assistant and may
not guarantee correctness.
</TooltipContent>
</Tooltip>
<div className="flex items-center gap-x-2">
<p className="text-sm">Inferred SQL:</p>
{isLoading && <Loader2 size={14} className="animate-spin text-foreground-lighter" />}
</div>
<div className="flex items-center gap-x-2">
<Tooltip>
<TooltipTrigger>
<Badge variant="warning">Generated</Badge>
</TooltipTrigger>
<TooltipContent side="bottom" align="end" className="w-64 text-center">
This query is inferred from client library code with the help of the Assistant and may
not guarantee correctness.
</TooltipContent>
</Tooltip>
</div>
</div>
<div className="h-44 relative">
<CodeEditor isReadOnly id="inferred-sql" language="pgsql" value={sql} />
{isLoading && !sql ? (
<div className="flex h-full items-center justify-center bg-surface-100 text-foreground-lighter">
<Loader2 size={20} className="animate-spin" />
</div>
) : (
<CodeEditor isReadOnly id="inferred-sql" language="pgsql" value={sql ?? ''} />
)}
</div>
</>
)
@@ -1,4 +1,3 @@
import { type PostgresPolicy } from '@supabase/postgres-meta'
import { Check, ChevronDown, Edit, X } from 'lucide-react'
import { useMemo } from 'react'
import {
@@ -9,13 +8,14 @@ import {
WarningIcon,
} from 'ui'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { ButtonTooltip } from '@/components/ui/ButtonTooltip'
interface RLSTableCardProps {
table: { schema: string; name: string; isRLSEnabled: boolean }
role?: string
policies: PostgresPolicy[]
handleSelectEditPolicy: (policy: PostgresPolicy) => void
policies: Policy[]
handleSelectEditPolicy: (policy: Policy) => void
}
export const RLSTableCard = ({
@@ -157,8 +157,8 @@ const TableAccessPolicySummary = ({
policies,
handleSelectEditPolicy,
}: {
policies: PostgresPolicy[]
handleSelectEditPolicy: (policy: PostgresPolicy) => void
policies: Policy[]
handleSelectEditPolicy: (policy: Policy) => void
}) => {
return (
<div className="border rounded-sm mt-4">
@@ -1,5 +1,4 @@
import { type PostgresPolicy } from '@supabase/postgres-meta'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { type User } from '@/data/auth/users-infinite-query'
import { type ParseSQLQueryResponse } from '@/data/misc/parse-query-mutation'
@@ -7,7 +6,7 @@ export type ParseQueryResults = {
tables: {
schema: string
table: string
tablePolicies: PostgresPolicy[]
tablePolicies: Array<Policy>
isRLSEnabled: boolean
}[]
operation: ParseSQLQueryResponse['operation']
@@ -1,4 +1,3 @@
import { type PostgresPolicy } from '@supabase/postgres-meta'
import {
Badge,
cn,
@@ -13,12 +12,13 @@ import { Results } from '../../SQLEditor/UtilityPanel/Results'
import { RLSTableCard } from './RLSTableCard'
import { ParseQueryResults } from './RLSTester.types'
import { useTestQueryRLS } from './useTestQueryRLS'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
interface RLSTesterResultsProps {
results: Object[]
autoLimit: boolean
parseQueryResults: ParseQueryResults
handleSelectEditPolicy: (policy: PostgresPolicy) => void
handleSelectEditPolicy: (policy: Policy) => void
}
export const RLSTesterResults = ({
@@ -1,4 +1,9 @@
import { type PostgresPolicy } from '@supabase/postgres-meta'
import {
acceptUntrustedSql,
safeSql,
type SafeSqlFragment,
type UntrustedSqlFragment,
} from '@supabase/pg-meta'
import {
Select,
SelectContent,
@@ -10,7 +15,7 @@ import {
} from '@ui/components/shadcn/ui/select'
import { LOCAL_STORAGE_KEYS } from 'common'
import { Code } from 'lucide-react'
import { useEffect, useState } from 'react'
import { useEffect, useRef, useState } from 'react'
import {
Button,
DialogSectionSeparator,
@@ -31,13 +36,14 @@ import { RLSTesterEmptyState } from './RLSTesterEmptyState'
import { RLSTesterResults } from './RLSTesterResults'
import { RoleSelector } from './RoleSelector'
import { UserSelector } from './UserSelector'
import { UserSqlEditor } from './UserSqlEditor'
import { useTestQueryRLS } from './useTestQueryRLS'
import { CodeEditor } from '@/components/ui/CodeEditor/CodeEditor'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { FeaturePreviewBadge } from '@/components/ui/FeaturePreviewBadge'
import { useRoleImpersonationStateSnapshot } from '@/state/role-impersonation-state'
interface RLSTesterSheetProps {
handleSelectEditPolicy: (policy: PostgresPolicy) => void
handleSelectEditPolicy: (policy: Policy) => void
}
export const RLSTesterSheet = ({ handleSelectEditPolicy }: RLSTesterSheetProps) => {
@@ -47,28 +53,53 @@ export const RLSTesterSheet = ({ handleSelectEditPolicy }: RLSTesterSheetProps)
const [selectedOption, setSelectedOption] = useState<'anon' | 'authenticated'>('anon')
const [format, setFormat] = useState<'sql' | 'lib'>('sql')
const [inferredSQL, setInferredSQL] = useState<string>()
const [inferredSQL, setInferredSQL] = useState<UntrustedSqlFragment>()
const [value, setValue] = useState<string>('')
const [value, setValue] = useState<SafeSqlFragment>(safeSql``)
const [results, setResults] = useState<Object[] | null>(null)
const [autoLimit, setAutoLimit] = useState(false)
const [parseQueryResults, setParseQueryResults] = useState<ParseQueryResults>()
const { testQuery, isLoading, executeSqlError, parseQueryError, parseClientCodeError } =
useTestQueryRLS()
const {
testQuery,
inferSQLFromLib,
isLoading,
isInferring,
executeSqlError,
parseQueryError,
parseClientCodeError,
} = useTestQueryRLS()
const debounceRef = useRef<ReturnType<typeof setTimeout> | null>(null)
const handleValueChange = (sql: SafeSqlFragment) => {
setValue(sql)
if (format !== 'lib') return
if (debounceRef.current !== null) clearTimeout(debounceRef.current)
if (!sql) return
debounceRef.current = setTimeout(() => {
inferSQLFromLib(sql, setInferredSQL)
}, 1500)
}
const executionCallbacks = {
option: selectedOption,
onExecuteSQL: ({ result, isAutoLimit }: { result: Object[] | null; isAutoLimit: boolean }) => {
setResults(result)
setAutoLimit(isAutoLimit)
},
onParseQuery: setParseQueryResults,
}
const onRunQuery = async () => {
await testQuery({
option: selectedOption,
format,
value,
onInferSQL: setInferredSQL,
onParseQuery: setParseQueryResults,
onExecuteSQL: ({ result, isAutoLimit }) => {
setResults(result)
setAutoLimit(isAutoLimit)
},
})
if (format === 'lib') {
if (!inferredSQL) return
await testQuery({ value: acceptUntrustedSql(inferredSQL), ...executionCallbacks })
} else {
await testQuery({ value, ...executionCallbacks })
}
}
useEffect(() => {
@@ -110,7 +141,17 @@ export const RLSTesterSheet = ({ handleSelectEditPolicy }: RLSTesterSheetProps)
<div className="flex items-center justify-between px-5 py-2">
<p className="text-sm">Query</p>
<div className="flex items-center gap-x-2">
<Select value={format} onValueChange={(x) => setFormat(x as 'sql' | 'lib')}>
<Select
value={format}
onValueChange={(x) => {
const newFormat = x as 'sql' | 'lib'
setFormat(newFormat)
if (newFormat !== 'lib') {
setInferredSQL(undefined)
if (debounceRef.current !== null) clearTimeout(debounceRef.current)
}
}}
>
<SelectTrigger size="tiny">
<SelectValue />
</SelectTrigger>
@@ -126,21 +167,20 @@ export const RLSTesterSheet = ({ handleSelectEditPolicy }: RLSTesterSheetProps)
</div>
<div className="h-40 relative">
<CodeEditor
<UserSqlEditor
id="rls-tester"
language="pgsql"
value={value}
placeholder={
format === 'sql'
? 'select * from table;'
: 'SQL will be inferred from client library code'
? safeSql`select * from table;`
: safeSql`SQL will be inferred from client library code`
}
onInputChange={(val) => setValue(val ?? '')}
onChange={handleValueChange}
actions={{
runQuery: {
enabled: open,
callback: () => {
if (!isLoading) onRunQuery()
if (!isInferring && !isLoading) onRunQuery()
},
},
}}
@@ -148,10 +188,10 @@ export const RLSTesterSheet = ({ handleSelectEditPolicy }: RLSTesterSheetProps)
</div>
</SheetSection>
{format === 'lib' && !!inferredSQL && (
{format === 'lib' && (
<div>
<DialogSectionSeparator />
<InferredSQLViewer sql={inferredSQL} />
<InferredSQLViewer sql={inferredSQL} isLoading={isInferring} />
</div>
)}
@@ -214,7 +254,12 @@ export const RLSTesterSheet = ({ handleSelectEditPolicy }: RLSTesterSheetProps)
<Button type="default" disabled={isLoading} onClick={() => setOpen(false)}>
Cancel
</Button>
<Button type="primary" loading={isLoading} onClick={onRunQuery}>
<Button
type="primary"
loading={isInferring || isLoading}
disabled={format === 'lib' && !inferredSQL}
onClick={onRunQuery}
>
Run query
</Button>
</div>
@@ -0,0 +1,28 @@
import { rawSql, type SafeSqlFragment } from '@supabase/pg-meta'
import type { ComponentProps } from 'react'
import { CodeEditor } from '@/components/ui/CodeEditor/CodeEditor'
interface UserSqlEditorProps {
id: string
value: SafeSqlFragment
placeholder?: SafeSqlFragment
actions?: ComponentProps<typeof CodeEditor>['actions']
onChange: (sql: SafeSqlFragment) => void
}
/**
* Wraps CodeEditor for user-authored SQL. The rawSql boundary lives here — any
* text the user types is immediately promoted to SafeSqlFragment so callers
* never handle plain strings.
*/
export const UserSqlEditor = ({ value, onChange, ...props }: UserSqlEditorProps) => {
return (
<CodeEditor
language="pgsql"
value={value}
onInputChange={(val) => onChange(rawSql(val ?? ''))}
{...props}
/>
)
}
@@ -1,3 +1,4 @@
import { type SafeSqlFragment, type UntrustedSqlFragment } from '@supabase/pg-meta'
import { useState } from 'react'
import { toast } from 'sonner'
@@ -20,7 +21,7 @@ import {
const limit = 100
/**
* [Joshen] Testing a SQL query for it's RLS access involves 3 async steps
* [Joshen] Testing a SQL query for its RLS access involves 3 async steps
* 0. (Optional) Inferring client library code to SQL query via the AI Assistant
* 1. Parsing the provided SQL query to retrieve its operation type + tables involved
* 2. Checking for tables involved if they've got RLS enabled
@@ -48,10 +49,26 @@ export const useTestQueryRLS = () => {
onError: () => {},
})
const { mutateAsync: parseClientCode, error: parseClientCodeError } = useParseClientCodeMutation({
const {
mutateAsync: parseClientCode,
isPending: isInferring,
error: parseClientCodeError,
} = useParseClientCodeMutation({
onError: () => {},
})
const inferSQLFromLib = async (
value: string,
onInferSQL: (unchecked_sql: UntrustedSqlFragment) => void
) => {
const { unchecked_sql, valid } = await parseClientCode({ code: value })
if (valid && unchecked_sql != null) {
onInferSQL(unchecked_sql)
} else {
toast.error('Client library code provided is not valid')
}
}
const { mutateAsync: parseQuery, error: parseQueryError } = useParseSQLQueryMutation({
onError: () => {},
})
@@ -62,17 +79,13 @@ export const useTestQueryRLS = () => {
})
const testQuery = async ({
option,
format,
value,
onInferSQL,
option,
onExecuteSQL,
onParseQuery,
}: {
value: SafeSqlFragment
option: 'anon' | 'authenticated'
format: 'lib' | 'sql'
value: string
onInferSQL: (sql: string) => void
onExecuteSQL: ({
result,
isAutoLimit,
@@ -90,20 +103,9 @@ export const useTestQueryRLS = () => {
try {
setIsLoading(true)
let formattedValue = value
if (format === 'lib') {
const { sql, valid } = await parseClientCode({ code: value })
if (valid && !!sql) {
formattedValue = sql
onInferSQL(sql)
} else {
return toast.error('Client library code provided is not valid')
}
}
const { appendAutoLimit } = checkIfAppendLimitRequired(formattedValue, limit)
const formattedSql = suffixWithLimit(formattedValue, limit)
const { appendAutoLimit } = checkIfAppendLimitRequired(value, limit)
const formattedSql = suffixWithLimit(value, limit)
const data = await parseQuery({ sql: formattedSql })
if (data.operation !== 'SELECT') {
@@ -175,7 +177,9 @@ export const useTestQueryRLS = () => {
return {
limit,
testQuery,
inferSQLFromLib,
isLoading,
isInferring,
executeSqlError,
parseQueryError,
parseClientCodeError,
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { PermissionAction } from '@supabase/shared-types/out/constants'
import { Search } from 'lucide-react'
import { parseAsBoolean, parseAsJson, parseAsString, useQueryState } from 'nuqs'
@@ -32,7 +33,7 @@ import SchemaSelector from '@/components/ui/SchemaSelector'
import { Shortcut } from '@/components/ui/Shortcut'
import { TextConfirmModal } from '@/components/ui/TextConfirmModalWrapper'
import { useDatabaseFunctionDeleteMutation } from '@/data/database-functions/database-functions-delete-mutation'
import type { DatabaseFunction } from '@/data/database-functions/database-functions-query'
import type { SavedDatabaseFunction } from '@/data/database-functions/database-functions-query'
import { useDatabaseFunctionsQuery } from '@/data/database-functions/database-functions-query'
import { useSchemasQuery } from '@/data/database/schemas-query'
import { useAsyncCheckPermissions } from '@/hooks/misc/useCheckPermissions'
@@ -46,7 +47,7 @@ import { SHORTCUT_IDS } from '@/state/shortcuts/registry'
import { useShortcut } from '@/state/shortcuts/useShortcut'
import { useSidebarManagerSnapshot } from '@/state/sidebar-manager-state'
const createFunctionSnippet = `create function function_name()
const createFunctionSnippet = safeSql`create function function_name()
returns void
language plpgsql
as $$
@@ -79,7 +80,7 @@ export const FunctionsList = () => {
}
}
const duplicateFunction = (fn: DatabaseFunction) => {
const duplicateFunction = (fn: SavedDatabaseFunction) => {
if (isInlineEditorEnabled) {
const dupFn = {
...fn,
@@ -94,7 +95,7 @@ export const FunctionsList = () => {
}
}
const editFunction = (fn: DatabaseFunction) => {
const editFunction = (fn: SavedDatabaseFunction) => {
setSelectedFunctionIdToDuplicate(null)
if (isInlineEditorEnabled) {
setEditorPanelValue(fn.complete_statement)
@@ -15,19 +15,19 @@ import {
TableRow,
} from 'ui'
import type { EventTrigger } from './EventTriggerList.utils'
import { SUPABASE_ROLES } from '@/components/interfaces/Database/Roles/Roles.constants'
import { ButtonTooltip } from '@/components/ui/ButtonTooltip'
import type { DatabaseEventTrigger } from '@/data/database-event-triggers/database-event-triggers-query'
interface EventTriggerListProps {
filterString: string
eventTriggers: DatabaseEventTrigger[]
eventTriggers: EventTrigger[]
ownerFilter: string[]
canEdit: boolean
onEditTrigger: (trigger: DatabaseEventTrigger) => void
onEditTriggerWithAssistant: (trigger: DatabaseEventTrigger) => void
onDuplicateTrigger: (trigger: DatabaseEventTrigger) => void
onDeleteTrigger: (trigger: DatabaseEventTrigger) => void
onEditTrigger: (trigger: EventTrigger) => void
onEditTriggerWithAssistant: (trigger: EventTrigger) => void
onDuplicateTrigger: (trigger: EventTrigger) => void
onDeleteTrigger: (trigger: EventTrigger) => void
}
const SYSTEM_OWNERS = new Set<string>(SUPABASE_ROLES)
@@ -1,37 +1,43 @@
import {
ident,
joinSqlFragments,
keyword,
literal,
safeSql,
type SafeSqlFragment,
} from '@supabase/pg-meta'
import type { DatabaseEventTrigger } from '@/data/database-event-triggers/database-event-triggers-query'
const ensureSemicolon = (statement: string) => {
const trimmed = statement.trim()
return trimmed.endsWith(';') ? trimmed : `${trimmed};`
export type EventTrigger = Omit<DatabaseEventTrigger, 'function_definition'> & {
function_definition: SafeSqlFragment | null
}
const escapeLiteral = (value: string) => value.replace(/'/g, "''")
const escapeIdentifier = (value: string) => value.replace(/"/g, '""')
export const generateEventTriggerCreateSQL = (trigger: DatabaseEventTrigger) => {
const statements: string[] = []
export const generateEventTriggerCreateSQL = (trigger: EventTrigger): SafeSqlFragment => {
const parts: SafeSqlFragment[] = []
if (trigger.function_definition) {
statements.push(ensureSemicolon(trigger.function_definition))
}
if (trigger.event && trigger.function_schema && trigger.function_name) {
statements.push(`DROP EVENT TRIGGER IF EXISTS "${escapeIdentifier(trigger.name)}";`)
statements.push(
[
`CREATE EVENT TRIGGER "${escapeIdentifier(trigger.name)}"`,
`ON ${trigger.event}`,
trigger.tags && trigger.tags.length > 0
? `WHEN TAG IN (${trigger.tags.map((tag) => `'${escapeLiteral(tag)}'`).join(', ')})`
: null,
`EXECUTE FUNCTION "${escapeIdentifier(trigger.function_schema)}"."${escapeIdentifier(
trigger.function_name
)}"();`,
]
.filter(Boolean)
.join('\n')
const hasTrailingSemicolon = /;\s*$/.test(trigger.function_definition)
parts.push(
hasTrailingSemicolon ? trigger.function_definition : safeSql`${trigger.function_definition};`
)
}
return statements.filter(Boolean).join('\n\n').trim()
if (trigger.event && trigger.function_schema && trigger.function_name) {
parts.push(safeSql`DROP EVENT TRIGGER IF EXISTS ${ident(trigger.name)};`)
const tagClause =
trigger.tags && trigger.tags.length > 0
? safeSql`\nWHEN TAG IN (${joinSqlFragments(
trigger.tags.map((tag) => literal(tag)),
', '
)})`
: safeSql``
parts.push(safeSql`CREATE EVENT TRIGGER ${ident(trigger.name)}
ON ${keyword(trigger.event)}${tagClause}
EXECUTE FUNCTION ${ident(trigger.function_schema)}.${ident(trigger.function_name)}();`)
}
return parts.length > 0 ? joinSqlFragments(parts, '\n\n') : safeSql``
}
@@ -1,4 +1,6 @@
export const DEFAULT_EVENT_TRIGGER_SQL = `CREATE OR REPLACE FUNCTION event_trigger_fn()
import { safeSql } from '@supabase/pg-meta'
export const DEFAULT_EVENT_TRIGGER_SQL = safeSql`CREATE OR REPLACE FUNCTION event_trigger_fn()
RETURNS event_trigger
LANGUAGE plpgsql
AS $$
@@ -13,7 +15,7 @@ ON ddl_command_end
EXECUTE FUNCTION event_trigger_fn();
`
export const AUTO_ENABLE_RLS_EVENT_TRIGGER_SQL = `
export const AUTO_ENABLE_RLS_EVENT_TRIGGER_SQL = safeSql`
CREATE OR REPLACE FUNCTION rls_auto_enable()
RETURNS EVENT_TRIGGER
LANGUAGE plpgsql
@@ -59,7 +61,7 @@ export const EVENT_TRIGGER_TEMPLATES = [
{
name: 'Prevent table drops',
description: 'Block dropping tables using the sql_drop event trigger.',
content: `-- Function
content: safeSql`-- Function
CREATE OR REPLACE FUNCTION dont_drop_function()
RETURNS event_trigger LANGUAGE plpgsql AS $$
DECLARE
@@ -8,7 +8,7 @@ import { EmptyStatePresentational } from 'ui-patterns'
import { GenericSkeletonLoader } from 'ui-patterns/ShimmeringLoader'
import { EventTriggerList } from './EventTriggerList'
import { generateEventTriggerCreateSQL } from './EventTriggerList.utils'
import { generateEventTriggerCreateSQL, type EventTrigger } from './EventTriggerList.utils'
import { DEFAULT_EVENT_TRIGGER_SQL, EVENT_TRIGGER_TEMPLATES } from './EventTriggers.constants'
import { DeleteEventTrigger } from '@/components/interfaces/Database/Triggers/DeleteEventTrigger'
import {
@@ -21,10 +21,7 @@ import { ButtonTooltip } from '@/components/ui/ButtonTooltip'
import { DocsButton } from '@/components/ui/DocsButton'
import { Shortcut } from '@/components/ui/Shortcut'
import { useDatabaseEventTriggerDeleteMutation } from '@/data/database-event-triggers/database-event-trigger-delete-mutation'
import {
useDatabaseEventTriggersQuery,
type DatabaseEventTrigger,
} from '@/data/database-event-triggers/database-event-triggers-query'
import { useDatabaseEventTriggersQuery } from '@/data/database-event-triggers/database-event-triggers-query'
import { useAsyncCheckPermissions } from '@/hooks/misc/useCheckPermissions'
import { useSelectedProjectQuery } from '@/hooks/misc/useSelectedProject'
import { DOCS_URL } from '@/lib/constants'
@@ -49,7 +46,7 @@ export const EventTriggersList = () => {
parseAsJson(selectFilterSchema.parse)
)
const ownerFilterValue = ownerFilter ?? DEFAULT_OWNER_FILTER
const [triggerToDelete, setTriggerToDelete] = useState<DatabaseEventTrigger | null>(null)
const [triggerToDelete, setTriggerToDelete] = useState<EventTrigger | null>(null)
const searchInputRef = useRef<HTMLInputElement>(null)
const { openSidebar } = useSidebarManagerSnapshot()
const aiSnap = useAiAssistantStateSnapshot()
@@ -92,7 +89,7 @@ export const EventTriggersList = () => {
openSidebar(SIDEBAR_KEYS.EDITOR_PANEL)
}
const editEventTrigger = (trigger: DatabaseEventTrigger) => {
const editEventTrigger = (trigger: EventTrigger) => {
setEditorPanelInitialPrompt(`Update the event trigger "${trigger.name}" that...`)
const sql = generateEventTriggerCreateSQL(trigger)
setEditorPanelValue(sql.length > 0 ? sql : DEFAULT_EVENT_TRIGGER_SQL)
@@ -100,7 +97,7 @@ export const EventTriggersList = () => {
openSidebar(SIDEBAR_KEYS.EDITOR_PANEL)
}
const editEventTriggerWithAssistant = (trigger: DatabaseEventTrigger) => {
const editEventTriggerWithAssistant = (trigger: EventTrigger) => {
const sql = generateEventTriggerCreateSQL(trigger)
openSidebar(SIDEBAR_KEYS.AI_ASSISTANT)
aiSnap.newChat({
@@ -128,7 +125,7 @@ export const EventTriggersList = () => {
})
}
const duplicateEventTrigger = (trigger: DatabaseEventTrigger) => {
const duplicateEventTrigger = (trigger: EventTrigger) => {
const duplicateTrigger = { ...trigger, name: `${trigger.name}_duplicate` }
setEditorPanelInitialPrompt('Create a new event trigger that...')
const sql = generateEventTriggerCreateSQL(duplicateTrigger)
@@ -137,7 +134,7 @@ export const EventTriggersList = () => {
openSidebar(SIDEBAR_KEYS.EDITOR_PANEL)
}
const handleDeleteEventTrigger = (trigger: DatabaseEventTrigger) => {
const handleDeleteEventTrigger = (trigger: EventTrigger) => {
setTriggerToDelete(trigger)
}
@@ -1,4 +1,3 @@
import { PostgresTrigger } from '@supabase/postgres-meta'
import { PermissionAction } from '@supabase/shared-types/out/constants'
import { useParams } from 'common'
import { includes, sortBy } from 'lodash'
@@ -17,7 +16,7 @@ import {
TableRow,
} from 'ui'
import { generateTriggerCreateSQL } from './TriggerList.utils'
import { generateTriggerCreateSQL, type PostgresTrigger } from './TriggerList.utils'
import { selectFilterSchema } from '@/components/interfaces/Reports/v2/ReportsSelectFilter'
import { SIDEBAR_KEYS } from '@/components/layouts/ProjectLayout/LayoutSidebar/LayoutSidebarProvider'
import { ButtonTooltip } from '@/components/ui/ButtonTooltip'
@@ -1,39 +1,39 @@
interface PostgresTrigger {
activation: string
condition: string | null
enabled_mode: string
events: string[]
function_args: string[]
function_name: string
function_schema: string
id: number
name: string
orientation: string
schema: string
table: string
table_id: number
import { ident, joinSqlFragments, keyword, safeSql, type SafeSqlFragment } from '@supabase/pg-meta'
import type { DatabaseTriggersData } from '@/data/database-triggers/database-triggers-query'
export type PostgresTrigger = Omit<
DatabaseTriggersData[number],
'function_args' | 'condition' | 'events'
> & {
function_args: SafeSqlFragment[]
condition: SafeSqlFragment | null
events: SafeSqlFragment[]
}
export const generateTriggerCreateSQL = (trigger: PostgresTrigger) => {
const events = trigger.events.join(' OR ')
const args = trigger.function_args.length > 0 ? `(${trigger.function_args.join(', ')})` : '()'
export const generateTriggerCreateSQL = (trigger: PostgresTrigger): SafeSqlFragment => {
const events = joinSqlFragments(trigger.events, ' OR ')
const args =
trigger.function_args.length > 0
? safeSql`(${joinSqlFragments(trigger.function_args, ', ')})`
: safeSql`()`
// Note: CREATE OR REPLACE is not supported for triggers
// We need to drop the existing trigger first if we want to replace it
let sql = `
DROP TRIGGER IF EXISTS "${trigger.name}" ON "${trigger.schema}"."${trigger.table}";
let sql = safeSql`
DROP TRIGGER IF EXISTS ${ident(trigger.name)} ON ${ident(trigger.schema)}.${ident(trigger.table)};
CREATE TRIGGER "${trigger.name}"
${trigger.activation} ${events}
ON "${trigger.schema}"."${trigger.table}"
FOR EACH ${trigger.orientation}
CREATE TRIGGER ${ident(trigger.name)}
${keyword(trigger.activation)} ${events}
ON ${ident(trigger.schema)}.${ident(trigger.table)}
FOR EACH ${keyword(trigger.orientation)}
`
if (trigger.condition) {
sql += `WHEN (${trigger.condition})\n`
sql = safeSql`${sql} WHEN (${trigger.condition})\n`
}
sql += `EXECUTE FUNCTION "${trigger.function_schema}"."${trigger.function_name}"${args};`
sql = safeSql`${sql} EXECUTE FUNCTION ${ident(trigger.function_schema)}.${ident(trigger.function_name)}${args};`
return sql.trim()
return sql
}
@@ -1,4 +1,4 @@
import type { PostgresTrigger } from '@supabase/postgres-meta'
import { safeSql } from '@supabase/pg-meta'
import { PermissionAction } from '@supabase/shared-types/out/constants'
import { DatabaseZap, Search } from 'lucide-react'
import { parseAsBoolean, parseAsJson, parseAsString, useQueryState } from 'nuqs'
@@ -10,7 +10,7 @@ import { GenericSkeletonLoader } from 'ui-patterns/ShimmeringLoader'
import { CreateTriggerButtons } from './CreateTriggerButtons'
import { TriggerList } from './TriggerList'
import { generateTriggerCreateSQL } from './TriggerList.utils'
import { generateTriggerCreateSQL, type PostgresTrigger } from './TriggerList.utils'
import { useIsInlineEditorEnabled } from '@/components/interfaces/Account/Preferences/useDashboardSettings'
import { ProtectedSchemaWarning } from '@/components/interfaces/Database/ProtectedSchemaWarning'
import { TriggerSheet } from '@/components/interfaces/Database/Triggers/TriggerSheet'
@@ -115,10 +115,12 @@ export const TriggersList = () => {
setTriggerToDuplicate(null)
if (isInlineEditorEnabled) {
setEditorPanelInitialPrompt('Create a new database trigger that...')
setEditorPanelValue(`create trigger trigger_name
setEditorPanelValue(
safeSql`create trigger trigger_name
after insert or update or delete on table_name
for each row
execute function function_name();`)
execute function function_name();`
)
if (editorPanelTemplates.length > 0) {
setEditorPanelTemplates([])
}
@@ -33,10 +33,10 @@ export function ExplainHeader({ mode, onToggleMode, summary, id, rows }: Explain
const getPromptData = () => {
if (!id) return null
const snippet = snapV2.snippets[id]?.snippet
if (!snippet?.content?.sql) return null
if (!snippet?.content?.unchecked_sql) return null
return buildExplainPrompt({
sql: snippet.content.sql,
sql: snippet.content.unchecked_sql,
explainPlanRows: (rows as QueryPlanRow[]) ?? [],
})
}
@@ -1,3 +1,5 @@
import type { SafeSqlFragment } from '@supabase/pg-meta'
export interface ChartDataPoint {
period_start: number
timestamp: string
@@ -19,7 +21,7 @@ export interface ParsedLogEntry {
application_name?: string
calls?: number
database_name?: string
query?: string
query?: SafeSqlFragment
query_id?: number
total_exec_time?: number
total_plan_time?: number
@@ -1,3 +1,4 @@
import { safeSql, type SafeSqlFragment } from '@supabase/pg-meta'
import { wrapWithRollback } from '@supabase/pg-meta/src/query'
import { useParams } from 'common'
import { Search, TextSearch, X } from 'lucide-react'
@@ -174,7 +175,7 @@ export const QueryInsightsTable = ({
}, [mode, selectedTriageRow, selectedRow, filteredTriageItems, explorerItems])
const runExplain = useCallback(
(query: string) => {
(query: SafeSqlFragment) => {
if (explainResults[query]) return
if (explainLoadingQuery) return
const requestQuery = query
@@ -183,7 +184,7 @@ export const QueryInsightsTable = ({
{
projectRef: project?.ref,
connectionString: project?.connectionString,
sql: wrapWithRollback(`EXPLAIN ANALYZE ${requestQuery}`),
sql: wrapWithRollback(safeSql`EXPLAIN ANALYZE ${requestQuery}`),
},
{
onSuccess(data) {
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { describe, expect, it, vi } from 'vitest'
import { hasIndexRecommendations } from '../../QueryPerformance/IndexAdvisor/index-advisor.utils'
@@ -9,7 +10,7 @@ vi.mock('../../QueryPerformance/IndexAdvisor/index-advisor.utils', () => ({
}))
const baseRow: QueryPerformanceRow = {
query: 'SELECT * FROM users',
query: safeSql`SELECT * FROM users`,
calls: 10,
mean_time: 50,
min_time: 10,
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { describe, expect, it } from 'vitest'
import type { ParsedLogEntry } from '../QueryInsights.types'
@@ -65,7 +66,7 @@ describe('parseSupamonitorLogs', () => {
})
it('handles multiple log entries', () => {
const raw = [makeSampleLog(), makeSampleLog({ query: 'SELECT 2', query_id: 2 })]
const raw = [makeSampleLog(), makeSampleLog({ query: safeSql`SELECT 2`, query_id: 2 })]
const result = parseSupamonitorLogs(raw)
expect(result).toHaveLength(2)
})
@@ -78,7 +79,7 @@ describe('transformLogsToChartData', () => {
})
it('filters out entries with no timestamp', () => {
const logs: ParsedLogEntry[] = [{ query: 'SELECT 1', calls: 5 }]
const logs: ParsedLogEntry[] = [{ query: safeSql`SELECT 1`, calls: 5 }]
const result = transformLogsToChartData(logs)
expect(result).toEqual([])
})
@@ -167,8 +168,8 @@ describe('aggregateLogsByQuery', () => {
it('skips entries with empty or whitespace-only queries', () => {
const logs: ParsedLogEntry[] = [
{ query: '', calls: 5 },
{ query: ' ', calls: 3 },
{ query: safeSql``, calls: 5 },
{ query: safeSql` `, calls: 3 },
]
const result = aggregateLogsByQuery(logs)
expect(result).toEqual([])
@@ -177,7 +178,7 @@ describe('aggregateLogsByQuery', () => {
it('aggregates a single log entry correctly', () => {
const logs: ParsedLogEntry[] = [
{
query: 'SELECT 1',
query: safeSql`SELECT 1`,
user_name: 'postgres',
application_name: 'app',
calls: 10,
@@ -207,7 +208,7 @@ describe('aggregateLogsByQuery', () => {
it('aggregates multiple entries for the same query', () => {
const logs: ParsedLogEntry[] = [
{
query: 'SELECT 1',
query: safeSql`SELECT 1`,
user_name: 'postgres',
calls: 5,
total_exec_time: 50,
@@ -218,7 +219,7 @@ describe('aggregateLogsByQuery', () => {
max_plan_time: 3,
},
{
query: 'SELECT 1',
query: safeSql`SELECT 1`,
user_name: 'postgres',
calls: 10,
total_exec_time: 100,
@@ -243,8 +244,8 @@ describe('aggregateLogsByQuery', () => {
it('normalizes whitespace differences in queries', () => {
const logs: ParsedLogEntry[] = [
{ query: 'SELECT 1', calls: 5, total_exec_time: 50, total_plan_time: 0 },
{ query: 'SELECT 1', calls: 3, total_exec_time: 30, total_plan_time: 0 },
{ query: safeSql`SELECT 1`, calls: 5, total_exec_time: 50, total_plan_time: 0 },
{ query: safeSql`SELECT 1`, calls: 3, total_exec_time: 30, total_plan_time: 0 },
]
const result = aggregateLogsByQuery(logs)
@@ -255,9 +256,9 @@ describe('aggregateLogsByQuery', () => {
it('sorts results by total_time descending', () => {
const logs: ParsedLogEntry[] = [
{ query: 'SELECT 1', calls: 1, total_exec_time: 10, total_plan_time: 0 },
{ query: 'SELECT 2', calls: 1, total_exec_time: 100, total_plan_time: 0 },
{ query: 'SELECT 3', calls: 1, total_exec_time: 50, total_plan_time: 0 },
{ query: safeSql`SELECT 1`, calls: 1, total_exec_time: 10, total_plan_time: 0 },
{ query: safeSql`SELECT 2`, calls: 1, total_exec_time: 100, total_plan_time: 0 },
{ query: safeSql`SELECT 3`, calls: 1, total_exec_time: 50, total_plan_time: 0 },
]
const result = aggregateLogsByQuery(logs)
@@ -270,8 +271,8 @@ describe('aggregateLogsByQuery', () => {
it('calculates prop_total_time as percentage of total execution', () => {
const logs: ParsedLogEntry[] = [
{ query: 'SELECT 1', calls: 1, total_exec_time: 75, total_plan_time: 0 },
{ query: 'SELECT 2', calls: 1, total_exec_time: 25, total_plan_time: 0 },
{ query: safeSql`SELECT 1`, calls: 1, total_exec_time: 75, total_plan_time: 0 },
{ query: safeSql`SELECT 2`, calls: 1, total_exec_time: 25, total_plan_time: 0 },
]
const result = aggregateLogsByQuery(logs)
@@ -282,7 +283,7 @@ describe('aggregateLogsByQuery', () => {
it('handles zero calls gracefully (mean_time defaults to 0)', () => {
const logs: ParsedLogEntry[] = [
{ query: 'SELECT 1', calls: 0, total_exec_time: 100, total_plan_time: 0 },
{ query: safeSql`SELECT 1`, calls: 0, total_exec_time: 100, total_plan_time: 0 },
]
const result = aggregateLogsByQuery(logs)
@@ -293,7 +294,7 @@ describe('aggregateLogsByQuery', () => {
it('sets static fields correctly', () => {
const logs: ParsedLogEntry[] = [
{ query: 'SELECT 1', calls: 1, total_exec_time: 10, total_plan_time: 0 },
{ query: safeSql`SELECT 1`, calls: 1, total_exec_time: 10, total_plan_time: 0 },
]
const result = aggregateLogsByQuery(logs)
@@ -1,4 +1,7 @@
import { type SafeSqlFragment } from '@supabase/pg-meta'
import type { QueryPerformanceRow } from '../../QueryPerformance/QueryPerformance.types'
import type { Logs } from '../../Settings/Logs/Logs.types'
import {
SCHEMA_INTROSPECTION_REGEX,
SUPAMONITOR_EXCLUDED_APP_NAMES,
@@ -26,29 +29,42 @@ export function filterSystemLogs(
})
}
export function parseSupamonitorLogs(logData: any[]): ParsedLogEntry[] {
function asString(unknown: unknown): string | undefined {
if (typeof unknown === 'string') return unknown
if (unknown === null || unknown === undefined) return undefined
return String(unknown)
}
function asNumber(unknown: unknown): number | undefined {
if (typeof unknown === 'number') return unknown
if (unknown === null || unknown === undefined) return undefined
const parsed = Number(unknown)
return Number.isNaN(parsed) ? undefined : parsed
}
export function parseSupamonitorLogs(logData: Logs['result']): ParsedLogEntry[] {
if (!logData || logData.length === 0) return []
return logData.map((log) => ({
timestamp: log.timestamp,
application_name: log.application_name,
calls: log.calls,
database_name: log.database_name,
timestamp: asString(log.timestamp),
application_name: asString(log.application_name),
calls: asNumber(log.calls),
database_name: asString(log.database_name),
query: log.query,
query_id: log.query_id,
total_exec_time: log.total_exec_time,
total_plan_time: log.total_plan_time,
user_name: log.user_name,
mean_exec_time: log.mean_exec_time,
mean_plan_time: log.mean_plan_time,
min_exec_time: log.min_exec_time,
max_exec_time: log.max_exec_time,
min_plan_time: log.min_plan_time,
max_plan_time: log.max_plan_time,
p50_exec_time: log.p50_exec_time,
p95_exec_time: log.p95_exec_time,
p50_plan_time: log.p50_plan_time,
p95_plan_time: log.p95_plan_time,
query_id: asNumber(log.query_id),
total_exec_time: asNumber(log.total_exec_time),
total_plan_time: asNumber(log.total_plan_time),
user_name: asString(log.user_name),
mean_exec_time: asNumber(log.mean_exec_time),
mean_plan_time: asNumber(log.mean_plan_time),
min_exec_time: asNumber(log.min_exec_time),
max_exec_time: asNumber(log.max_exec_time),
min_plan_time: asNumber(log.min_plan_time),
max_plan_time: asNumber(log.max_plan_time),
p50_exec_time: asNumber(log.p50_exec_time),
p95_exec_time: asNumber(log.p95_exec_time),
p50_plan_time: asNumber(log.p50_plan_time),
p95_plan_time: asNumber(log.p95_plan_time),
}))
}
@@ -170,7 +186,7 @@ export function aggregateLogsByQuery(parsedLogs: ParsedLogEntry[]): QueryPerform
const propTotalTime = totalExecutionTime > 0 ? (stats.totalTime / totalExecutionTime) * 100 : 0
aggregatedData.push({
query: stats.query,
query: stats.query as SafeSqlFragment,
rolname: stats.rolname,
application_name: stats.applicationName,
calls: stats.totalCalls,
@@ -1,7 +1,9 @@
import type { SafeSqlFragment } from '@supabase/pg-meta'
import { GetIndexAdvisorResultResponse } from '@/data/database/retrieve-index-advisor-result-query'
export interface QueryPerformanceRow {
query: string
query: SafeSqlFragment
prop_total_time: number
total_time: number
calls: number
@@ -11,7 +13,7 @@ export interface QueryPerformanceRow {
rows_read: number
p95_time?: number
cache_hit_rate: number
rolname: string
rolname?: string
application_name?: string
index_advisor_result?: GetIndexAdvisorResultResponse | null
_total_cache_hits?: number
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { describe, expect, it, vi } from 'vitest'
import { transformStatementDataToRows } from './WithStatements.utils'
@@ -13,13 +14,14 @@ vi.mock('../IndexAdvisor/index-advisor.utils', () => ({
}))
const makeRow = (overrides: Record<string, any> = {}) => ({
query: 'SELECT 1',
query: safeSql`SELECT 1`,
rolname: 'postgres',
calls: 10,
mean_time: 5.0,
min_time: 1.0,
max_time: 20.0,
total_time: 50.0,
prop_total_time: 0,
rows_read: 100,
cache_hit_rate: 0.95,
index_advisor_result: null,
@@ -50,8 +52,18 @@ describe('transformStatementDataToRows', () => {
})
})
it('defaults missing numeric fields to 0', () => {
const data = [{ query: 'SELECT 1' }]
it('preserves zero-valued numeric fields as 0', () => {
const data = [
makeRow({
calls: 0,
mean_time: 0,
min_time: 0,
max_time: 0,
total_time: 0,
rows_read: 0,
cache_hit_rate: 0,
}),
]
const result = transformStatementDataToRows(data)
expect(result).toHaveLength(1)
@@ -72,8 +84,8 @@ describe('transformStatementDataToRows', () => {
it('calculates prop_total_time as percentage of total time', () => {
const data = [
makeRow({ query: 'Q1', total_time: 75 }),
makeRow({ query: 'Q2', total_time: 25 }),
makeRow({ query: safeSql`Q1`, total_time: 75 }),
makeRow({ query: safeSql`Q2`, total_time: 25 }),
]
const result = transformStatementDataToRows(data)
@@ -108,7 +120,7 @@ describe('transformStatementDataToRows', () => {
describe('filterIndexAdvisor mode', () => {
it('keeps rows for non-protected schema queries', () => {
const data = [makeRow({ query: 'SELECT * FROM public.users' })]
const data = [makeRow({ query: safeSql`SELECT * FROM public.users` })]
const result = transformStatementDataToRows(data, true)
expect(result).toHaveLength(1)
})
@@ -116,7 +128,7 @@ describe('transformStatementDataToRows', () => {
it('keeps protected-schema rows that have valid recommendations', () => {
const data = [
makeRow({
query: 'SELECT * FROM auth.users',
query: safeSql`SELECT * FROM auth.users`,
index_advisor_result: { index_statements: ['CREATE INDEX ON auth.users (id)'] },
}),
]
@@ -127,7 +139,7 @@ describe('transformStatementDataToRows', () => {
it('filters out protected-schema rows with no valid recommendations', () => {
const data = [
makeRow({
query: 'SELECT * FROM auth.users',
query: safeSql`SELECT * FROM auth.users`,
index_advisor_result: { _mock_filter_null: true },
}),
]
@@ -138,7 +150,7 @@ describe('transformStatementDataToRows', () => {
it('does not filter protected-schema rows when filterIndexAdvisor is false', () => {
const data = [
makeRow({
query: 'SELECT * FROM auth.users',
query: safeSql`SELECT * FROM auth.users`,
index_advisor_result: { _mock_filter_null: true },
}),
]
@@ -5,7 +5,7 @@ import {
import { QueryPerformanceRow } from '../QueryPerformance.types'
export const transformStatementDataToRows = (
data: any[],
data: QueryPerformanceRow[],
filterIndexAdvisor: boolean = false
): QueryPerformanceRow[] => {
if (!data || data.length === 0) return []
@@ -70,7 +70,7 @@ export const ReportBlock = ({
}
)
const sql = isSnippet ? (data?.content as SqlSnippets.Content)?.sql : undefined
const sql = isSnippet ? (data?.content as SqlSnippets.Content)?.unchecked_sql : undefined
const chartConfig = { ...DEFAULT_CHART_CONFIG, ...(item.chartConfig ?? {}) }
const isDeprecatedChart = DEPRECATED_REPORTS.includes(item.attribute)
const snippetMissing = contentError?.message.includes('Content not found')
@@ -291,7 +291,7 @@ const MonacoEditor = ({
onMount={handleEditorOnMount}
onChange={handleEditorChange}
defaultLanguage="pgsql"
defaultValue={snippet?.snippet.content?.sql}
defaultValue={snippet?.snippet.content?.unchecked_sql}
path={id}
options={{
tabSize: 2,
@@ -140,7 +140,7 @@ export const MoveQueryModal = ({ visible, snippets = [], onClose }: MoveQueryMod
let snippetContent = (snippet as SnippetWithContent)?.content
if (snippetContent === undefined) {
const { content } = await getContentById({ projectRef: ref, id: snippet.id })
if ('sql' in content) {
if ('unchecked_sql' in content) {
snippetContent = content
}
}
@@ -89,11 +89,11 @@ const RenameQueryModal = ({
const generateTitle = async () => {
if ('content' in snippet && isSQLSnippet) {
getGeneratedValues({ sql: snippet.content.sql })
getGeneratedValues({ sql: snippet.content.unchecked_sql })
} else {
try {
const { content } = await getContentById({ projectRef: ref, id: snippet.id })
if ('sql' in content) getGeneratedValues({ sql: content.sql })
if ('unchecked_sql' in content) getGeneratedValues({ sql: content.unchecked_sql })
} catch (error) {
toast.error('Unable to generate title based on query contents')
}
@@ -1,3 +1,4 @@
import { untrustedSql } from '@supabase/pg-meta'
import { IS_PLATFORM } from 'common'
import type { SqlSnippets, UserContent } from '@/types'
@@ -13,7 +14,7 @@ export const NEW_SQL_SNIPPET_SKELETON: UserContent<SqlSnippets.Content> = {
content: {
schema_version: SQL_SNIPPET_SCHEMA_VERSION,
content_id: '',
sql: 'this is a test',
unchecked_sql: untrustedSql(''),
},
}
@@ -1,4 +1,11 @@
import type { Monaco } from '@monaco-editor/react'
import {
acceptUntrustedSql,
rawSql,
safeSql,
type SafeSqlFragment,
type UntrustedSqlFragment,
} from '@supabase/pg-meta'
import { wrapWithRollback } from '@supabase/pg-meta/src/query'
import { useQueryClient } from '@tanstack/react-query'
import { IS_PLATFORM, LOCAL_STORAGE_KEYS, useFlag, useParams } from 'common'
@@ -311,7 +318,8 @@ export const SQLEditor = () => {
const selection = editor.getSelection()
const selectedValue = selection ? editor.getModel()?.getValueInRange(selection) : undefined
const sql = snippet
? ((selectedValue || editorRef.current?.getValue()) ?? snippet.snippet.content?.sql)
? ((selectedValue || editorRef.current?.getValue()) ??
snippet.snippet.content?.unchecked_sql)
: selectedValue || editorRef.current?.getValue()
const formattedSql = formatSql(sql)
@@ -333,7 +341,7 @@ export const SQLEditor = () => {
})
const executeQuery = useCallback(
async (force: boolean = false, sqlOverride?: string) => {
async (force: boolean = false, sqlOverride?: SafeSqlFragment) => {
if (isDiffOpen) {
clearPendingRunRefocus()
return
@@ -353,7 +361,8 @@ export const SQLEditor = () => {
const selectedValue = selection ? editor.getModel()?.getValueInRange(selection) : undefined
const editorSql = snippet
? ((selectedValue || editorRef.current?.getValue()) ?? snippet.snippet.content?.sql)
? ((selectedValue || editorRef.current?.getValue()) ??
snippet.snippet.content?.unchecked_sql)
: selectedValue || editorRef.current?.getValue()
const sql = sqlOverride ?? editorSql
@@ -405,8 +414,9 @@ export const SQLEditor = () => {
return toast.error('Unable to run query: Connection string is missing')
}
const { appendAutoLimit } = checkIfAppendLimitRequired(sql, limit)
const formattedSql = suffixWithLimit(sql, limit)
const userSql = rawSql(sql)
const { appendAutoLimit } = checkIfAppendLimitRequired(userSql, limit)
const formattedSql = suffixWithLimit(userSql, limit)
execute({
projectRef: project.ref,
@@ -463,7 +473,8 @@ export const SQLEditor = () => {
const selectedValue = selection ? editor.getModel()?.getValueInRange(selection) : undefined
const sql = snippet
? ((selectedValue || editorRef.current?.getValue()) ?? snippet.snippet.content?.sql)
? ((selectedValue || editorRef.current?.getValue()) ??
snippet.snippet.content?.unchecked_sql)
: selectedValue || editorRef.current?.getValue()
// Check for multiple statements - EXPLAIN only works on a single statement
@@ -491,7 +502,8 @@ export const SQLEditor = () => {
}
// Wrap the query with EXPLAIN ANALYZE only if it's not already an EXPLAIN query
const explainSql = isExplainSql(sql) ? sql : `EXPLAIN ANALYZE ${sql}`
const userSql = rawSql(sql ?? '')
const explainSql = isExplainSql(sql) ? userSql : safeSql`EXPLAIN ANALYZE ${userSql}`
// Wrap EXPLAIN queries in a transaction with rollback to prevent data modifications
// This ensures EXPLAIN ANALYZE INSERT/UPDATE/DELETE queries don't actually modify data
@@ -566,7 +578,9 @@ export const SQLEditor = () => {
const buildDebugPrompt = useCallback(() => {
const snippet = snapV2.snippets[id]
const result = snapV2.results[id]?.[0]
const sql = (snippet?.snippet.content?.sql ?? '').replace(sqlAiDisclaimerComment, '').trim()
const sql = (snippet?.snippet.content?.unchecked_sql ?? '')
.replace(sqlAiDisclaimerComment, '')
.trim()
const errorMessage = result?.error?.message ?? 'Unknown error'
const prompt = `Help me to debug the attached sql snippet which gives the following error: \n\n${errorMessage}`
@@ -581,7 +595,7 @@ export const SQLEditor = () => {
aiSnap.newChat({
name: 'Debug SQL snippet',
sqlSnippets: [
(snippet.snippet.content?.sql ?? '').replace(sqlAiDisclaimerComment, '').trim(),
(snippet.snippet.content?.unchecked_sql ?? '').replace(sqlAiDisclaimerComment, '').trim(),
],
initialInput: `Help me to debug the attached sql snippet which gives the following error: \n\n${result.error.message}`,
})
@@ -878,7 +892,7 @@ export const SQLEditor = () => {
shouldRefocusAfterRunRef.current = true
setPotentialIssues(undefined)
refocusEditor()
void executeQuery(true, rewrittenSql)
void executeQuery(true, acceptUntrustedSql(rewrittenSql as UntrustedSqlFragment))
}}
/>
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { stripIndent } from 'common-tags'
import { describe, expect, it, test } from 'vitest'
@@ -136,19 +137,19 @@ select * from cities
// [Joshen] These will just need to test the cases when appendAutoLimit returns true then
describe('SQLEditor.utils.ts:suffixWithLimit', () => {
test('Should add the limit param properly if query ends without a semi colon', () => {
const sql = 'select * from countries'
const sql = safeSql`select * from countries`
const limit = 100
const formattedSql = suffixWithLimit(sql, limit)
expect(formattedSql).toBe('select * from countries limit 100;')
})
test('Should add the limit param properly if query ends with a semi colon', () => {
const sql = 'select * from countries;'
const sql = safeSql`select * from countries;`
const limit = 100
const formattedSql = suffixWithLimit(sql, limit)
expect(formattedSql).toBe('select * from countries limit 100;')
})
test('Should add the limit param properly if query ends with multiple semi colon', () => {
const sql = 'select * from countries;;;;;;;'
const sql = safeSql`select * from countries;;;;;;;`
const limit = 100
const formattedSql = suffixWithLimit(sql, limit)
expect(formattedSql).toBe('select * from countries limit 100;')
@@ -1,3 +1,4 @@
import { untrustedSql, type SafeSqlFragment } from '@supabase/pg-meta'
import { TABLE_EVENT_ACTIONS } from 'common/telemetry-constants'
import {
@@ -80,7 +81,7 @@ export const createSqlSnippetSkeletonV2 = ({
content: {
...NEW_SQL_SNIPPET_SKELETON.content,
content_id: id ?? '',
sql: sql ?? '',
unchecked_sql: untrustedSql(sql ?? ''),
} as any,
isNotSavedInDatabaseYet: true,
}
@@ -253,12 +254,10 @@ export const checkIfAppendLimitRequired = (sql: string, limit: number = 0) => {
return { cleanedSql, appendAutoLimit }
}
export const suffixWithLimit = (sql: string, limit: number = 0) => {
export const suffixWithLimit = (sql: SafeSqlFragment, limit: number = 0): SafeSqlFragment => {
const { cleanedSql, appendAutoLimit } = checkIfAppendLimitRequired(sql, limit)
const formattedSql = appendAutoLimit
? cleanedSql.endsWith(';')
? sql.replace(/[;]+$/, ` limit ${limit};`)
: `${sql} limit ${limit};`
: sql
return formattedSql
if (!appendAutoLimit) return sql
return (
cleanedSql.endsWith(';') ? sql.replace(/[;]+$/, ` limit ${limit};`) : `${sql} limit ${limit};`
) as SafeSqlFragment
}
@@ -1,3 +1,4 @@
import type { SafeSqlFragment } from '@supabase/pg-meta'
import React from 'react'
import type { Datum } from '@/components/ui/Charts/Charts.types'
@@ -27,6 +28,7 @@ export interface LogsEndpointParams {
}
export interface CustomLogData {
query?: SafeSqlFragment | undefined
[other: string]: unknown
}
@@ -1,9 +1,9 @@
import { PostgresPolicy } from '@supabase/postgres-meta'
import { difference } from 'lodash'
import { useRouter } from 'next/router'
import { STORAGE_CLIENT_LIBRARY_MAPPINGS } from './Storage.constants'
import type { StoragePolicyFormField } from './Storage.types'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { WrapperMeta } from '@/components/interfaces/Integrations/Wrappers/Wrappers.types'
import { convertKVStringArrayToJson } from '@/components/interfaces/Integrations/Wrappers/Wrappers.utils'
import { FDW } from '@/data/fdw/fdws-query'
@@ -21,7 +21,7 @@ const shortHash = (str: string) => {
return new Uint32Array([hash])[0].toString(36)
}
export type PoliciesByBucket = { name: string | Symbol; policies: PostgresPolicy[] }[]
export type PoliciesByBucket = { name: string | Symbol; policies: Policy[] }[]
/**
* Formats the policies from the objects table in the storage schema
@@ -31,7 +31,7 @@ export type PoliciesByBucket = { name: string | Symbol; policies: PostgresPolicy
*/
export const formatPoliciesForStorage = (
buckets: Bucket[],
policies: PostgresPolicy[]
policies: Policy[]
): PoliciesByBucket => {
if (policies.length === 0) return []
@@ -58,7 +58,7 @@ export const UNKNOWN_BUCKET_SYMBOL = createWrappedSymbol('unknown-bucket', 'Unkn
*/
export const UNGROUPED_POLICY_SYMBOL = createWrappedSymbol('ungrouped-policy', 'Ungrouped')
const formatStoragePolicies = (buckets: Bucket[], policies: PostgresPolicy[]) => {
const formatStoragePolicies = (buckets: Bucket[], policies: Policy[]) => {
const availableBuckets = buckets.map((bucket) => bucket.name)
const formattedPolicies = policies.map((policy) => {
const { definition: policyDefinition, check: policyCheck } = policy
@@ -93,8 +93,8 @@ export const extractBucketNameFromDefinition = (definition: string | null) => {
return bucketDefinition ? bucketDefinition.split("'")[1] : null
}
const groupPoliciesByBucket = (policies: (PostgresPolicy & { bucket: string | Symbol })[]) => {
const policiesByBucket = new Map<string | Symbol, PostgresPolicy[]>()
const groupPoliciesByBucket = (policies: (Policy & { bucket: string | Symbol })[]) => {
const policiesByBucket = new Map<string | Symbol, Policy[]>()
policies.forEach((policy) => {
if (!policiesByBucket.has(policy.bucket)) {
policiesByBucket.set(policy.bucket, [])
@@ -1,4 +1,3 @@
import { PostgresPolicy } from '@supabase/postgres-meta'
import { useParams } from 'common'
import { isEmpty } from 'lodash'
import { parseAsString, useQueryState } from 'nuqs'
@@ -21,6 +20,7 @@ import { StoragePoliciesBucketRow } from './StoragePoliciesBucketRow'
import { BucketsPolicies, type SelectBucketPolicyForAction } from './StoragePoliciesBucketsSection'
import { StoragePoliciesEditPolicyModal } from './StoragePoliciesEditPolicyModal'
import { PolicyEditorModal } from '@/components/interfaces/Auth/Policies/PolicyEditorModal'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { useDatabasePoliciesQuery } from '@/data/database-policies/database-policies-query'
import { useDatabasePolicyCreateMutation } from '@/data/database-policies/database-policy-create-mutation'
import { useDatabasePolicyDeleteMutation } from '@/data/database-policies/database-policy-delete-mutation'
@@ -33,8 +33,8 @@ export const StoragePolicies = () => {
const { ref: projectRef } = useParams()
const { data: project } = useSelectedProjectQuery()
const [selectedPolicyToEdit, setSelectedPolicyToEdit] = useState<PostgresPolicy>()
const [selectedPolicyToDelete, setSelectedPolicyToDelete] = useState<PostgresPolicy>()
const [selectedPolicyToEdit, setSelectedPolicyToEdit] = useState<Policy>()
const [selectedPolicyToDelete, setSelectedPolicyToDelete] = useState<Policy>()
const [isEditingPolicyForBucket, setIsEditingPolicyForBucket] = useState<{
bucket: string
table: string
@@ -1,4 +1,3 @@
import { PostgresPolicy } from '@supabase/postgres-meta'
import { FilesBucket as FilesBucketIcon } from 'icons'
import { noop } from 'lodash'
import { forwardRef, type CSSProperties } from 'react'
@@ -20,6 +19,7 @@ import {
} from 'ui'
import { PolicyRow } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyRow'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { PUBLIC_BUCKET_TOOLTIP } from '@/components/interfaces/Storage/Storage.constants'
import { Bucket } from '@/data/storage/buckets-query'
@@ -27,11 +27,11 @@ interface StoragePoliciesBucketRowProps {
table: string
label: string
bucket?: Bucket
policies: PostgresPolicy[]
policies: Policy[]
style?: CSSProperties
onSelectPolicyAdd: (bucketName: string | undefined, table: string) => void
onSelectPolicyEdit: (policy: PostgresPolicy, bucketName: string, table: string) => void
onSelectPolicyDelete: (policy: PostgresPolicy) => void
onSelectPolicyEdit: (policy: Policy, bucketName: string, table: string) => void
onSelectPolicyDelete: (policy: Policy) => void
}
export const StoragePoliciesBucketRow = forwardRef<HTMLDivElement, StoragePoliciesBucketRowProps>(
@@ -1,4 +1,3 @@
import { PostgresPolicy } from '@supabase/postgres-meta'
import { useVirtualizer } from '@tanstack/react-virtual'
import { ChevronUp, Search, X } from 'lucide-react'
import { forwardRef, useEffect, useState, type HTMLAttributes, type ReactNode } from 'react'
@@ -22,6 +21,7 @@ import {
import { StoragePoliciesBucketRow } from './StoragePoliciesBucketRow'
import StoragePoliciesPlaceholder from './StoragePoliciesPlaceholder'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { useMainScrollContainer } from '@/components/layouts/MainScrollContainerContext'
import { NoSearchResults } from '@/components/ui/NoSearchResults'
import { type Bucket } from '@/data/storage/buckets-query'
@@ -29,12 +29,12 @@ import { useStaticEffectEvent } from '@/hooks/useStaticEffectEvent'
export type SelectBucketPolicyForAction = {
addPolicy: (bucketName?: string, table?: string) => void
editPolicy: (policy: PostgresPolicy, bucketName?: string, table?: string) => void
deletePolicy: (policy: PostgresPolicy) => void
editPolicy: (policy: Policy, bucketName?: string, table?: string) => void
deletePolicy: (policy: Policy) => void
}
type BucketsPoliciesProps = {
buckets: { bucket: Bucket; policies: PostgresPolicy[] }[]
buckets: { bucket: Bucket; policies: Policy[] }[]
search?: string
debouncedSearch?: string
setSearch: (search: string) => void
@@ -131,7 +131,7 @@ export const BucketsPolicies = ({
}
type BucketsPoliciesVirtualizedListProps = {
items: { bucket: Bucket; policies: PostgresPolicy[] }[]
items: { bucket: Bucket; policies: Policy[] }[]
actions: SelectBucketPolicyForAction
pagination: BucketsPoliciesProps['pagination']
}
@@ -194,13 +194,13 @@ export const SQLEditorTreeViewItem = ({
const snippet = element.metadata
let sql: string = ''
if (snippet.content && snippet.content.sql) {
sql = snippet.content.sql
if (snippet.content && snippet.content.unchecked_sql) {
sql = snippet.content.unchecked_sql
} else {
// Fetch the content first
const { content } = await getContentById({ projectRef, id: snippet.id })
if ('sql' in content) {
sql = content.sql
if ('unchecked_sql' in content) {
sql = content.unchecked_sql
}
}
@@ -237,7 +237,7 @@ function SnippetSelector({
</CommandList_Shadcn_>
<CodeBlock
language="sql"
value={isSQLSnippet ? selectedSnippet?.content?.sql : ''}
value={isSQLSnippet ? selectedSnippet?.content?.unchecked_sql : ''}
wrapperClassName="hidden md:block"
className="w-full h-full border-0 [&>code]:overflow-scroll [&>code]:block [&>code]:w-full [&>code]:h-full"
hideCopy
@@ -249,7 +249,7 @@ function SnippetSelector({
function snippetValue(snippet: SqlSnippet) {
if (snippet.type !== 'sql') return ''
return escapeAttributeSelector(
`${snippet.id}-${snippet.name}-${snippet?.content?.sql.slice(0, 30)}`
`${snippet.id}-${snippet.name}-${snippet?.content?.unchecked_sql.slice(0, 30)}`
).toLowerCase()
}
@@ -127,7 +127,7 @@ export const AIAssistant = ({ className }: AIAssistantProps) => {
const isInSQLEditor = router.pathname.includes('/sql/[id]')
const snippet = snippets[entityId ?? '']
const snippetContent = snippet?.snippet?.content?.sql
const snippetContent = snippet?.snippet?.content?.unchecked_sql
const { data: tables } = useTablesQuery(
{
@@ -1,4 +1,5 @@
import type { Monaco } from '@monaco-editor/react'
import { acceptUntrustedSql, safeSql, untrustedSql } from '@supabase/pg-meta'
import { useQueryClient } from '@tanstack/react-query'
import { useDebounce } from '@uidotdev/usehooks'
import { useParams } from 'common'
@@ -111,7 +112,7 @@ export const EditorPanel = () => {
const isInlineEditorHotkeyEnabled = useIsShortcutEnabled(SHORTCUT_IDS.INLINE_EDITOR_TOGGLE)
const isAIAssistantHotkeyEnabled = useIsShortcutEnabled(SHORTCUT_IDS.AI_ASSISTANT_TOGGLE)
const currentValue = value || ''
const currentValue = value || safeSql``
const { ref } = useParams()
const router = useRouter()
@@ -173,7 +174,7 @@ export const EditorPanel = () => {
useEffect(() => {
if (!snippetById || !activeSnippetId) return
const sqlSnippet = snippetById as unknown as Extract<Content, { type: 'sql' }>
const sql = sqlSnippet.content.sql ?? ''
const sql = sqlSnippet.content.unchecked_sql ?? safeSql``
setValue(sql)
setActiveSnippet(sqlSnippet)
originalSnippetRef.current = { sql, name: sqlSnippet.name }
@@ -228,7 +229,7 @@ export const EditorPanel = () => {
}
executeSql({
sql: suffixWithLimit(currentValue, 100),
sql: suffixWithLimit(acceptUntrustedSql(currentValue), 100),
projectRef: project?.ref,
connectionString: project?.connectionString,
isStatementTimeoutDisabled: true,
@@ -243,8 +244,8 @@ export const EditorPanel = () => {
const isValidExplainQuery = isExplainQuery(results ?? [])
const handleChange = (value: string) => {
setValue(value)
onChange?.(value)
setValue(untrustedSql(value))
onChange?.(untrustedSql(value))
}
const onSelectTemplate = (content: string) => {
@@ -1,3 +1,4 @@
import { untrustedSql, type UntrustedSqlFragment } from '@supabase/pg-meta'
import { useMutation } from '@tanstack/react-query'
import { toast } from 'sonner'
@@ -6,7 +7,10 @@ import { BASE_PATH } from '@/lib/constants'
import { ResponseError, UseCustomMutationOptions } from '@/types'
export type ParseClientCodeResponse = {
sql: string | undefined
// Named unchecked_sql to highlight that this SQL must never be run
// automatically without user confirmation — it is AI-generated and may not
// be correct.
unchecked_sql: UntrustedSqlFragment | undefined
valid: boolean
}
@@ -25,7 +29,10 @@ export async function generateSqlTitle({ code }: ParseClientCodeVariables) {
}).then((res) => res.json())
if (response.error) throw new Error(response.error)
return response as ParseClientCodeResponse
return {
valid: response.valid as boolean,
unchecked_sql: response.sql != null ? untrustedSql(response.sql as string) : undefined,
} satisfies ParseClientCodeResponse
} catch (error) {
throw error
}
+2 -2
View File
@@ -2,6 +2,7 @@ import { useQuery } from '@tanstack/react-query'
import { components } from 'api-types'
import type { Content } from './content-query'
import { remapSqlContentField } from './content-remap'
import { contentKeys } from './keys'
import { get, handleError } from '@/data/fetchers'
import type { ResponseError, UseCustomQueryOptions } from '@/types'
@@ -26,8 +27,7 @@ export async function getContentById(
})
if (error) throw handleError(error)
// override content type
return data as unknown as GetUserContentByIdResponse
return remapSqlContentField(data as unknown as GetUserContentByIdResponse)
}
export type ContentIdData = Awaited<ReturnType<typeof getContentById>>
@@ -1,6 +1,7 @@
import { InfiniteData, useInfiniteQuery } from '@tanstack/react-query'
import { Content, ContentType } from './content-query'
import { remapSqlContentFields } from './content-remap'
import { contentKeys } from './keys'
import { get, handleError } from '@/data/fetchers'
import { UseCustomInfiniteQueryOptions } from '@/types'
@@ -40,7 +41,7 @@ export async function getContent(
return {
cursor: data.cursor,
content: data.data as unknown as Content[],
content: remapSqlContentFields(data.data as unknown as Content[]),
}
}
+2 -1
View File
@@ -1,6 +1,7 @@
import { useQuery } from '@tanstack/react-query'
import { components } from 'api-types'
import { remapSqlContentFields } from './content-remap'
import { contentKeys } from './keys'
import { get, handleError } from '@/data/fetchers'
import type { Dashboards, LogSqlSnippets, SqlSnippets, UseCustomQueryOptions } from '@/types'
@@ -49,7 +50,7 @@ export async function getContent(
return {
cursor: data.cursor,
content: data.data as unknown as Content[],
content: remapSqlContentFields(data.data as unknown as Content[]),
}
}
+27
View File
@@ -0,0 +1,27 @@
// Remap `sql` → `unchecked_sql` on SQL snippet content objects as they cross the API boundary.
// The API stores and returns the field as `sql`; the frontend type uses `unchecked_sql` to make
// it explicit that this value must never be executed without user confirmation.
import { untrustedSql } from '@supabase/pg-meta'
export function remapSqlContentField<T extends { type: string }>(item: T): T {
if (item.type !== 'sql') return item
if (!('content' in item)) return item
const content = item.content as Record<string, unknown>
if (!('sql' in content)) return item
const { sql, ...rest } = content
return { ...item, content: { ...rest, unchecked_sql: untrustedSql(sql as string) } } as T
}
export function remapSqlContentFields<T extends { type: string }>(items: Array<T>): Array<T> {
return items.map(remapSqlContentField)
}
// Reverse remap: `unchecked_sql` → `sql` before sending to the API.
export function unmapSqlContentField<T extends { type: string }>(item: T): T {
if (item.type !== 'sql') return item
if (!('content' in item)) return item
const content = item.content as Record<string, unknown>
if (!('unchecked_sql' in content)) return item
const { unchecked_sql, ...rest } = content
return { ...item, content: { ...rest, sql: unchecked_sql } } as T
}
@@ -2,6 +2,7 @@ import { useMutation, useQueryClient } from '@tanstack/react-query'
import { toast } from 'sonner'
import type { Content } from './content-query'
import { unmapSqlContentField } from './content-remap'
import { contentKeys } from './keys'
import type { Snippet } from './sql-folders-query'
import type { components } from '@/data/api'
@@ -25,7 +26,7 @@ export async function upsertContent(
) {
const { data, error } = await put('/platform/projects/{ref}/content', {
params: { path: { ref: projectRef } },
body: payload,
body: unmapSqlContentField(payload),
headers: { Version: '2' },
signal,
})
@@ -1,6 +1,7 @@
import { InfiniteData, useInfiniteQuery } from '@tanstack/react-query'
import { Content } from './content-query'
import { remapSqlContentFields } from './content-remap'
import { contentKeys } from './keys'
import { SNIPPET_PAGE_LIMIT } from './sql-folders-query'
import { get } from '@/data/fetchers'
@@ -50,7 +51,7 @@ export async function getSqlSnippets(
return {
cursor: data.cursor,
contents: data.data as unknown as SqlSnippet[],
contents: remapSqlContentFields(data.data as unknown as SqlSnippet[]),
}
}
@@ -1,6 +1,7 @@
import { useQuery } from '@tanstack/react-query'
import { databaseEventTriggerKeys } from './keys'
import type { EventTrigger } from '@/components/interfaces/Database/Triggers/EventTriggersList/EventTriggerList.utils'
import { executeSql } from '@/data/sql/execute-sql-query'
import type { ResponseError, UseCustomQueryOptions } from '@/types'
@@ -67,16 +68,23 @@ export async function getDatabaseEventTriggers(
export type DatabaseEventTriggersData = Awaited<ReturnType<typeof getDatabaseEventTriggers>>
export type DatabaseEventTriggersError = ResponseError
export const useDatabaseEventTriggersQuery = <TData = DatabaseEventTriggersData>(
function markSavedEventTriggerSafe(trigger: DatabaseEventTrigger): EventTrigger {
return trigger as EventTrigger
}
export const useDatabaseEventTriggersQuery = <TData = EventTrigger[]>(
{ projectRef, connectionString }: DatabaseEventTriggersVariables,
{
enabled = true,
...options
}: UseCustomQueryOptions<DatabaseEventTriggersData, DatabaseEventTriggersError, TData> = {}
}: UseCustomQueryOptions<EventTrigger[], DatabaseEventTriggersError, TData> = {}
) =>
useQuery<DatabaseEventTriggersData, DatabaseEventTriggersError, TData>({
useQuery<EventTrigger[], DatabaseEventTriggersError, TData>({
queryKey: databaseEventTriggerKeys.list(projectRef),
queryFn: ({ signal }) => getDatabaseEventTriggers({ projectRef, connectionString }, signal),
queryFn: ({ signal }) =>
getDatabaseEventTriggers({ projectRef, connectionString }, signal).then((data) =>
data.map(markSavedEventTriggerSafe)
),
enabled: enabled && typeof projectRef !== 'undefined',
...options,
})
@@ -1,4 +1,4 @@
import pgMeta from '@supabase/pg-meta'
import pgMeta, { type SafeSqlFragment } from '@supabase/pg-meta'
import { useQuery } from '@tanstack/react-query'
import { z } from 'zod'
@@ -12,6 +12,9 @@ export type DatabaseFunctionsVariables = {
}
export type DatabaseFunction = z.infer<typeof pgMeta.functions.pgFunctionZod>
export type SavedDatabaseFunction = Omit<DatabaseFunction, 'complete_statement'> & {
complete_statement: SafeSqlFragment
}
const pgMetaFunctionsList = pgMeta.functions.list()
@@ -33,10 +36,10 @@ export async function getDatabaseFunctions(
headers
)
return result as DatabaseFunction[]
return result as SavedDatabaseFunction[]
}
export type DatabaseFunctionsData = z.infer<typeof pgMetaFunctionsList.zod>
export type DatabaseFunctionsData = Awaited<ReturnType<typeof getDatabaseFunctions>>
export type DatabaseFunctionsError = ResponseError
export const useDatabaseFunctionsQuery = <TData = DatabaseFunctionsData>(
@@ -2,6 +2,7 @@ import { DEFAULT_PLATFORM_APPLICATION_NAME } from '@supabase/pg-meta/src/constan
import { useQuery } from '@tanstack/react-query'
import { databasePoliciesKeys } from './keys'
import type { Policy } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { get, handleError } from '@/data/fetchers'
import { useSelectedProjectQuery } from '@/hooks/misc/useSelectedProject'
import { PROJECT_STATUS } from '@/lib/constants'
@@ -46,19 +47,23 @@ export async function getDatabasePolicies(
export type DatabasePoliciesData = Awaited<ReturnType<typeof getDatabasePolicies>>
export type DatabasePoliciesError = ResponseError
export const useDatabasePoliciesQuery = <TData = DatabasePoliciesData>(
function markSavedPolicySafe(policy: DatabasePoliciesData[number]): Policy {
return policy as Policy
}
export const useDatabasePoliciesQuery = <TData = Policy[]>(
{ projectRef, connectionString, schema }: DatabasePoliciesVariables,
{
enabled = true,
...options
}: UseCustomQueryOptions<DatabasePoliciesData, DatabasePoliciesError, TData> = {}
{ enabled = true, ...options }: UseCustomQueryOptions<Policy[], DatabasePoliciesError, TData> = {}
) => {
const { data: project } = useSelectedProjectQuery()
const isActive = project?.status === PROJECT_STATUS.ACTIVE_HEALTHY
return useQuery<DatabasePoliciesData, DatabasePoliciesError, TData>({
return useQuery<Policy[], DatabasePoliciesError, TData>({
queryKey: databasePoliciesKeys.list(projectRef, schema),
queryFn: ({ signal }) => getDatabasePolicies({ projectRef, connectionString, schema }, signal),
queryFn: ({ signal }) =>
getDatabasePolicies({ projectRef, connectionString, schema }, signal).then((data) =>
data.map(markSavedPolicySafe)
),
enabled: enabled && typeof projectRef !== 'undefined' && isActive,
...options,
})
@@ -2,9 +2,14 @@ import { DEFAULT_PLATFORM_APPLICATION_NAME } from '@supabase/pg-meta/src/constan
import { useQuery } from '@tanstack/react-query'
import { databaseTriggerKeys } from './keys'
import type { PostgresTrigger } from '@/components/interfaces/Database/Triggers/TriggersList/TriggerList.utils'
import { get, handleError } from '@/data/fetchers'
import type { ResponseError, UseCustomQueryOptions } from '@/types'
function markSavedTriggerSafe(trigger: DatabaseTriggersData[number]): PostgresTrigger {
return trigger as PostgresTrigger
}
export type DatabaseTriggersVariables = {
projectRef?: string
connectionString?: string | null
@@ -61,16 +66,19 @@ export const useDatabaseHooksQuery = <TData = DatabaseTriggersData>(
...options,
})
export const useDatabaseTriggersQuery = <TData = DatabaseTriggersData>(
export const useDatabaseTriggersQuery = <TData = PostgresTrigger[]>(
{ projectRef, connectionString }: DatabaseTriggersVariables,
{
enabled = true,
...options
}: UseCustomQueryOptions<DatabaseTriggersData, DatabaseTriggersError, TData> = {}
}: UseCustomQueryOptions<PostgresTrigger[], DatabaseTriggersError, TData> = {}
) =>
useQuery<DatabaseTriggersData, DatabaseTriggersError, TData>({
useQuery<PostgresTrigger[], DatabaseTriggersError, TData>({
queryKey: databaseTriggerKeys.list(projectRef),
queryFn: ({ signal }) => getDatabaseTriggers({ projectRef, connectionString }, signal),
queryFn: ({ signal }) =>
getDatabaseTriggers({ projectRef, connectionString }, signal).then((data) =>
data.map(markSavedTriggerSafe)
),
enabled: enabled && typeof projectRef !== 'undefined',
...options,
})
@@ -1,4 +1,4 @@
import { getTableIndexAdvisorSql } from '@supabase/pg-meta'
import { getTableIndexAdvisorSql, type SafeSqlFragment } from '@supabase/pg-meta'
import { useQuery } from '@tanstack/react-query'
import { databaseKeys } from './keys'
@@ -14,7 +14,7 @@ export type TableIndexAdvisorVariables = {
}
export type IndexAdvisorSuggestion = {
query: string
query: SafeSqlFragment
calls: number
total_time: number
mean_time: number
@@ -113,7 +113,7 @@ export async function getTableIndexAdvisorSuggestions({
: 0
return {
query: row.query,
query: row.query as SafeSqlFragment,
calls: row.calls,
total_time: row.total_time,
mean_time: row.mean_time,
@@ -1,3 +1,4 @@
import { ident } from '@supabase/pg-meta'
import { Query } from '@supabase/pg-meta/src/query'
import { useMutation } from '@tanstack/react-query'
import { toast } from 'sonner'
@@ -20,7 +21,7 @@ export function getCellValueSql({
}: Pick<GetCellValueVariables, 'table' | 'column' | 'pkMatch'>) {
return new Query()
.from(table.name, table.schema ?? undefined)
.select(`"${column}"`)
.select(ident(column))
.match(pkMatch)
.toSql()
}
@@ -1,3 +1,4 @@
import { joinSqlFragments, safeSql, type SafeSqlFragment } from '@supabase/pg-meta'
import { wrapWithTransaction } from '@supabase/pg-meta/src/query'
import { useMutation, useQueryClient } from '@tanstack/react-query'
import { toast } from 'sonner'
@@ -24,7 +25,7 @@ export type OperationQueueSaveVariables = {
* Generates SQL for a single queued operation.
* Extend this function as new operation types are added.
*/
function getOperationSql(operation: QueuedOperation): string {
function getOperationSql(operation: QueuedOperation): SafeSqlFragment {
switch (operation.type) {
case QueuedOperationType.EDIT_CELL_CONTENT: {
const { payload } = operation
@@ -95,12 +96,12 @@ export async function saveOperationQueue({
}
const sortedOperations = sortOperations(operations)
const statements = sortedOperations.map((op) => {
const statements: Array<SafeSqlFragment> = sortedOperations.map((op) => {
const sql = getOperationSql(op)
return sql.endsWith(';') ? sql.slice(0, -1) : sql
return (sql.endsWith(';') ? sql.slice(0, -1) : sql) as SafeSqlFragment
})
const transactionSql = wrapWithTransaction(statements.join(';\n') + ';')
const transactionSql = wrapWithTransaction(safeSql`${joinSqlFragments(statements, ';\n')};`)
const sql = wrapWithRoleImpersonation(transactionSql, roleImpersonationState)
@@ -1,4 +1,4 @@
import { ROLE_IMPERSONATION_NO_RESULTS } from '@supabase/pg-meta'
import { ident, joinSqlFragments, ROLE_IMPERSONATION_NO_RESULTS, safeSql } from '@supabase/pg-meta'
import { Query, type QueryFilter } from '@supabase/pg-meta/src/query'
import { getTableRowsSql } from '@supabase/pg-meta/src/query/table-row-query'
import { useQuery, useQueryClient, type QueryClient } from '@tanstack/react-query'
@@ -130,11 +130,15 @@ export const getAllTableRowsSql = ({
.filter(
(column) => (column?.enum ?? []).length > 0 && column.dataType.toLowerCase() === 'array'
)
.map((column) => `"${column.name}"::text[]`)
.map((column) => safeSql`${ident(column.name)}::text[]`)
let queryChains = query
.from(table.name, table.schema ?? undefined)
.select(arrayBasedColumns.length > 0 ? `*,${arrayBasedColumns.join(',')}` : '*')
.select(
arrayBasedColumns.length > 0
? joinSqlFragments([safeSql`*`, ...arrayBasedColumns], ',')
: safeSql`*`
)
filters
.filter((filter) => filter.value && filter.value !== '')
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { Query } from '@supabase/pg-meta/src/query'
import { useQuery } from '@tanstack/react-query'
@@ -8,7 +9,7 @@ import { UseCustomQueryOptions } from '@/types'
const vaultSecretDecryptedValueQuery = (id: string) => {
const sql = new Query()
.from('decrypted_secrets', 'vault')
.select('decrypted_secret')
.select(safeSql`decrypted_secret`)
.match({ id })
.toSql()
@@ -18,7 +19,7 @@ const vaultSecretDecryptedValueQuery = (id: string) => {
const vaultSecretDecryptedValuesQuery = (ids: string[]) => {
const sql = new Query()
.from('decrypted_secrets', 'vault')
.select('id,decrypted_secret')
.select(safeSql`id,decrypted_secret`)
.filter('id', 'in', ids)
.toSql()
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { Query } from '@supabase/pg-meta/src/query'
import { useQuery } from '@tanstack/react-query'
@@ -8,7 +9,7 @@ import type { UseCustomQueryOptions, VaultSecret } from '@/types'
export const getVaultSecretsSql = () => {
const sql = new Query()
.from('secrets', 'vault')
.select('id,name,description,secret,created_at,updated_at')
.select(safeSql`id,name,description,secret,created_at,updated_at`)
.toSql()
return sql
+1 -1
View File
@@ -84,7 +84,7 @@ const buildSnippet = (
description: '',
favorite: false,
content: {
sql: content, // Default content
sql: content,
content_id: uuidv4(),
schema_version: '1.0',
},
+4 -1
View File
@@ -1,3 +1,4 @@
import { type SafeSqlFragment } from '@supabase/pg-meta'
import { format } from 'sql-formatter'
/**
@@ -5,7 +6,9 @@ import { format } from 'sql-formatter'
* formatting is consistent across the app. It also has a try/catch block which returns the original SQL in case of
* an error.
*/
export const formatSql = (sql: string) => {
export function formatSql(sql: SafeSqlFragment): SafeSqlFragment
export function formatSql(sql: string): string
export function formatSql(sql: string): string {
try {
return format(sql, {
language: 'postgresql',
+2 -1
View File
@@ -1,3 +1,4 @@
import { safeSql } from '@supabase/pg-meta'
import { describe, expect, it } from 'vitest'
import type { RoleImpersonationState } from './role-impersonation'
@@ -119,7 +120,7 @@ describe('getPostgrestClaims', () => {
})
describe('wrapWithRoleImpersonation', () => {
const sql = 'select * from colors;'
const sql = safeSql`select * from colors;`
const ref = 'default'
describe('postgres role (undefined)', () => {
+5 -2
View File
@@ -1,4 +1,4 @@
import { getImpersonationSQL } from '@supabase/pg-meta'
import { getImpersonationSQL, type SafeSqlFragment } from '@supabase/pg-meta'
import { uuidv4 } from './helpers'
import type { User } from '@/data/auth/users-infinite-query'
@@ -96,7 +96,10 @@ export function getPostgrestClaims(projectRef: string, role: PostgrestImpersonat
export type RoleImpersonationState = Pick<ValtioRoleImpersonationState, 'role' | 'claims'>
export function wrapWithRoleImpersonation(sql: string, state?: RoleImpersonationState) {
export function wrapWithRoleImpersonation(
sql: SafeSqlFragment,
state?: RoleImpersonationState
): SafeSqlFragment {
const { role, claims } = state ?? { role: undefined, claims: undefined }
if (role === undefined) return sql
@@ -10,9 +10,9 @@ import apiWrapper from '@/lib/api/apiWrapper'
const codeSchema = z.object({
sql: z
.string()
.optional()
.nullable()
.describe(
'The converted SQL query from the provided client library code. Return undefined if the code is invalid'
'The converted SQL query from the provided client library code. Return null if the code is invalid'
),
valid: z.boolean().describe('Whether the provided client library code is valid.'),
})
@@ -51,7 +51,7 @@ export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
output: Output.object({ schema: codeSchema }),
prompt: source`
Convert the follow Supabase client library code into SQL. The response should only be in JSON with the structure: { sql: string, valid: boolean }
If the client library code does not look valid, return { sql: undefined, valid: false }. Otherwise return valid as true and sql as the converted SQL query
If the client library code does not look valid, return { sql: null, valid: false }. Otherwise return valid as true and sql as the converted SQL query
${code}
`,
@@ -1,4 +1,5 @@
import type { PostgresPolicy, PostgresTable } from '@supabase/postgres-meta'
import { ident, safeSql } from '@supabase/pg-meta'
import type { PostgresTable } from '@supabase/postgres-meta'
import { PermissionAction } from '@supabase/shared-types/out/constants'
import { LOCAL_STORAGE_KEYS, useParams } from 'common'
import { Search, X } from 'lucide-react'
@@ -25,7 +26,10 @@ import { Policies } from '@/components/interfaces/Auth/Policies/Policies'
import { PoliciesDataProvider } from '@/components/interfaces/Auth/Policies/PoliciesDataContext'
import { getGeneralPolicyTemplates } from '@/components/interfaces/Auth/Policies/PolicyEditorModal/PolicyEditorModal.constants'
import { PolicyEditorPanel } from '@/components/interfaces/Auth/Policies/PolicyEditorPanel'
import { generatePolicyUpdateSQL } from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import {
generatePolicyUpdateSQL,
type Policy,
} from '@/components/interfaces/Auth/Policies/PolicyTableRow/PolicyTableRow.utils'
import { RLSTesterSheet } from '@/components/interfaces/Auth/RLSTester/RLSTesterSheet'
import AuthLayout from '@/components/layouts/AuthLayout/AuthLayout'
import { DefaultLayout } from '@/components/layouts/DefaultLayout'
@@ -60,7 +64,7 @@ import type { NextPageWithLayout } from '@/types'
*/
const getTableFilterState = (
tables: PostgresTable[],
policies: PostgresPolicy[],
policies: Array<Policy>,
searchString?: string
) => {
const sortedTables = tables.slice().sort((a, b) => a.name.localeCompare(b.name))
@@ -74,8 +78,7 @@ const getTableFilterState = (
const filter = searchString.toLowerCase()
const matchingPolicyKeys = new Set(
policies
// @ts-ignore Type instantiation is excessively deep and possibly infinite
.filter((policy: PostgresPolicy) => policy.name.toLowerCase().includes(filter))
.filter((policy: Policy) => policy.name.toLowerCase().includes(filter))
.map((policy) => `${policy.schema}.${policy.table}`)
)
@@ -190,8 +193,8 @@ const AuthPoliciesPage: NextPageWithLayout = () => {
setShowCreatePolicy(true)
if (isInlineEditorEnabled) {
const defaultSql = `create policy "replace_with_policy_name"
on ${schema}.${table}
const defaultSql = safeSql`create policy "replace_with_policy_name"
on ${ident(schema)}.${ident(table)}
for select
to authenticated
using (
@@ -210,7 +213,7 @@ const AuthPoliciesPage: NextPageWithLayout = () => {
)
const handleSelectEditPolicy = useCallback(
(policy: PostgresPolicy) => {
(policy: Policy) => {
setSelectedTable(undefined)
if (isInlineEditorEnabled) {
+6 -5
View File
@@ -1,3 +1,4 @@
import { safeSql, type DisplayableSqlFragment } from '@supabase/pg-meta'
import { proxy, snapshot, useSnapshot } from 'valtio'
type Template = {
@@ -13,18 +14,18 @@ export type SqlError = {
}
type EditorPanelState = {
value: string
value: DisplayableSqlFragment
templates: Template[]
results: Record<string, unknown>[] | undefined
error: SqlError | undefined
initialPrompt: string
onChange: ((value: string) => void) | undefined
onChange: ((value: DisplayableSqlFragment) => void) | undefined
activeSnippetId: string | null
pendingReset: boolean
}
const initialState: EditorPanelState = {
value: '',
value: safeSql``,
templates: [],
results: undefined,
error: undefined,
@@ -36,7 +37,7 @@ const initialState: EditorPanelState = {
export const editorPanelState = proxy({
...initialState,
setValue(value: string) {
setValue(value: DisplayableSqlFragment) {
editorPanelState.value = value
editorPanelState.onChange?.(value)
editorPanelState.setResults(undefined)
@@ -58,7 +59,7 @@ export const editorPanelState = proxy({
editorPanelState.activeSnippetId = id
},
openAsNew() {
editorPanelState.value = ''
editorPanelState.value = safeSql``
editorPanelState.results = undefined
editorPanelState.error = undefined
editorPanelState.pendingReset = true
+2 -1
View File
@@ -1,3 +1,4 @@
import { untrustedSql } from '@supabase/pg-meta'
import { debounce, memoize } from 'lodash'
import { useMemo } from 'react'
import { toast } from 'sonner'
@@ -178,7 +179,7 @@ export const sqlEditorState = proxy({
}) => {
let snippet = sqlEditorState.snippets[id]?.snippet
if (snippet?.content) {
snippet.content.sql = sql
snippet.content.unchecked_sql = untrustedSql(sql)
sqlEditorState.needsSaving.set(id, shouldInvalidate)
}
},
+5 -1
View File
@@ -1,3 +1,5 @@
import type { UntrustedSqlFragment } from '@supabase/pg-meta'
import { ChartConfig } from '@/components/interfaces/SQLEditor/UtilityPanel/ChartConfig'
export interface UserContent<
@@ -34,7 +36,9 @@ export namespace SqlSnippets {
content_id: string
// A full SQL query - this will be hashed on the /content endpoint
sql: string
// Named unchecked_sql to highlight that this SQL must never be run automatically
// without user confirmation — it may originate from untrusted sources like URL params.
unchecked_sql: UntrustedSqlFragment
// we can add some versioning to this schema in case we need to change the format.
schema_version: string
+32 -6
View File
@@ -270,9 +270,15 @@ test.describe('Database', () => {
await dropTable(databaseTableNameDuplicate)
}
)
const databaseLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'tables?include_columns=true&included_schemas=public'
)
await page.goto(toUrl(`/project/${env.PROJECT_REF}/database/tables?schema=public`))
// Wait for database tables to be populated
await waitForDatabaseToLoad(page, ref)
await databaseLoadWait
// create a new table
await page.getByRole('button', { name: 'New table' }).click()
@@ -485,10 +491,11 @@ test.describe('Database', () => {
test.describe('Triggers', () => {
test('actions works as expected', async ({ page, ref }) => {
const triggersLoadWait = createApiResponseWaiter(page, 'pg-meta', ref, 'triggers')
await page.goto(toUrl(`/project/${env.PROJECT_REF}/database/triggers?schema=public`))
// Wait for database triggers to be populated
await waitForApiResponse(page, 'pg-meta', ref, 'triggers')
await triggersLoadWait
const newTriggerButton = page.getByRole('button', { name: 'New trigger' }).first()
// create new trigger button to exist in public schema
@@ -526,10 +533,11 @@ test.describe('Database', () => {
}
)
const triggersCrudLoadWait = createApiResponseWaiter(page, 'pg-meta', ref, 'triggers')
await page.goto(toUrl(`/project/${env.PROJECT_REF}/database/triggers?schema=public`))
// Wait for database triggers to be populated
await waitForApiResponse(page, 'pg-meta', ref, 'triggers')
await triggersCrudLoadWait
// create new trigger
await page.getByRole('button', { name: 'New trigger' }).first().click()
@@ -603,10 +611,16 @@ test.describe('Database', () => {
test.describe('Database Indexes', () => {
test('actions works as expected', async ({ page, ref }) => {
const indexesLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=indexes-public'
)
await page.goto(toUrl(`/project/${env.PROJECT_REF}/database/indexes?schema=public`))
// Wait for database indexes to be populated
await waitForApiResponse(page, 'pg-meta', ref, 'query?key=indexes-public')
await indexesLoadWait
// create new index button exists in public schema
await expect(page.getByRole('button', { name: 'Create index' })).toBeVisible()
@@ -716,10 +730,16 @@ test.describe('Database', () => {
test.describe('Roles', () => {
test('actions works as expected', async ({ page, ref }) => {
const rolesLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=database-roles'
)
await page.goto(toUrl(`/project/${env.PROJECT_REF}/database/roles`))
// Wait for database roles list to be populated
await waitForApiResponse(page, 'pg-meta', ref, 'query?key=database-roles')
await rolesLoadWait
// filter between active and all roles
await page.getByRole('button', { name: 'Active roles' }).click()
@@ -734,10 +754,16 @@ test.describe('Database', () => {
test('CRUD operations works as expected', async ({ page, ref }) => {
const databaseRoleName = 'pw_database_role'
const databaseRolesWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=database-roles'
)
await page.goto(toUrl(`/project/${env.PROJECT_REF}/database/roles`))
// Wait for database roles to be populated
await waitForApiResponse(page, 'pg-meta', ref, 'query?key=database-roles')
await databaseRolesWait
// delete role if exists
const exists = (await page.getByRole('button', { name: databaseRoleName }).count()) > 0
@@ -3,7 +3,7 @@ import { expect, Page } from '@playwright/test'
import { createTable, dropTable, query } from '../utils/db/index.js'
import { test, withSetupCleanup } from '../utils/test.js'
import { toUrl } from '../utils/to-url.js'
import { waitForTableToLoad } from '../utils/wait-for-response.js'
import { createApiResponseWaiter, waitForTableToLoad } from '../utils/wait-for-response.js'
const QUEUE_OPERATIONS_KEY = 'supabase-ui-queue-operations'
const tableNamePrefix = 'pw_queue_table'
@@ -46,8 +46,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -94,8 +100,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -137,8 +149,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -175,8 +193,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -217,8 +241,14 @@ test.describe('Queue Table Operations', () => {
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -262,8 +292,14 @@ test.describe('Queue Table Operations', () => {
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -309,8 +345,14 @@ test.describe('Queue Table Operations', () => {
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -350,8 +392,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -388,8 +436,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -433,8 +487,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -473,8 +533,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -508,8 +574,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -552,8 +624,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -591,8 +669,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -645,8 +729,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -701,8 +791,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName1}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -755,8 +851,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName1}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -799,8 +901,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -856,8 +964,14 @@ test.describe('Queue Table Operations', () => {
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await enableQueueOperations(page)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.reload()
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
+56 -8
View File
@@ -212,8 +212,14 @@ testRunner('table editor', () => {
await expect(page.getByRole('cell', { name: 'value1, value2', exact: true })).toBeVisible()
// create a new table with new column for enums
const tableEditorEnumLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor`))
await waitForTableToLoad(page, ref) // load tables
await tableEditorEnumLoadWait // load tables
await page.getByRole('button', { name: 'New table', exact: true }).click()
await page.getByTestId('table-name-input').fill(tableNameEnum)
@@ -477,8 +483,14 @@ testRunner('table editor', () => {
}
)
const tableLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableLoadWait
await page.getByRole('button', { name: `View ${viewName}`, exact: true }).click()
@@ -1319,8 +1331,14 @@ testRunner('table editor', () => {
}
)
const tableEditorLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableEditorLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -1464,8 +1482,14 @@ testRunner('table editor', () => {
}
)
const tableEditorLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableEditorLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -1514,8 +1538,14 @@ testRunner('table editor', () => {
}
)
const tableEditorLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableEditorLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -1582,8 +1612,14 @@ testRunner('table editor', () => {
}
)
const tableEditorLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableEditorLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -1647,8 +1683,14 @@ testRunner('table editor', () => {
}
)
const tableEditorLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableEditorLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
@@ -1761,8 +1803,14 @@ testRunner('table editor', () => {
}
)
const tableEditorLoadWait = createApiResponseWaiter(
page,
'pg-meta',
ref,
'query?key=entity-types-public-'
)
await page.goto(toUrl(`/project/${ref}/editor?schema=public`))
await waitForTableToLoad(page, ref)
await tableEditorLoadWait
await page.getByRole('button', { name: `View ${tableName}`, exact: true }).click()
await page.waitForURL(/\/editor\/\d+\?schema=public$/)
+11 -2
View File
@@ -33,8 +33,17 @@ export * from './sql/studio/sql-editor'
export * from './sql/studio/role-impersonation'
export * from './sql/studio/integrations'
export { ident, literal, keyword, safeSql, joinSqlFragments } from './pg-format'
export type { SafeSqlFragment } from './pg-format'
export {
ident,
literal,
keyword,
safeSql,
rawSql,
untrustedSql,
acceptUntrustedSql,
joinSqlFragments,
} from './pg-format'
export type { SafeSqlFragment, UntrustedSqlFragment, DisplayableSqlFragment } from './pg-format'
export default {
roles,
+84 -25
View File
@@ -26,6 +26,18 @@ export interface PgFormatConfig {
*/
export type SafeSqlFragment = string & { readonly __safeSqlFragmentBrand: never }
/**
* A branded string type representing SQL that may have been influenced by a
* third party (URL params, AI output, external content). Safe to display;
* must never be auto-executed or persisted as user-authored content.
* Promote to SafeSqlFragment via acceptUntrustedSql() only inside an
* explicit user-action event handler.
*/
export type UntrustedSqlFragment = string & { readonly __untrustedSqlFragmentBrand: never }
/** Either brand — for read-only display surfaces that accept both. */
export type DisplayableSqlFragment = SafeSqlFragment | UntrustedSqlFragment
export type SqlFragmentSeparator =
| ','
| ', '
@@ -49,8 +61,8 @@ const FMT_PATTERN_CONFIG: PgFormatConfigPattern = {
}
// convert to Postgres default ISO 8601 format
function formatDate(date: string): string {
return date.replace('T', ' ').replace('Z', '+00')
function formatDate(date: SafeSqlFragment): SafeSqlFragment {
return date.replace('T', ' ').replace('Z', '+00') as SafeSqlFragment
}
function isReserved(value: string): boolean {
@@ -60,14 +72,18 @@ function isReserved(value: string): boolean {
return false
}
function arrayToList(useSpace: boolean, array: unknown[], formatter: (value: unknown) => string) {
let sql = ''
function arrayToList<ElementType = unknown>(
useSpace: boolean,
array: ElementType[],
formatter: (value: ElementType) => SafeSqlFragment
): SafeSqlFragment {
let sql = safeSql``
sql += useSpace ? ' (' : '('
sql = useSpace ? safeSql`${sql} (` : safeSql`${sql} (`
for (const [index, element] of array.entries()) {
sql += (index === 0 ? '' : ', ') + formatter(element)
sql = safeSql`${sql}${index === 0 ? safeSql`` : safeSql`, `}${formatter(element)}`
}
sql += ')'
sql = safeSql`${sql})`
return sql
}
@@ -82,7 +98,7 @@ export function ident(value?: unknown): SafeSqlFragment {
} else if (value === true) {
return '"t"' as SafeSqlFragment
} else if (value instanceof Date) {
return `"${formatDate(value.toISOString())}"` as SafeSqlFragment
return safeSql`"${formatDate(value.toISOString() as SafeSqlFragment)}"`
} else if (Array.isArray(value)) {
const temporary: string[] = []
for (const element of value) {
@@ -149,7 +165,7 @@ export function literal(value?: unknown): SafeSqlFragment {
return "'t'" as SafeSqlFragment
}
if (value instanceof Date) {
return `'${formatDate(value.toISOString())}'` as SafeSqlFragment
return safeSql`'${formatDate(value.toISOString() as SafeSqlFragment)}'`
}
if (Array.isArray(value)) {
const temporary: string[] = []
@@ -210,22 +226,32 @@ export function keyword(value: string): SafeSqlFragment {
return value as SafeSqlFragment
}
type Stringifyable =
| SafeSqlFragment
| number
| boolean
| Date
| null
| undefined
| Record<string | number | symbol, unknown>
| Stringifyable[]
// eslint-disable-next-line radar/cognitive-complexity
export function string(value?: unknown): string {
export function string(value?: Stringifyable): SafeSqlFragment {
if (value === undefined || value === null) {
return ''
return safeSql``
}
if (value === false) {
return 'f'
return safeSql`f`
}
if (value === true) {
return 't'
return safeSql`t`
}
if (value instanceof Date) {
return formatDate(value.toISOString())
return formatDate(value.toISOString() as SafeSqlFragment)
}
if (Array.isArray(value)) {
const temporary: string[] = []
const temporary: SafeSqlFragment[] = []
for (const [index, element] of value.entries()) {
if (element !== null && element !== undefined) {
if (Array.isArray(element) === true) {
@@ -235,13 +261,14 @@ export function string(value?: unknown): string {
}
}
}
return temporary.toString()
return temporary.toString() as SafeSqlFragment
}
if (value === Object(value)) {
return JSON.stringify(value)
if (!!value && typeof value === 'object') {
return JSON.stringify(value) as SafeSqlFragment
}
return String(value).toString().slice(0) // return copy
// value is number or SafeSqlFragment
return String(value).toString().slice(0) as SafeSqlFragment // return copy
}
export function config(cfg: PgFormatConfig): void {
@@ -263,7 +290,7 @@ export function config(cfg: PgFormatConfig): void {
}
}
export function withArray(fmt: string, parameters: unknown[]): string {
export function withArray(fmt: SafeSqlFragment, parameters: SafeSqlFragment[]): SafeSqlFragment {
let index = 0
let reText = '%(%|(\\d+\\$)?['
@@ -274,10 +301,9 @@ export function withArray(fmt: string, parameters: unknown[]): string {
const re = new RegExp(reText, 'g')
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
return fmt.replace(re, (_, type: string): string => {
return fmt.replace(re, (_, type: string): SafeSqlFragment => {
if (type === '%') {
return '%'
return safeSql`%`
}
let position = index
@@ -306,10 +332,12 @@ export function withArray(fmt: string, parameters: unknown[]): string {
if (type === FMT_PATTERN_CONFIG.string) {
return string(parameters[position])
}
})
throw new Error(`unsupported format type: ${type}`)
}) as SafeSqlFragment
}
export function format(fmt: string, ...arguments_: unknown[]): string {
export function format(fmt: SafeSqlFragment, ...arguments_: SafeSqlFragment[]): SafeSqlFragment {
return withArray(fmt, arguments_)
}
@@ -328,6 +356,37 @@ export function safeSql(
) as SafeSqlFragment
}
/**
* Marks a user-provided SQL string as a SafeSqlFragment for execution.
* Only use this when the user has explicitly typed or authored the SQL
* (e.g. a SQL editor, RLS tester). Never use for arbitrary data.
*/
export function rawSql(sql: string): SafeSqlFragment {
return sql as SafeSqlFragment
}
/**
* Marks SQL that may have been influenced by a third party (URL params, AI
* output, external content) as UntrustedSqlFragment. Safe to display; must
* never be auto-executed or persisted as user-authored.
*/
export function untrustedSql(sql: string): UntrustedSqlFragment {
return sql as UntrustedSqlFragment
}
/**
* Promote SQL to executable after explicit user acknowledgment.
* Accepts DisplayableSqlFragment (SafeSqlFragment | UntrustedSqlFragment) because
* a Run action approves whatever is currently in the editor, whether the user typed
* it themselves or loaded it from an external source.
* ONLY call from an event handler tied to a deliberate user action (onClick,
* keydown on Run shortcut). Never call from useEffect, render, or any path
* that runs without a user gesture.
*/
export function acceptUntrustedSql(sql: DisplayableSqlFragment): SafeSqlFragment {
return sql as unknown as SafeSqlFragment
}
/**
* Joins an array of already-safe SQL fragments with a fixed structural
* separator.
+199 -77
View File
@@ -1,4 +1,11 @@
import { format, ident, literal } from '../pg-format'
import {
format,
ident,
joinSqlFragments,
literal,
safeSql,
type SafeSqlFragment,
} from '../pg-format'
import type { Dictionary, Filter, QueryPagination, QueryTable, Sort } from './types'
export function countQuery(
@@ -7,12 +14,12 @@ export function countQuery(
filters?: Filter[]
}
) {
let query = `select count(*) from ${queryTable(table)}`
let query = safeSql`select count(*) from ${queryTable(table)}`
const { filters } = options ?? {}
if (filters) {
query = applyFilters(query, filters)
}
return query + ';'
return safeSql`${query};`
}
export function truncateQuery(
@@ -22,12 +29,12 @@ export function truncateQuery(
cascade?: boolean
}
) {
let query = `truncate ${queryTable(table)}`
let query = safeSql`truncate ${queryTable(table)}`
const { cascade } = options ?? {}
if (cascade) {
query += ' cascade'
query = safeSql`${query} cascade`
}
return query + ';'
return safeSql`${query};`
}
export function deleteQuery(
@@ -41,18 +48,22 @@ export function deleteQuery(
if (!filters || filters.length === 0) {
throw new Error('no filters for this delete query')
}
let query = `delete from ${queryTable(table)}`
let query = safeSql`delete from ${queryTable(table)}`
const { returning, enumArrayColumns } = options ?? {}
if (filters) {
query = applyFilters(query, filters)
}
if (returning) {
query +=
const returningFragment =
enumArrayColumns === undefined || enumArrayColumns.length === 0
? ` returning *`
: ` returning *, ${enumArrayColumns.map((x) => `${ident(x)}::text[]`).join(',')}`
? safeSql` returning *`
: safeSql` returning *, ${joinSqlFragments(
enumArrayColumns.map((x) => safeSql`${ident(x)}::text[]`),
','
)}`
query = safeSql`${query}${returningFragment}`
}
return query + ';'
return safeSql`${query};`
}
export function insertQuery(
@@ -67,36 +78,41 @@ export function insertQuery(
throw new Error('no value to insert')
}
const { returning, enumArrayColumns } = options ?? {}
const queryColumns = Object.keys(values[0])
.map((x) => ident(x))
.join(',')
let query = ''
const queryColumns = joinSqlFragments(
Object.keys(values[0]).map((x) => ident(x)),
','
)
let query = safeSql``
if (queryColumns.length == 0) {
query = format(
'insert into %1$s select from jsonb_populate_recordset(null::%1$s, %2$s)',
safeSql`insert into %1$s select from jsonb_populate_recordset(null::%1$s, %2$s)`,
queryTable(table),
literal(JSON.stringify(values))
)
} else {
query = format(
'insert into %1$s (%2$s) select %2$s from jsonb_populate_recordset(null::%1$s, %3$s)',
safeSql`insert into %1$s (%2$s) select %2$s from jsonb_populate_recordset(null::%1$s, %3$s)`,
queryTable(table),
queryColumns,
literal(JSON.stringify(values))
)
}
if (returning) {
query +=
const returningStatement =
enumArrayColumns === undefined || enumArrayColumns.length === 0
? ` returning *`
: ` returning *, ${enumArrayColumns.map((x) => `${ident(x)}::text[]`).join(',')}`
? safeSql` returning *`
: safeSql` returning *, ${joinSqlFragments(
enumArrayColumns.map((x) => safeSql`${ident(x)}::text[]`),
','
)}`
query = safeSql`${query}${returningStatement}`
}
return query + ';'
return safeSql`${query};`
}
export function selectQuery(
table: QueryTable,
columns?: string,
columns?: SafeSqlFragment,
options?: {
filters?: Filter[]
pagination?: QueryPagination
@@ -105,9 +121,9 @@ export function selectQuery(
isFinal = true,
isCTE = false
) {
let query = ''
const queryColumn = columns ?? '*'
query += `select ${queryColumn} from ${isCTE ? queryCTE(table) : queryTable(table)}`
let query = safeSql``
const queryColumn = columns ?? safeSql`*`
query = safeSql`select ${queryColumn} from ${isCTE ? queryCTE(table) : queryTable(table)}`
const { filters, pagination, sorts } = options ?? {}
if (filters) {
@@ -118,9 +134,9 @@ export function selectQuery(
}
if (pagination) {
const { limit, offset } = pagination ?? {}
query += ` limit ${literal(limit)} offset ${literal(offset)}`
query = safeSql`${query} limit ${literal(limit)} offset ${literal(offset)}`
}
return `${query}${isFinal ? ';' : ''}`
return safeSql`${query}${isFinal ? safeSql`;` : safeSql``}`
}
export function updateQuery(
@@ -136,11 +152,12 @@ export function updateQuery(
if (!filters || filters.length === 0) {
throw new Error('no filters for this update query')
}
const queryColumns = Object.keys(value)
.map((x) => ident(x))
.join(',')
const queryColumns = joinSqlFragments(
Object.keys(value).map((x) => ident(x)),
','
)
let query = format(
'update %1$s set (%2$s) = (select %2$s from json_populate_record(null::%1$s, %3$s))',
safeSql`update %1$s set (%2$s) = (select %2$s from json_populate_record(null::%1$s, %3$s))`,
queryTable(table),
queryColumns,
literal(JSON.stringify(value))
@@ -149,23 +166,27 @@ export function updateQuery(
query = applyFilters(query, filters)
}
if (returning) {
query +=
const returning =
enumArrayColumns === undefined || enumArrayColumns.length === 0
? ` returning *`
: ` returning *, ${enumArrayColumns.map((x) => `${ident(x)}::text[]`).join(',')}`
? safeSql` returning *`
: safeSql` returning *, ${joinSqlFragments(
enumArrayColumns.map((x) => safeSql`${ident(x)}::text[]`),
','
)}`
query = safeSql`${query}${returning}`
}
return query + ';'
return safeSql`${query};`
}
//============================================================
// Filter Utils
//============================================================
function applyFilters(query: string, filters: Filter[]) {
function applyFilters(query: SafeSqlFragment, filters: Filter[]) {
if (filters.length === 0) return query
query += ` where ${filters
.map((filter) => {
query = safeSql`${query} where ${joinSqlFragments(
filters.map((filter) => {
// Handle composite values
if (Array.isArray(filter.column)) {
switch (filter.operator) {
@@ -194,22 +215,23 @@ function applyFilters(query: string, filters: Filter[]) {
case '!~~*':
return castColumnToText(filter)
default:
return `${ident(filter.column)} ${filter.operator} ${filterLiteral(filter.value)}`
return safeSql`${ident(filter.column)} ${filter.operator as SafeSqlFragment} ${filterLiteral(filter.value)}`
}
})
.join(' and ')}`
}),
' and '
)}`
return query
}
function inFilterSql(filter: Filter) {
let values: Array<unknown>
let values: Array<SafeSqlFragment>
if (Array.isArray(filter.value)) {
values = filter.value.map((x) => filterLiteral(x))
} else {
const filterValueTxt = String(filter.value)
values = filterValueTxt.split(',').map((x) => filterLiteral(x))
}
return `${ident(filter.column)} ${filter.operator} (${values.join(',')})`
return safeSql`${ident(filter.column)} ${filter.operator as SafeSqlFragment} (${joinSqlFragments(values, ',')})`
}
function defaultTupleFilterSql(filter: Filter) {
@@ -223,9 +245,15 @@ function defaultTupleFilterSql(filter: Filter) {
throw new Error('Tuple filter value must have the same length as the column array')
}
const columns = `(${filter.column.map((c) => ident(c)).join(', ')})`
const values = `(${filter.value.map((v) => filterLiteral(v)).join(', ')})`
return `${columns} ${filter.operator} ${values}`
const columns = safeSql`(${joinSqlFragments(
filter.column.map((c) => ident(c)),
', '
)})`
const values = safeSql`(${joinSqlFragments(
filter.value.map((v) => filterLiteral(v)),
', '
)})`
return safeSql`${columns} ${filter.operator as SafeSqlFragment} ${values}`
}
function inTupleFilterSql(filter: Filter) {
@@ -236,25 +264,34 @@ function inTupleFilterSql(filter: Filter) {
throw new Error(`Values for a tuple 'in' filter must be an array`)
}
const columns = `(${filter.column.map((c) => ident(c)).join(', ')})`
const columns = safeSql`(${joinSqlFragments(
filter.column.map((c) => ident(c)),
', '
)})`
const values = filter.value.map((v) => {
if (Array.isArray(v)) {
if (v.length !== filter.column.length) {
throw new Error(`Tuple value length must match column length`)
}
return `(${v.map((x) => filterLiteral(x)).join(', ')})`
return safeSql`(${joinSqlFragments(
v.map((x) => filterLiteral(x)),
', '
)})`
} else {
const filterValueTxt = String(v)
const currValues = filterValueTxt.split(',')
if (currValues.length !== filter.column.length) {
throw new Error(`Tuple value length must match column length`)
}
return `(${currValues.map((x) => filterLiteral(x)).join(', ')})`
return safeSql`(${joinSqlFragments(
currValues.map((x) => filterLiteral(x)),
', '
)})`
}
})
return `${columns} ${filter.operator} (${values.join(', ')})`
return safeSql`${columns} ${filter.operator as SafeSqlFragment} (${joinSqlFragments(values, ', ')})`
}
function isFilterSql(filter: Filter) {
@@ -264,41 +301,126 @@ function isFilterSql(filter: Filter) {
case 'false':
case 'true':
case 'not null':
return `${ident(filter.column)} ${filter.operator} ${filterValueTxt}`
return safeSql`${ident(filter.column)} ${filter.operator as SafeSqlFragment} ${filterValueTxt as SafeSqlFragment}`
default:
return `${ident(filter.column)} ${filter.operator} ${filterLiteral(filter.value)}`
return safeSql`${ident(filter.column)} ${filter.operator as SafeSqlFragment} ${filterLiteral(filter.value)}`
}
}
function castColumnToText(filter: Filter) {
return `${ident(filter.column)}::text ${filter.operator} ${filterLiteral(filter.value)}`
return safeSql`${ident(filter.column)}::text ${filter.operator as SafeSqlFragment} ${filterLiteral(filter.value)}`
}
function filterLiteral(value: any) {
if (typeof value === 'string') {
if (value?.startsWith('ARRAY[') && value?.endsWith(']')) {
return value
function parseArrayLiteral(value: string): SafeSqlFragment | null {
if (!value.startsWith('ARRAY[')) return null
// Find the closing ] of the ARRAY, tracking quoted strings
const afterPrefix = value.slice(6)
let inString = false
let arrayCloseIdx = -1
for (let i = 0; i < afterPrefix.length; i++) {
const ch = afterPrefix[i]
if (!inString) {
if (ch === ']') {
arrayCloseIdx = i
break
} else if (ch === "'") {
inString = true
}
} else {
return literal(value)
if (ch === "'" && afterPrefix[i + 1] === "'") {
i++ // escaped ''
} else if (ch === "'") {
inString = false
}
}
}
return value
if (arrayCloseIdx === -1) return null
const contents = afterPrefix.slice(0, arrayCloseIdx)
const suffix = afterPrefix.slice(arrayCloseIdx + 1) // e.g. "::status_type[]" or ""
// Validate type cast suffix: only allow ::word_chars[]? or empty
let typeCast: SafeSqlFragment = safeSql``
if (suffix !== '') {
const match = suffix.match(/^::([A-Za-z_][A-Za-z0-9_]*)(\[\])?$/)
if (!match) return null
typeCast = safeSql`::${match[1] as SafeSqlFragment}${match[2] ? safeSql`[]` : safeSql``}`
}
// Parse comma-separated, single-quoted items
const rawItems: Array<string> = []
let current = ''
let inStr = false
for (let i = 0; i < contents.length; i++) {
const ch = contents[i]
if (!inStr) {
if (ch === "'") {
inStr = true
current += ch
} else if (ch === ',') {
rawItems.push(current.trim())
current = ''
} else {
current += ch
}
} else {
if (ch === "'" && contents[i + 1] === "'") {
current += "''"
i++
} else if (ch === "'") {
current += ch
inStr = false
} else {
current += ch
}
}
}
if (current.trim()) rawItems.push(current.trim())
const unquoted = rawItems.map((item) => {
if (item.startsWith("'") && item.endsWith("'")) {
return item.slice(1, -1).replace(/''/g, "'")
}
return item
})
const formattedItems = joinSqlFragments(
unquoted.map((x) => literal(x)),
','
)
return safeSql`ARRAY[${formattedItems}]${typeCast}`
}
function filterLiteral(value: any): SafeSqlFragment {
if (typeof value === 'boolean') {
return (value ? 'true' : 'false') as SafeSqlFragment
}
if (typeof value === 'string') {
if (value.startsWith('ARRAY[')) {
const parsed = parseArrayLiteral(value)
if (parsed !== null) return parsed
}
return literal(value)
}
return literal(value)
}
//============================================================
// Sort Utils
//============================================================
function applySorts(query: string, sorts: Sort[]) {
function applySorts(query: SafeSqlFragment, sorts: Sort[]): SafeSqlFragment {
const validSorts = sorts.filter((sort) => sort.column)
if (validSorts.length === 0) return query
query += ` order by ${validSorts
.map((x) => {
const order = x.ascending ? 'asc' : 'desc'
const nullOrder = x.nullsFirst ? 'nulls first' : 'nulls last'
return `${ident(x.table)}.${ident(x.column)} ${order} ${nullOrder}`
})
.join(', ')}`
query = safeSql`${query} order by ${joinSqlFragments(
validSorts.map((x) => {
const order = x.ascending ? safeSql`asc` : safeSql`desc`
const nullOrder = x.nullsFirst ? safeSql`nulls first` : safeSql`nulls last`
return safeSql`${ident(x.table)}.${ident(x.column)} ${order} ${nullOrder}`
}),
', '
)}`
return query
}
@@ -307,29 +429,29 @@ function applySorts(query: string, sorts: Sort[]) {
//============================================================
function queryTable(table: QueryTable) {
return `${ident(table.schema)}.${ident(table.name)}`
return safeSql`${ident(table.schema)}.${ident(table.name)}`
}
function queryCTE(table: QueryTable) {
return `${ident(table.name)}`
return safeSql`${ident(table.name)}`
}
export function wrapWithTransaction(sql: string) {
return /* SQL */ `
export function wrapWithTransaction(sql: SafeSqlFragment) {
return safeSql`
begin;
${sql}
commit;
`
}
export function wrapWithRollback(sql: string) {
return /* SQL */ `
export function wrapWithRollback(sql: SafeSqlFragment) {
return safeSql`
begin;
${sql}
rollback;
`
}
+20 -13
View File
@@ -1,11 +1,12 @@
import type { SafeSqlFragment } from '../pg-format'
import { IQueryFilter, QueryFilter } from './QueryFilter'
import type { Dictionary, QueryTable } from './types'
export interface IQueryAction {
count: () => IQueryFilter
delete: (options?: { returning: boolean }) => IQueryFilter
insert: (values: Dictionary<any>[], options?: { returning: boolean }) => IQueryFilter
select: (columns?: string) => IQueryFilter
insert: (values: Array<Dictionary<any>>, options?: { returning: boolean }) => IQueryFilter
select: (columns?: SafeSqlFragment) => IQueryFilter
update: (value: Dictionary<any>, options?: { returning: boolean }) => IQueryFilter
truncate: (options?: { returning: boolean }) => IQueryFilter
}
@@ -17,7 +18,7 @@ export class QueryAction implements IQueryAction {
* Performs a COUNT on the table.
*/
count() {
return new QueryFilter(this.table, 'count')
return new QueryFilter(this.table, { action: 'count' })
}
/**
@@ -25,8 +26,8 @@ export class QueryAction implements IQueryAction {
*
* @param options.returning If `true`, return the deleted row(s) in the response.
*/
delete(options?: { returning: boolean; enumArrayColumns?: string[] }) {
return new QueryFilter(this.table, 'delete', undefined, options)
delete(options?: { returning: boolean; enumArrayColumns?: Array<string> }) {
return new QueryFilter(this.table, { action: 'delete' }, options)
}
/**
@@ -35,8 +36,11 @@ export class QueryAction implements IQueryAction {
* @param values The values to insert.
* @param options.returning If `true`, return the inserted row(s) in the response.
*/
insert(values: Dictionary<any>[], options?: { returning: boolean; enumArrayColumns?: string[] }) {
return new QueryFilter(this.table, 'insert', values, options)
insert(
values: Array<Dictionary<any>>,
options?: { returning: boolean; enumArrayColumns?: Array<string> }
) {
return new QueryFilter(this.table, { action: 'insert', actionValue: values }, options)
}
/**
@@ -44,8 +48,8 @@ export class QueryAction implements IQueryAction {
*
* @param columns the query columns, by default set to '*'.
*/
select(columns?: string) {
return new QueryFilter(this.table, 'select', columns)
select(columns?: SafeSqlFragment) {
return new QueryFilter(this.table, { action: 'select', actionValue: columns })
}
/**
@@ -54,14 +58,17 @@ export class QueryAction implements IQueryAction {
* @param value The value to update.
* @param options.returning If `true`, return the updated row(s) in the response.
*/
update(value: Dictionary<any>, options?: { returning: boolean; enumArrayColumns?: string[] }) {
return new QueryFilter(this.table, 'update', value, options)
update(
value: Dictionary<any>,
options?: { returning: boolean; enumArrayColumns?: Array<string> }
) {
return new QueryFilter(this.table, { action: 'update', actionValue: value }, options)
}
/**
* Performs a TRUNCATE on the table
*/
truncate(options?: { returning: boolean; enumArrayColumns?: string[] }) {
return new QueryFilter(this.table, 'truncate', undefined, options)
truncate(options?: { returning: boolean; enumArrayColumns?: Array<string> }) {
return new QueryFilter(this.table, { action: 'truncate' }, options)
}
}
+8 -12
View File
@@ -1,5 +1,5 @@
import { IQueryModifier, QueryModifier } from './QueryModifier'
import type { Dictionary, Filter, FilterOperator, QueryTable, Sort } from './types'
import type { ActionConfig, Dictionary, Filter, FilterOperator, QueryTable, Sort } from './types'
export interface IQueryFilter {
filter: (column: string, operator: FilterOperator, value: string) => IQueryFilter
@@ -8,14 +8,13 @@ export interface IQueryFilter {
}
export class QueryFilter implements IQueryFilter, IQueryModifier {
protected filters: Filter[] = []
protected sorts: Sort[] = []
protected filters: Array<Filter> = []
protected sorts: Array<Sort> = []
constructor(
protected table: QueryTable,
protected action: 'count' | 'delete' | 'insert' | 'select' | 'update' | 'truncate',
protected actionValue?: string | string[] | Dictionary<any> | Dictionary<any>[],
protected actionOptions?: { returning: boolean; enumArrayColumns?: string[] }
protected actionConfig: ActionConfig,
protected actionOptions?: { returning: boolean; enumArrayColumns?: Array<string> }
) {}
filter(column: string | string[], operator: FilterOperator, value: any) {
@@ -47,8 +46,7 @@ export class QueryFilter implements IQueryFilter, IQueryModifier {
clone(): QueryFilter {
const clonedData = structuredClone({
table: this.table,
action: this.action,
actionValue: this.actionValue,
actionConfig: this.actionConfig,
actionOptions: this.actionOptions,
filters: this.filters,
sorts: this.sorts,
@@ -56,8 +54,7 @@ export class QueryFilter implements IQueryFilter, IQueryModifier {
const cloned = new QueryFilter(
clonedData.table,
clonedData.action,
clonedData.actionValue,
clonedData.actionConfig,
clonedData.actionOptions
)
@@ -72,8 +69,7 @@ export class QueryFilter implements IQueryFilter, IQueryModifier {
}
_getQueryModifier() {
return new QueryModifier(this.table, this.action, {
actionValue: this.actionValue,
return new QueryModifier(this.table, this.actionConfig, {
actionOptions: this.actionOptions,
filters: this.filters,
sorts: this.sorts,
+15 -13
View File
@@ -1,3 +1,4 @@
import { safeSql, type SafeSqlFragment } from '../pg-format'
import {
countQuery,
deleteQuery,
@@ -6,7 +7,7 @@ import {
truncateQuery,
updateQuery,
} from './Query.utils'
import type { Dictionary, Filter, QueryPagination, QueryTable, Sort } from './types'
import type { ActionConfig, Filter, QueryPagination, QueryTable, Sort } from './types'
export interface IQueryModifier {
range: (from: number, to: number) => QueryModifier
@@ -18,12 +19,11 @@ export class QueryModifier implements IQueryModifier {
constructor(
protected table: QueryTable,
protected action: 'count' | 'delete' | 'insert' | 'select' | 'update' | 'truncate',
protected actionConfig: ActionConfig,
protected options?: {
actionValue?: string | string[] | Dictionary<any> | Dictionary<any>[]
actionOptions?: { returning?: boolean; cascade?: boolean; enumArrayColumns?: string[] }
filters?: Filter[]
sorts?: Sort[]
actionOptions?: { returning?: boolean; cascade?: boolean; enumArrayColumns?: Array<string> }
filters?: Array<Filter>
sorts?: Array<Sort>
}
) {}
@@ -41,10 +41,12 @@ export class QueryModifier implements IQueryModifier {
/**
* Return SQL string for query chains
*/
toSql(options: { isCTE: boolean; isFinal: boolean } = { isCTE: false, isFinal: true }) {
toSql(
options: { isCTE: boolean; isFinal: boolean } = { isCTE: false, isFinal: true }
): SafeSqlFragment {
try {
const { actionValue, actionOptions, filters, sorts } = this.options ?? {}
switch (this.action) {
const { actionOptions, filters, sorts } = this.options ?? {}
switch (this.actionConfig.action) {
case 'count': {
return countQuery(this.table, { filters })
}
@@ -55,7 +57,7 @@ export class QueryModifier implements IQueryModifier {
})
}
case 'insert': {
return insertQuery(this.table, actionValue as Dictionary<any>[], {
return insertQuery(this.table, this.actionConfig.actionValue, {
returning: actionOptions?.returning,
enumArrayColumns: actionOptions?.enumArrayColumns,
})
@@ -63,7 +65,7 @@ export class QueryModifier implements IQueryModifier {
case 'select': {
return selectQuery(
this.table,
actionValue as string | undefined,
this.actionConfig.actionValue,
{
filters,
pagination: this.pagination,
@@ -74,7 +76,7 @@ export class QueryModifier implements IQueryModifier {
)
}
case 'update': {
return updateQuery(this.table, actionValue as Dictionary<any>, {
return updateQuery(this.table, this.actionConfig.actionValue, {
filters,
returning: actionOptions?.returning,
enumArrayColumns: actionOptions?.enumArrayColumns,
@@ -86,7 +88,7 @@ export class QueryModifier implements IQueryModifier {
})
}
default: {
return ''
return safeSql``
}
}
} catch (error) {
+60 -44
View File
@@ -1,4 +1,11 @@
import { ident } from '../pg-format'
import {
ident,
joinSqlFragments,
keyword,
literal,
safeSql,
type SafeSqlFragment,
} from '../pg-format'
import { PGForeignTable } from '../pg-meta-foreign-tables'
import { PGMaterializedView } from '../pg-meta-materialized-views'
import { PGTable } from '../pg-meta-tables'
@@ -28,44 +35,50 @@ export interface BuildTableRowsQueryArgs {
}
// Text and JSON types that should be truncated
export const TEXT_TYPES = ['text', 'varchar', 'char', 'character varying', 'character']
export const JSON_TYPES = ['json', 'jsonb']
export const TEXT_TYPES = [
safeSql`text`,
safeSql`varchar`,
safeSql`char`,
safeSql`character varying`,
safeSql`character`,
]
export const JSON_TYPES = [safeSql`json`, safeSql`jsonb`]
const JSON_SET = new Set(JSON_TYPES)
// Additional PostgreSQL types that can hold large values and should be truncated
export const ADDITIONAL_LARGE_TYPES = [
// Standard PostgreSQL types
'bytea', // Binary data
'xml', // XML data
'hstore', // Key-value store
'clob', // Character large object
safeSql`bytea`, // Binary data
safeSql`xml`, // XML data
safeSql`hstore`, // Key-value store
safeSql`clob`, // Character large object
// Extension-specific types
// pgvector extension (for AI/ML/RAG applications)
'vector', // Vector type used for embeddings
safeSql`vector`, // Vector type used for embeddings
// PostGIS extension types
'geometry', // Spatial data type
'geography', // Spatial data type
safeSql`geometry`, // Spatial data type
safeSql`geography`, // Spatial data type
// Full-text search types
'tsvector', // Text search vector
'tsquery', // Text search query
safeSql`tsvector`, // Text search vector
safeSql`tsquery`, // Text search query
// Range types
'daterange', // Date range
'tsrange', // Timestamp range
'tstzrange', // Timestamp with timezone range
'numrange', // Numeric range
'int4range', // Integer range
'int8range', // Bigint range
safeSql`daterange`, // Date range
safeSql`tsrange`, // Timestamp range
safeSql`tstzrange`, // Timestamp with timezone range
safeSql`numrange`, // Numeric range
safeSql`int4range`, // Integer range
safeSql`int8range`, // Bigint range
// Other extension types
'cube', // Multi-dimensional cube
'ltree', // Label tree
'lquery', // Label tree query
'jsonpath', // JSON path expressions
'citext', // Case-insensitive text
safeSql`cube`, // Multi-dimensional cube
safeSql`ltree`, // Label tree
safeSql`lquery`, // Label tree query
safeSql`jsonpath`, // JSON path expressions
safeSql`citext`, // Case-insensitive text
]
export const LARGE_COLUMNS_TYPES = [...TEXT_TYPES, ...JSON_TYPES, ...ADDITIONAL_LARGE_TYPES]
@@ -104,7 +117,7 @@ export const getDefaultOrderByColumns = (
* the data as truncated or not
*/
export const shouldTruncateColumn = (columnFormat: string): boolean =>
LARGE_COLUMNS_TYPES_SET.has(columnFormat.toLowerCase())
(LARGE_COLUMNS_TYPES_SET as Set<string>).has(columnFormat.toLowerCase())
export const DEFAULT_PAGE_SIZE = 100
@@ -125,8 +138,8 @@ export const getTableRowsSql = ({
maxCharacters = MAX_CHARACTERS,
maxArraySize = MAX_ARRAY_SIZE,
sortExcludedColumns = [],
}: BuildTableRowsQueryArgs) => {
if (!table || !table.columns) return ``
}: BuildTableRowsQueryArgs): SafeSqlFragment => {
if (!table || !table.columns) return safeSql``
const query = new Query()
@@ -135,7 +148,7 @@ export const getTableRowsSql = ({
filters.forEach((x) => {
const col = table.columns?.find((y) => y.name === x.column)
const isStringTypeColumn = !!col ? TEXT_TYPES.includes(col.format) : true
const isStringTypeColumn = !!col ? (TEXT_TYPES as string[]).includes(col.format) : true
queryChains = queryChains.filter(
x.column,
x.operator,
@@ -168,7 +181,7 @@ export const getTableRowsSql = ({
// filtering, applying limits and order by, then we can apply selection with some conditional logic to truncate large columns
// allowing postgres to only truncate the columns within the subset that we'll return instead of attemting to do it on
// all the rows within the table
const baseSelectQuery = `with _base_query as (${queryChains.range(from, to).toSql({ isCTE: false, isFinal: false })})`
const baseSelectQuery = safeSql`with _base_query as (${queryChains.range(from, to).toSql({ isCTE: false, isFinal: false })})`
const allColumnNames = table.columns
.sort((a, b) => a.ordinal_position - b.ordinal_position)
@@ -180,13 +193,13 @@ export const getTableRowsSql = ({
.map((column) => column.name)
// Create select expressions for each column, applying truncation only to needed columns
const selectExpressions = allColumnNames.map(({ name: columnName }) => {
const selectExpressions: Array<SafeSqlFragment> = allColumnNames.map(({ name: columnName }) => {
const escapedColumnName = ident(columnName)
if (columnsToTruncate.includes(columnName)) {
return `case
when octet_length(${escapedColumnName}::text) > ${maxCharacters}
then left(${escapedColumnName}::text, ${maxCharacters}) || '...'
return safeSql`case
when octet_length(${escapedColumnName}::text) > ${literal(maxCharacters)}
then left(${escapedColumnName}::text, ${literal(maxCharacters)}) || '...'
else ${escapedColumnName}::text
end as ${escapedColumnName}`
} else {
@@ -206,10 +219,13 @@ export const getTableRowsSql = ({
const index = selectExpressions.findIndex(
(expr) => expr === ident(columnName) // if the column is selected without any truncation applied to it
)
// If the column is a json, the final cast remain an array of json
const typeCast = JSON_SET.has(format) ? `${format}[]` : 'text[]'
const lastElement =
typeCast === 'text[]' ? `array['...']` : `array['{"truncated": true}'::json]`
const isJson = (JSON_SET as Set<string>).has(format)
// format comes from pg_attribute (e.g. 'text', 'json') — ident() ensures safe quoting
const arrayTypeCast = isJson ? safeSql`::${keyword(format)}[]` : safeSql`::text[]`
const lastElement: SafeSqlFragment = isJson
? safeSql`array['{"truncated": true}'::json]`
: safeSql`array['...']`
const col = ident(columnName)
if (index >= 0) {
// We cast to text[] but limit the array size if the total size of the array is too large (same logic than for text fields)
// This returns the first MAX_ARRAY_SIZE elements of the array (adjustable) and adds '...' if truncated
@@ -218,28 +234,28 @@ export const getTableRowsSql = ({
// Also handle multi-dimentionals array truncation, but won't happen the extra `...` element to it as we can't determine what's
// the right number of items to generate within the array. Studio side, we'll consider any multi-dimentional array as possibly
// truncated.
selectExpressions[index] = `
selectExpressions[index] = safeSql`
case
when octet_length(${ident(columnName)}::text) > ${maxCharacters}
when octet_length(${col}::text) > ${literal(maxCharacters)}
then
case
when array_ndims(${ident(columnName)}) = 1
when array_ndims(${col}) = 1
then
(select array_cat(${ident(columnName)}[1:${maxArraySize}]::${typeCast}, ${lastElement}::${typeCast}))::${typeCast}
(select array_cat(${col}[1:${literal(maxArraySize)}]${arrayTypeCast}, ${lastElement}${arrayTypeCast}))${arrayTypeCast}
else
${ident(columnName)}[1:${maxArraySize}]::${typeCast}
${col}[1:${literal(maxArraySize)}]${arrayTypeCast}
end
else ${ident(columnName)}::${typeCast}
else ${col}${arrayTypeCast}
end
`
}
})
const selectClause = selectExpressions.join(',')
const selectClause = joinSqlFragments(selectExpressions, ',')
const finalQuery = new Query()
// Now, we apply our selection logic with the tables truncation on the _base_query constructed before
const finalQueryChain = finalQuery.from('_base_query').select(selectClause)
return `${baseSelectQuery}
return safeSql`${baseSelectQuery}
${finalQueryChain.toSql({ isCTE: true, isFinal: true })}`
}
+8
View File
@@ -1,3 +1,11 @@
import type { SafeSqlFragment } from '../pg-format'
export type ActionConfig =
| { action: 'count' | 'delete' | 'truncate' }
| { action: 'select'; actionValue?: SafeSqlFragment }
| { action: 'insert'; actionValue: Array<Dictionary<any>> }
| { action: 'update'; actionValue: Dictionary<any> }
export interface Sort {
table: string
column: string
@@ -1,6 +1,8 @@
export const QUEUES_SCHEMA = 'pgmq_public'
import { safeSql, type SafeSqlFragment } from '../../../pg-format'
export const HIDE_QUEUES_FROM_POSTGREST_SQL = /* SQL */ `
export const QUEUES_SCHEMA = safeSql`pgmq_public`
export const HIDE_QUEUES_FROM_POSTGREST_SQL = safeSql`
drop function if exists
${QUEUES_SCHEMA}.pop(queue_name text),
${QUEUES_SCHEMA}.send(queue_name text, message jsonb, sleep_seconds integer),
@@ -27,11 +29,15 @@ export const HIDE_QUEUES_FROM_POSTGREST_SQL = /* SQL */ `
drop schema if exists ${QUEUES_SCHEMA};
`
export const getExposeQueuesSQL = ({ isNewerPgmqversion }: { isNewerPgmqversion: boolean }) => {
const conditionalJsonb = isNewerPgmqversion ? `, conditional := '{}'::jsonb` : ''
const jsonBArg = isNewerPgmqversion ? `, jsonb` : ''
export const getExposeQueuesSQL = ({
isNewerPgmqversion,
}: {
isNewerPgmqversion: boolean
}): SafeSqlFragment => {
const conditionalJsonb = isNewerPgmqversion ? safeSql`, conditional := '{}'::jsonb` : safeSql``
const jsonBArg = isNewerPgmqversion ? safeSql`, jsonb` : safeSql``
return /* SQL */ `
return safeSql`
create schema if not exists ${QUEUES_SCHEMA};
grant usage on schema ${QUEUES_SCHEMA} to postgres, anon, authenticated, service_role;
@@ -204,12 +210,12 @@ export const getExposeQueuesSQL = ({ isNewerPgmqversion }: { isNewerPgmqversion:
grant usage, select, update
on sequences
to anon, authenticated, service_role;
`.trim()
`
}
// [Joshen] Check if all the relevant functions exist to indicate whether PGMQ has been exposed through PostgREST
export const getQueuesExposePostgrestStatusSQL = () => {
return /**SQL */ `
export const getQueuesExposePostgrestStatusSQL = (): SafeSqlFragment => {
return safeSql`
SELECT exists (select schema_name FROM information_schema.schemata WHERE schema_name = '${QUEUES_SCHEMA}');
`.trim()
`
}
@@ -1,3 +1,4 @@
import { literal, safeSql, type SafeSqlFragment } from '../../../pg-format'
import { Filter, Query } from '../../../query'
import { COUNT_ESTIMATE_SQL, THRESHOLD_COUNT } from './get-count-estimate'
@@ -17,8 +18,8 @@ export const getTableRowsCountSql = ({
filters?: Filter[]
enforceExactCount?: boolean
isUsingReadReplica?: boolean
}) => {
if (!table) return ``
}): SafeSqlFragment => {
if (!table) return safeSql``
if (enforceExactCount) {
const query = new Query()
@@ -28,16 +29,23 @@ export const getTableRowsCountSql = ({
.forEach((x) => {
queryChains = queryChains.filter(x.column, x.operator, x.value)
})
return `select (${queryChains.toSql().slice(0, -1)}), false as is_estimate;`
const queryChainsSql = queryChains.toSql()
const queryChainsSqlWithoutSemicolon = queryChainsSql.endsWith(';')
? (queryChainsSql.slice(0, -1) as SafeSqlFragment)
: queryChainsSql
return safeSql`select (${queryChainsSqlWithoutSemicolon}), false as is_estimate;`
} else {
const selectQuery = new Query()
let selectQueryChains = selectQuery.from(table.name, table.schema ?? undefined).select('*')
let selectQueryChains = selectQuery.from(table.name, table.schema ?? undefined).select()
filters
.filter((x) => x.value && x.value != '')
.forEach((x) => {
selectQueryChains = selectQueryChains.filter(x.column, x.operator, x.value)
})
const selectBaseSql = selectQueryChains.toSql()
const selectBaseSqlWithoutSemicolon = selectBaseSql.endsWith(';')
? (selectBaseSql.slice(0, -1) as SafeSqlFragment)
: selectBaseSql
const countQuery = new Query()
let countQueryChains = countQuery.from(table.name, table.schema ?? undefined).count()
@@ -46,42 +54,45 @@ export const getTableRowsCountSql = ({
.forEach((x) => {
countQueryChains = countQueryChains.filter(x.column, x.operator, x.value)
})
const countBaseSql = countQueryChains.toSql().slice(0, -1)
const countBaseSql = countQueryChains.toSql()
const countBaseSqlWithoutSemicolon = countBaseSql.endsWith(';')
? (countBaseSql.slice(0, -1) as SafeSqlFragment)
: countBaseSql
if (isUsingReadReplica) {
const sql = `
const sql = safeSql`
with approximation as (
select reltuples as estimate
from pg_class
where oid = ${table.id}
where oid = ${literal(table.id)}
)
select
case
when estimate > ${THRESHOLD_COUNT} then (select -1)
else (${countBaseSql})
when estimate > ${literal(THRESHOLD_COUNT)} then (select -1)
else (${countBaseSqlWithoutSemicolon})
end as count,
estimate > ${THRESHOLD_COUNT} as is_estimate
estimate > ${literal(THRESHOLD_COUNT)} as is_estimate
from approximation;
`.trim()
`
return sql
} else {
const sql = `
const sql = safeSql`
${COUNT_ESTIMATE_SQL}
with approximation as (
select reltuples as estimate
from pg_class
where oid = ${table.id}
where oid = ${literal(table.id)}
)
select
case
when estimate > ${THRESHOLD_COUNT} then ${filters.length > 0 ? `pg_temp.count_estimate('${selectBaseSql.replaceAll("'", "''")}')` : 'estimate'}
else (${countBaseSql})
when estimate > ${literal(THRESHOLD_COUNT)} then ${filters.length > 0 ? safeSql`pg_temp.count_estimate('${selectBaseSqlWithoutSemicolon.replaceAll("'", "''") as SafeSqlFragment}')` : safeSql`estimate`}
else (${countBaseSqlWithoutSemicolon})
end as count,
estimate > ${THRESHOLD_COUNT} as is_estimate
estimate > ${literal(THRESHOLD_COUNT)} as is_estimate
from approximation;
`.trim()
`
return sql
}
@@ -1,4 +1,4 @@
import { ident, literal } from '../../../pg-format'
import { ident, joinSqlFragments, literal, safeSql, type SafeSqlFragment } from '../../../pg-format'
import { wrapWithTransaction } from '../../../query'
export const getCreateEnumeratedTypeSQL = ({
@@ -12,11 +12,13 @@ export const getCreateEnumeratedTypeSQL = ({
values: string[]
description?: string
}) => {
const typeSql = `${ident(schema)}.${ident(name)}`
const createSql = `create type ${typeSql} as enum (${values.map(literal).join(', ')});`
const typeSql = safeSql`${ident(schema)}.${ident(name)}`
const createSql = safeSql`create type ${typeSql} as enum (${joinSqlFragments(values.map(literal), ', ')});`
const commentSql =
description !== undefined ? `comment on type ${typeSql} is ${literal(description)};` : ''
return wrapWithTransaction(`${createSql} ${commentSql}`)
description !== undefined
? safeSql`comment on type ${typeSql} is ${literal(description)};`
: safeSql``
return wrapWithTransaction(safeSql`${createSql} ${commentSql}`)
}
export const getDeleteEnumeratedTypeSQL = ({ schema, name }: { schema: string; name: string }) => {
@@ -34,12 +36,12 @@ export const getUpdateEnumeratedTypeSQL = ({
description?: string
values?: { original: string; updated: string; isNew: boolean }[]
}) => {
const statements: string[] = []
const typeSql = `${ident(schema)}.${ident(name.updated)}`
const statements: SafeSqlFragment[] = []
const typeSql = safeSql`${ident(schema)}.${ident(name.updated)}`
if (name.original !== name.updated) {
statements.push(
`alter type ${ident(schema)}.${ident(name.original)} rename to ${ident(name.updated)};`
safeSql`alter type ${ident(schema)}.${ident(name.original)} rename to ${ident(name.updated)};`
)
}
@@ -50,23 +52,23 @@ export const getUpdateEnumeratedTypeSQL = ({
// Consider if any new enums were added before any existing enums
const firstExistingEnumValue = values.find((x) => !x.isNew)
statements.push(
`alter type ${typeSql} add value ${literal(x.updated)} before ${literal(firstExistingEnumValue?.original)};`
safeSql`alter type ${typeSql} add value ${literal(x.updated)} before ${literal(firstExistingEnumValue?.original)};`
)
} else {
statements.push(
`alter type ${typeSql} add value ${literal(x.updated)} after ${literal(values[idx - 1].updated)};`
safeSql`alter type ${typeSql} add value ${literal(x.updated)} after ${literal(values[idx - 1].updated)};`
)
}
} else if (x.original !== x.updated) {
statements.push(
`alter type ${typeSql} rename value ${literal(x.original)} to ${literal(x.updated)};`
safeSql`alter type ${typeSql} rename value ${literal(x.original)} to ${literal(x.updated)};`
)
}
})
}
if (description !== undefined) {
statements.push(`comment on type ${typeSql} is ${literal(description)};`)
statements.push(safeSql`comment on type ${typeSql} is ${literal(description)};`)
}
return wrapWithTransaction(statements.join(' '))
return wrapWithTransaction(joinSqlFragments(statements, ' '))
}
@@ -1,4 +1,4 @@
import { literal } from '../../pg-format'
import { literal, safeSql, type SafeSqlFragment } from '../../pg-format'
function getPostgrestRoleImpersonationSql({
role,
@@ -6,25 +6,25 @@ function getPostgrestRoleImpersonationSql({
}: {
role: string
unexpiredClaims: Object
}) {
return `
}): SafeSqlFragment {
return safeSql`
select set_config('role', ${literal(role)}, true),
set_config('request.jwt.claims', ${literal(JSON.stringify(unexpiredClaims))}, true),
set_config('request.method', 'POST', true),
set_config('request.path', '/impersonation-example-request-path', true),
set_config('request.headers', '{"accept": "*/*"}', true);
`.trim()
`
}
function getCustomRoleImpersonationSql(roleName: string) {
return /* SQL */ `
function getCustomRoleImpersonationSql(roleName: string): SafeSqlFragment {
return safeSql`
set local role ${literal(roleName)};
`.trim()
`
}
// Includes getPostgrestRoleImpersonationSql() and wrapWithRoleImpersonation()
export const ROLE_IMPERSONATION_SQL_LINE_COUNT = 11
export const ROLE_IMPERSONATION_NO_RESULTS = 'ROLE_IMPERSONATION_NO_RESULTS'
export const ROLE_IMPERSONATION_NO_RESULTS = safeSql`ROLE_IMPERSONATION_NO_RESULTS`
export const getImpersonationSQL = ({
role,
@@ -36,16 +36,16 @@ export const getImpersonationSQL = ({
role: string
}
unexpiredClaims?: Object
sql: string
}) => {
sql: SafeSqlFragment
}): SafeSqlFragment => {
const impersonationSql =
role.type === 'postgrest'
? unexpiredClaims !== undefined
? getPostgrestRoleImpersonationSql({ role: role.role, unexpiredClaims })
: ''
: safeSql``
: getCustomRoleImpersonationSql(role.role)
return /* SQL */ `
return safeSql`
${impersonationSql}
-- If the users sql returns no rows, pg-meta will
@@ -53,5 +53,5 @@ export const getImpersonationSQL = ({
select 1 as "${ROLE_IMPERSONATION_NO_RESULTS}";
${sql}
`.trim()
`
}
@@ -1,5 +1,6 @@
import { afterAll, describe, expect, test } from 'vitest'
import { ident, joinSqlFragments, safeSql } from '../../src/pg-format'
import { Query } from '../../src/query/Query'
import { cleanupRoot, createTestDatabase } from '../db/utils'
@@ -117,7 +118,7 @@ describe('Advanced Query Tests', () => {
describe('Special Table and Column Names', () => {
withTestDatabase('should handle tables with spaces', async (db) => {
const query = new Query()
const sql = query.from('table with spaces', 'public').select('*').toSql()
const sql = query.from('table with spaces', 'public').select().toSql()
expect(sql).toMatchInlineSnapshot(`"select * from public."table with spaces";"`)
const result = await validateSql(db, sql)
@@ -128,7 +129,7 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle tables with double quotes', async (db) => {
const query = new Query()
const sql = query.from('quoted"table', 'public').select('*').toSql()
const sql = query.from('quoted"table', 'public').select().toSql()
expect(sql).toMatchInlineSnapshot(`"select * from public."quoted""table";"`)
const result = await validateSql(db, sql)
@@ -139,7 +140,7 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle tables with single quotes', async (db) => {
const query = new Query()
const sql = query.from("quoted'table", 'public').select('*').toSql()
const sql = query.from("quoted'table", 'public').select().toSql()
expect(sql).toMatchInlineSnapshot(`"select * from public."quoted'table";"`)
const result = await validateSql(db, sql)
@@ -150,7 +151,7 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle camelCase table names', async (db) => {
const query = new Query()
const sql = query.from('camelCaseTable', 'public').select('*').toSql()
const sql = query.from('camelCaseTable', 'public').select().toSql()
expect(sql).toMatchInlineSnapshot(`"select * from public."camelCaseTable";"`)
const result = await validateSql(db, sql)
@@ -161,7 +162,7 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle tables with special characters', async (db) => {
const query = new Query()
const sql = query.from('special#$%^&Table', 'public').select('*').toSql()
const sql = query.from('special#$%^&Table', 'public').select().toSql()
expect(sql).toMatchInlineSnapshot(`"select * from public."special#$%^&Table";"`)
const result = await validateSql(db, sql)
@@ -172,7 +173,10 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle columns with spaces', async (db) => {
const query = new Query()
const sql = query.from('table with spaces', 'public').select('"column with spaces"').toSql()
const sql = query
.from('table with spaces', 'public')
.select(safeSql`"column with spaces"`)
.toSql()
expect(sql).toMatchInlineSnapshot(
`"select "column with spaces" from public."table with spaces";"`
@@ -185,7 +189,10 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle columns with double quotes', async (db) => {
const query = new Query()
const sql = query.from('table with spaces', 'public').select('"quoted""column"').toSql()
const sql = query
.from('table with spaces', 'public')
.select(safeSql`"quoted""column"`)
.toSql()
expect(sql).toMatchInlineSnapshot(
`"select "quoted""column" from public."table with spaces";"`
@@ -198,7 +205,10 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle columns with single quotes', async (db) => {
const query = new Query()
const sql = query.from('table with spaces', 'public').select('"quoted\'column"').toSql()
const sql = query
.from('table with spaces', 'public')
.select(safeSql`"quoted'column"`)
.toSql()
expect(sql).toMatchInlineSnapshot(`"select "quoted'column" from public."table with spaces";"`)
const result = await validateSql(db, sql)
@@ -209,7 +219,10 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle camelCase column names', async (db) => {
const query = new Query()
const sql = query.from('table with spaces', 'public').select('"camelCaseColumn"').toSql()
const sql = query
.from('table with spaces', 'public')
.select(safeSql`"camelCaseColumn"`)
.toSql()
expect(sql).toMatchInlineSnapshot(
`"select "camelCaseColumn" from public."table with spaces";"`
@@ -222,7 +235,10 @@ describe('Advanced Query Tests', () => {
withTestDatabase('should handle columns with special characters', async (db) => {
const query = new Query()
const sql = query.from('table with spaces', 'public').select('"special#$%^&Column"').toSql()
const sql = query
.from('table with spaces', 'public')
.select(safeSql`"special#$%^&Column"`)
.toSql()
expect(sql).toMatchInlineSnapshot(
`"select "special#$%^&Column" from public."table with spaces";"`
@@ -255,7 +271,7 @@ describe('Advanced Query Tests', () => {
// The Query class handles the proper quoting
const sql = query
.from('table with spaces', 'public')
.select('*')
.select()
.filter('column with spaces', '=', 'test value')
.toSql()
@@ -278,7 +294,7 @@ describe('Advanced Query Tests', () => {
const query = new Query()
const sql = query
.from('normal_table', 'public')
.select('*')
.select()
.filter('name', '=', "O'Reilly")
.toSql()
@@ -336,7 +352,7 @@ describe('Advanced Query Tests', () => {
const query = new Query()
const sql = query
.from('normal_table', 'public')
.select('id, name')
.select(safeSql`id, name`)
.filter('id', '>', 10)
.filter('name', '~~', '%John%')
.order('normal_table', 'name', true, false)
@@ -501,7 +517,7 @@ describe('Advanced Query Tests', () => {
const query = new Query()
const sql = query
.from('normal_table', 'public')
.select('*')
.select()
.filter('name', '=', 'Special $ ^ & * ( ) _ + { } | : < > ? characters')
.toSql()
@@ -527,7 +543,7 @@ describe('Advanced Query Tests', () => {
const query = new Query()
const sql = query
.from('normal_table', 'public')
.select('*')
.select()
.filter('id', 'in', [1, 2, 3])
.toSql()
@@ -547,11 +563,7 @@ describe('Advanced Query Tests', () => {
`)
const query = new Query()
const sql = query
.from('normal_table', 'public')
.select('*')
.filter('name', 'is', 'null')
.toSql()
const sql = query.from('normal_table', 'public').select().filter('name', 'is', 'null').toSql()
expect(sql).toMatchInlineSnapshot(`"select * from public.normal_table where name is null;"`)
const result = await validateSql(db, sql)
@@ -572,7 +584,7 @@ describe('Advanced Query Tests', () => {
const query = new Query()
const sql = query
.from('normal_table', 'public')
.select('*')
.select()
.filter('name', 'is', 'not null')
.toSql()
+113 -77
View File
@@ -1,5 +1,6 @@
import { describe, expect, test } from 'vitest'
import { ident, joinSqlFragments, safeSql } from '../../src/pg-format'
import { Query } from '../../src/query/Query'
import * as QueryUtils from '../../src/query/Query.utils'
import { QueryAction } from '../../src/query/QueryAction'
@@ -33,7 +34,7 @@ describe('QueryAction', () => {
expect(filter).toBeInstanceOf(QueryFilter)
expect(filter['table']).toEqual(table)
expect(filter['action']).toBe('count')
expect(filter['actionConfig'].action).toBe('count')
})
test('delete() should create a QueryFilter with the correct action and options', () => {
@@ -42,7 +43,7 @@ describe('QueryAction', () => {
expect(filter).toBeInstanceOf(QueryFilter)
expect(filter['table']).toEqual(table)
expect(filter['action']).toBe('delete')
expect(filter['actionConfig'].action).toBe('delete')
expect(filter['actionOptions']).toEqual({ returning: true })
})
@@ -53,19 +54,18 @@ describe('QueryAction', () => {
expect(filter).toBeInstanceOf(QueryFilter)
expect(filter['table']).toEqual(table)
expect(filter['action']).toBe('insert')
expect(filter['actionValue']).toEqual(values)
expect(filter['actionConfig']).toEqual({ action: 'insert', actionValue: values })
expect(filter['actionOptions']).toEqual({ returning: true })
})
test('select() should create a QueryFilter with the correct action and columns', () => {
const action = new QueryAction(table)
const filter = action.select('id, name')
const cols = joinSqlFragments([ident('id'), ident('name')], ', ')
const filter = action.select(cols)
expect(filter).toBeInstanceOf(QueryFilter)
expect(filter['table']).toEqual(table)
expect(filter['action']).toBe('select')
expect(filter['actionValue']).toBe('id, name')
expect(filter['actionConfig']).toEqual({ action: 'select', actionValue: 'id, name' })
})
test('update() should create a QueryFilter with the correct action, value and options', () => {
@@ -75,8 +75,7 @@ describe('QueryAction', () => {
expect(filter).toBeInstanceOf(QueryFilter)
expect(filter['table']).toEqual(table)
expect(filter['action']).toBe('update')
expect(filter['actionValue']).toEqual(value)
expect(filter['actionConfig']).toEqual({ action: 'update', actionValue: value })
expect(filter['actionOptions']).toEqual({ returning: true })
})
@@ -86,7 +85,7 @@ describe('QueryAction', () => {
expect(filter).toBeInstanceOf(QueryFilter)
expect(filter['table']).toEqual(table)
expect(filter['action']).toBe('truncate')
expect(filter['actionConfig'].action).toBe('truncate')
expect(filter['actionOptions']).toEqual({ returning: true })
})
})
@@ -95,7 +94,10 @@ describe('QueryFilter', () => {
const table: QueryTable = { name: 'users', schema: 'public' }
test('filter() should add a filter and return the filter instance', () => {
const queryFilter = new QueryFilter(table, 'select', 'id, name')
const queryFilter = new QueryFilter(table, {
action: 'select',
actionValue: joinSqlFragments([ident('id'), ident('name')], ', '),
})
const result = queryFilter.filter('id', '=', 1)
expect(result).toBe(queryFilter)
@@ -103,7 +105,10 @@ describe('QueryFilter', () => {
})
test('match() should add multiple filters and return the filter instance', () => {
const queryFilter = new QueryFilter(table, 'select', 'id, name')
const queryFilter = new QueryFilter(table, {
action: 'select',
actionValue: joinSqlFragments([ident('id'), ident('name')], ', '),
})
const result = queryFilter.match({ id: 1, name: 'John' })
expect(result).toBe(queryFilter)
@@ -114,7 +119,10 @@ describe('QueryFilter', () => {
})
test('order() should add a sort and return the filter instance', () => {
const queryFilter = new QueryFilter(table, 'select', 'id, name')
const queryFilter = new QueryFilter(table, {
action: 'select',
actionValue: joinSqlFragments([ident('id'), ident('name')], ', '),
})
const result = queryFilter.order('users', 'name', false, true)
expect(result).toBe(queryFilter)
@@ -124,7 +132,10 @@ describe('QueryFilter', () => {
})
test('range() should delegate to QueryModifier.range() and return the result', () => {
const queryFilter = new QueryFilter(table, 'select', 'id, name')
const queryFilter = new QueryFilter(table, {
action: 'select',
actionValue: joinSqlFragments([ident('id'), ident('name')], ', '),
})
const result = queryFilter.range(0, 10)
expect(result).toBeInstanceOf(QueryModifier)
@@ -133,7 +144,10 @@ describe('QueryFilter', () => {
})
test('toSql() should delegate to QueryModifier.toSql() and return the SQL string', () => {
const queryFilter = new QueryFilter(table, 'select', 'id, name')
const queryFilter = new QueryFilter(table, {
action: 'select',
actionValue: joinSqlFragments([ident('id'), ident('name')], ', '),
})
queryFilter.filter('id', '=', 1)
const result = queryFilter.toSql()
@@ -147,8 +161,9 @@ describe('QueryModifier', () => {
const table: QueryTable = { name: 'users', schema: 'public' }
test('range() should set the pagination and return the modifier instance', () => {
const queryModifier = new QueryModifier(table, 'select', {
actionValue: 'id, name',
const queryModifier = new QueryModifier(table, {
action: 'select',
actionValue: joinSqlFragments([ident('id'), ident('name')], ', '),
})
const result = queryModifier.range(0, 10)
@@ -157,28 +172,35 @@ describe('QueryModifier', () => {
})
test('toSql() should generate the correct SQL for a count query', () => {
const queryModifier = new QueryModifier(table, 'count')
const queryModifier = new QueryModifier(table, { action: 'count' })
const result = queryModifier.toSql()
expect(result).toBe('select count(*) from public.users;')
})
test('toSql() should generate the correct SQL for a delete query with filters', () => {
const queryModifier = new QueryModifier(table, 'delete', {
filters: [{ column: 'id', operator: '=', value: 1 }],
actionOptions: { returning: true },
})
const queryModifier = new QueryModifier(
table,
{ action: 'delete' },
{
filters: [{ column: 'id', operator: '=', value: 1 }],
actionOptions: { returning: true },
}
)
const result = queryModifier.toSql()
expect(result).toBe('delete from public.users where id = 1 returning *;')
})
test('toSql() should generate the correct SQL for a select query with filters, sorts and pagination', () => {
const queryModifier = new QueryModifier(table, 'select', {
actionValue: 'id, name',
filters: [{ column: 'id', operator: '>', value: 10 }],
sorts: [{ table: 'users', column: 'name', ascending: true, nullsFirst: false }],
})
const queryModifier = new QueryModifier(
table,
{ action: 'select', actionValue: joinSqlFragments([ident('id'), ident('name')], ', ') },
{
filters: [{ column: 'id', operator: '>', value: 10 }],
sorts: [{ table: 'users', column: 'name', ascending: true, nullsFirst: false }],
}
)
queryModifier.range(0, 5)
const result = queryModifier.toSql()
expect(result).toMatchInlineSnapshot(
@@ -187,16 +209,20 @@ describe('QueryModifier', () => {
})
test('toSql() should generate the correct SQL for a truncate query', () => {
const queryModifier = new QueryModifier(table, 'truncate')
const queryModifier = new QueryModifier(table, { action: 'truncate' })
const result = queryModifier.toSql()
expect(result).toBe('truncate public.users;')
})
test('toSql() should generate the correct SQL for a truncate query with cascade', () => {
const queryModifier = new QueryModifier(table, 'truncate', {
actionOptions: { cascade: true },
})
const queryModifier = new QueryModifier(
table,
{ action: 'truncate' },
{
actionOptions: { cascade: true },
}
)
const result = queryModifier.toSql()
expect(result).toBe('truncate public.users cascade;')
@@ -298,31 +324,34 @@ describe('Query.utils', () => {
})
test('should generate a correct select query with custom columns', () => {
const result = QueryUtils.selectQuery(table, 'id, name')
const result = QueryUtils.selectQuery(
table,
joinSqlFragments([ident('id'), ident('name')], ', ')
)
expect(result).toBe('select id, name from public.users;')
})
test('should generate a correct select query with filters', () => {
const filters = [{ column: 'id', operator: '>' as const, value: 1 }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where id > 1;')
})
test('should generate a correct select query with sorts', () => {
const sorts = [{ table: 'users', column: 'name', ascending: true, nullsFirst: false }]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toBe('select * from public.users order by users.name asc nulls last;')
})
test('should generate a correct select query with pagination', () => {
const pagination = { limit: 10, offset: 0 }
const result = QueryUtils.selectQuery(table, '*', { pagination: pagination })
const result = QueryUtils.selectQuery(table, safeSql`*`, { pagination: pagination })
expect(result).toBe('select * from public.users limit 10 offset 0;')
})
test('should ignore sorts with undefined column', () => {
const sorts: Sort[] = [{ table: 'users', column: '', ascending: true, nullsFirst: false }]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toMatchInlineSnapshot(`"select * from public.users;"`)
})
})
@@ -372,7 +401,7 @@ describe('Query.utils', () => {
describe('applyFilters', () => {
test('should correctly apply equality filters', () => {
const filters: Filter[] = [{ column: 'name', operator: '=', value: 'John' }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe("select * from public.users where name = 'John';")
})
@@ -381,93 +410,93 @@ describe('Query.utils', () => {
{ column: 'name', operator: '=', value: 'John' },
{ column: 'age', operator: '>', value: 25 },
]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe("select * from public.users where name = 'John' and age > 25;")
})
test('should correctly handle "in" operator with array values', () => {
const filters: Filter[] = [{ column: 'id', operator: 'in', value: [1, 2, 3] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where id in (1,2,3);')
})
test('should correctly handle "in" operator with comma-separated string', () => {
const filters: Filter[] = [{ column: 'id', operator: 'in', value: '1,2,3' }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe("select * from public.users where id in ('1','2','3');")
})
test('should correctly handle "is" operator with null value', () => {
const filters: Filter[] = [{ column: 'email', operator: 'is', value: 'null' }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where email is null;')
})
test('should correctly handle "is" operator with not null value', () => {
const filters: Filter[] = [{ column: 'email', operator: 'is', value: 'not null' }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where email is not null;')
})
test('should correctly handle "is" operator with boolean values', () => {
const filters: Filter[] = [{ column: 'active', operator: 'is', value: 'true' }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where active is true;')
})
test('should correctly escape string values in filters', () => {
const filters: Filter[] = [{ column: 'name', operator: '=', value: "O'Reilly" }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toContain("where name = 'O''Reilly'")
})
test('should error if tuple filter value length does not match column length', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '=', value: [1] }]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError(
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError(
'Tuple filter value must have the same length as the column array'
)
})
test('should error if tuple filter value is not an array', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '=', value: 1 }]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError(
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError(
'Tuple filter value must be an array'
)
})
test('should correctly handle tuple filters with equality operator', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '=', value: [1, 2] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where (id, version) = (1, 2);')
})
test('should correctly handle tuple filters with greater than operator', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '>', value: [1, 2] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where (id, version) > (1, 2);')
})
test('should correctly handle tuple filters with greater than or equal operator', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '>=', value: [1, 2] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where (id, version) >= (1, 2);')
})
test('should correctly handle tuple filters with less than operator', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '<', value: [10, 5] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where (id, version) < (10, 5);')
})
test('should correctly handle tuple filters with less than or equal operator', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '<=', value: [10, 5] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where (id, version) <= (10, 5);')
})
test('should correctly handle tuple filters with not equal operator (<>)', () => {
const filters: Filter[] = [{ column: ['id', 'version'], operator: '<>', value: [1, 2] }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where (id, version) <> (1, 2);')
})
@@ -483,7 +512,7 @@ describe('Query.utils', () => {
],
},
]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe(
'select * from public.users where (id, version) in ((1, 2), (3, 4), (5, 6));'
)
@@ -497,7 +526,7 @@ describe('Query.utils', () => {
value: [[1, 2], [3, 4], [5]],
},
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
test('should correctly handle tuple filters with in operator using strings', () => {
@@ -508,7 +537,7 @@ describe('Query.utils', () => {
value: ['one,two', 'three,four', 'five,six'],
},
]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe(
`select * from public.users where (id, version) in (('one', 'two'), ('three', 'four'), ('five', 'six'));`
)
@@ -522,14 +551,14 @@ describe('Query.utils', () => {
value: ['one,two', 'three,four', 'five'],
},
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
test('should correctly handle tuple filters with string values', () => {
const filters: Filter[] = [
{ column: ['first_name', 'last_name'], operator: '=', value: ['John', 'Doe'] },
]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe(
"select * from public.users where (first_name, last_name) = ('John', 'Doe');"
)
@@ -540,7 +569,7 @@ describe('Query.utils', () => {
{ column: ['id', 'version'], operator: '>', value: [1, 2] },
{ column: 'active', operator: '=', value: true },
]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe(
'select * from public.users where (id, version) > (1, 2) and active = true;'
)
@@ -554,7 +583,7 @@ describe('Query.utils', () => {
value: [null, null],
},
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
test('should error when trying to use "~~" operator as a tuple filter', () => {
@@ -565,28 +594,28 @@ describe('Query.utils', () => {
value: ['%John%', '%Doe%'],
},
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
test('should error when trying to use "~~*" operator as a tuple filter', () => {
const filters: Filter[] = [
{ column: ['first_name', 'last_name'], operator: '~~*', value: ['%john%', '%doe%'] },
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
test('should error when trying to use "!~~" operator as a tuple filter', () => {
const filters: Filter[] = [
{ column: ['first_name', 'last_name'], operator: '!~~', value: ['%Admin%', '%System%'] },
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
test('should error when trying to use "!~~*" operator as a tuple filter', () => {
const filters: Filter[] = [
{ column: ['first_name', 'last_name'], operator: '!~~*', value: ['%admin%', '%system%'] },
]
expect(() => QueryUtils.selectQuery(table, '*', { filters: filters })).toThrowError()
expect(() => QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })).toThrowError()
})
})
@@ -595,7 +624,7 @@ describe('Query.utils', () => {
const sorts: Sort[] = [
{ table: 'users', column: 'name', ascending: true, nullsFirst: false },
]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toBe('select * from public.users order by users.name asc nulls last;')
})
@@ -603,7 +632,7 @@ describe('Query.utils', () => {
const sorts: Sort[] = [
{ table: 'users', column: 'name', ascending: false, nullsFirst: false },
]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toBe('select * from public.users order by users.name desc nulls last;')
})
@@ -611,7 +640,7 @@ describe('Query.utils', () => {
const sorts: Sort[] = [
{ table: 'users', column: 'name', ascending: true, nullsFirst: true },
]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toBe('select * from public.users order by users.name asc nulls first;')
})
@@ -620,7 +649,7 @@ describe('Query.utils', () => {
{ table: 'users', column: 'last_name', ascending: true, nullsFirst: false },
{ table: 'users', column: 'first_name', ascending: true, nullsFirst: false },
]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toBe(
'select * from public.users order by users.last_name asc nulls last, users.first_name asc nulls last;'
)
@@ -628,7 +657,7 @@ describe('Query.utils', () => {
test('should ignore sorts with undefined column', () => {
const sorts: Sort[] = [{ table: 'users', column: '', ascending: true, nullsFirst: false }]
const result = QueryUtils.selectQuery(table, '*', { sorts: sorts })
const result = QueryUtils.selectQuery(table, safeSql`*`, { sorts: sorts })
expect(result).toMatchInlineSnapshot(`"select * from public.users;"`)
})
})
@@ -636,13 +665,13 @@ describe('Query.utils', () => {
describe('filterLiteral', () => {
test('should correctly handle array literal syntax', () => {
const filters: Filter[] = [{ column: 'tags', operator: '=', value: "ARRAY['tag1','tag2']" }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe("select * from public.users where tags = ARRAY['tag1','tag2'];")
})
test('should correctly handle non-string values', () => {
const filters: Filter[] = [{ column: 'active', operator: '=', value: true }]
const result = QueryUtils.selectQuery(table, '*', { filters: filters })
const result = QueryUtils.selectQuery(table, safeSql`*`, { filters: filters })
expect(result).toBe('select * from public.users where active = true;')
})
})
@@ -659,7 +688,10 @@ describe('Query.utils', () => {
describe('End-to-end query chaining', () => {
test('should correctly build a simple select query', () => {
const query = new Query()
const sql = query.from('users', 'public').select('id, name, email').toSql()
const sql = query
.from('users', 'public')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.toSql()
expect(sql).toBe('select id, name, email from public.users;')
})
@@ -668,7 +700,7 @@ describe('End-to-end query chaining', () => {
const query = new Query()
const sql = query
.from('users', 'public')
.select('id, name, email')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.filter('id', '>', 10)
.toSql()
@@ -679,7 +711,7 @@ describe('End-to-end query chaining', () => {
const query = new Query()
const sql = query
.from('users', 'public')
.select('id, name, email')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.filter('id', '>', 10)
.filter('name', '~~', '%John%')
.toSql()
@@ -693,7 +725,7 @@ describe('End-to-end query chaining', () => {
const query = new Query()
const sql = query
.from('users', 'public')
.select('id, name, email')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.match({ active: true, role: 'admin' })
.toSql()
@@ -706,7 +738,7 @@ describe('End-to-end query chaining', () => {
const query = new Query()
const sql = query
.from('users', 'public')
.select('id, name, email')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.order('users', 'name', true, false)
.toSql()
@@ -715,7 +747,11 @@ describe('End-to-end query chaining', () => {
test('should correctly build a select query with pagination', () => {
const query = new Query()
const sql = query.from('users', 'public').select('id, name, email').range(0, 9).toSql()
const sql = query
.from('users', 'public')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.range(0, 9)
.toSql()
expect(sql).toBe('select id, name, email from public.users limit 10 offset 0;')
})
@@ -724,7 +760,7 @@ describe('End-to-end query chaining', () => {
const query = new Query()
const sql = query
.from('users', 'public')
.select('id, name, email')
.select(joinSqlFragments([ident('id'), ident('name'), ident('email')], ', '))
.filter('id', '>', 10)
.match({ active: true })
.order('users', 'name', true, false)