AI SQL: Update model + add test suite (#19644)

* refactor: ai sql logic to common package

* feat: ai sql tests

* feat: test rls policy ai chat assistant

* refactor: jest env loading

* feat(ai): improve sql title and description quality

* fix: ai message role type

* Move the new files to a separate package.

* Remove a forgotten console.log.

* Migrate the tests to use snapshots and commit the snapshots.

* Separate the functions which require edge runtime to be exported via /edge subpath.

* Bust the turbo cache when one of the deps has been rebuilt.

* chore: fix package main/type references

* fix: package lock out of sync

* fix: ai sql debugging to fix typos

---------

Co-authored-by: Ivan Vasilov <vasilov.ivan@gmail.com>
This commit is contained in:
Greg Richardson
2023-12-18 11:23:59 -07:00
committed by GitHub
parent adc6db57c2
commit a6f1313490
23 changed files with 1427 additions and 857 deletions
@@ -11,7 +11,7 @@ import { AIPolicyPre } from './AIPolicyPre'
interface MessageProps {
name?: string
role: 'function' | 'user' | 'assistant' | 'system'
role: 'function' | 'user' | 'assistant' | 'system' | 'data'
content?: string
createdAt?: number
isDebug?: boolean
+2
View File
@@ -39,6 +39,8 @@
"@uidotdev/usehooks": "^2.4.1",
"@zip.js/zip.js": "^2.7.29",
"ai": "^2.2.26",
"ai-commands": "*",
"ajv": "^8.6.3",
"awesome-debounce-promise": "^2.1.0",
"blueimp-md5": "^2.19.0",
"clsx": "^1.2.1",
+34 -142
View File
@@ -1,44 +1,10 @@
import { SchemaBuilder } from '@serafin/schema-builder'
import { codeBlock, stripIndent } from 'common-tags'
import { isError } from 'data/utils/error-check'
import { jsonrepair } from 'jsonrepair'
import { ContextLengthError, EmptySqlError, debugSql } from 'ai-commands'
import apiWrapper from 'lib/api/apiWrapper'
import { NextApiRequest, NextApiResponse } from 'next'
import { OpenAI } from 'openai'
const openAiKey = process.env.OPENAI_KEY
const debugSqlSchema = SchemaBuilder.emptySchema()
.addString('solution', {
description: 'A short suggested solution for the error (as concise as possible).',
})
.addString('sql', {
description: 'The SQL rewritten to apply the solution. Includes all the original SQL.',
})
type DebugSqlResult = typeof debugSqlSchema.T
const completionFunctions: Record<
string,
OpenAI.Chat.Completions.ChatCompletionCreateParams.Function
> = {
debugSql: {
name: 'debugSql',
description: stripIndent`
Debugs a Postgres SQL error and modifies the SQL to fix it.
- Create extensions if they are missing (only for valid extensions)
- Suggest creating tables if they are missing
- Include all of the original SQL
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- When creating tables, always add foreign key references inline
- Prefer 'text' over 'varchar'
- Prefer 'timestamp with time zone' over 'date'
- Use vector(384) data type for any embedding/vector related query
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
`,
parameters: debugSqlSchema.schema as Record<string, unknown>,
},
}
const openai = new OpenAI({ apiKey: openAiKey })
async function handler(req: NextApiRequest, res: NextApiResponse) {
if (!openAiKey) {
@@ -59,128 +25,54 @@ async function handler(req: NextApiRequest, res: NextApiResponse) {
}
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
const openAI = new OpenAI({ apiKey: openAiKey })
const {
body: { errorMessage, sql, entityDefinitions },
} = req
const model = 'gpt-3.5-turbo-0613'
const maxCompletionTokenCount = 2048
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
if (hasEntityDefinitions) {
completionMessages.push({
role: 'user',
content: codeBlock`
Here is my database schema for reference:
${entityDefinitions.join('\n\n')}
`,
})
}
completionMessages.push(
{
role: 'user',
content: stripIndent`
Here is my current SQL:
${sql}
`,
},
{
role: 'user',
content: stripIndent`
Here is the error I am getting:
${errorMessage}
`,
}
)
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model,
messages: completionMessages,
max_tokens: maxCompletionTokenCount,
temperature: 0,
function_call: {
name: completionFunctions.debugSql.name,
},
functions: [completionFunctions.debugSql],
stream: false,
}
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
try {
completionResponse = await openAI.chat.completions.create(completionOptions)
} catch (error: any) {
console.error(`AI SQL debugging failed: ${error.message}`)
const result = await debugSql(openai, errorMessage, sql, entityDefinitions)
return res.json(result)
} catch (error) {
if (error instanceof Error) {
console.error(`AI SQL debugging failed: ${error.message}`)
if ('code' in error && error.code === 'context_length_exceeded') {
if (hasEntityDefinitions) {
const definitionsLength = entityDefinitions.reduce(
(sum: number, def: string) => sum + def.length,
0
)
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
if (definitionsLength > sql.length) {
return res.status(400).json({
error:
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
})
if (error instanceof ContextLengthError) {
// If there are more entity definitions than the SQL provided, attribute the
// error to the database metadata
if (hasEntityDefinitions) {
const definitionsLength = entityDefinitions.reduce(
(sum: number, def: string) => sum + def.length,
0
)
if (definitionsLength > sql.length) {
return res.status(400).json({
error:
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
})
}
}
// Otherwise attribute the error to the SQL being too large
return res.status(400).json({
error:
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
})
}
return res.status(400).json({
error:
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
})
if (error instanceof EmptySqlError) {
res.status(400).json({
error: 'Unable to debug SQL. No fix identified for the error.',
})
}
} else {
console.log(`Unknown error: ${error}`)
}
return res.status(500).json({
error: 'There was an unknown error debugging the SQL snippet. Please try again.',
})
}
const [firstChoice] = completionResponse.choices
const sqlResponseString = firstChoice.message?.function_call?.arguments
if (!sqlResponseString) {
console.error(
`AI SQL debugging failed: OpenAI response succeeded, but response format was incorrect`
)
return res.status(500).json({
error: 'There was an unknown error debugging the SQL snippet. Please try again.',
})
}
try {
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
const repairedJsonString = jsonrepair(sqlResponseString)
const debugSqlResult: DebugSqlResult = JSON.parse(repairedJsonString)
if (!debugSqlResult.sql) {
console.error(`AI SQL debugging failed: Unable to debug SQL for the given error message`)
return res.status(400).json({
error: 'Unable to debug SQL',
})
}
return res.json(debugSqlResult)
} catch (error) {
console.error(
`AI SQL editing failed: ${
isError(error) ? error.message : 'An unknown error occurred'
}, sqlResponseString: ${sqlResponseString}`
)
return res.status(500).json({
error: 'There was an unknown error editing the SQL snippet. Please try again.',
})
}
}
const wrapper = (req: NextApiRequest, res: NextApiResponse) => apiWrapper(req, res, handler)
+39 -135
View File
@@ -1,39 +1,10 @@
import { SchemaBuilder } from '@serafin/schema-builder'
import { codeBlock, stripIndent } from 'common-tags'
import { isError } from 'data/utils/error-check'
import { jsonrepair } from 'jsonrepair'
import { ContextLengthError, EmptySqlError, editSql } from 'ai-commands'
import apiWrapper from 'lib/api/apiWrapper'
import { NextApiRequest, NextApiResponse } from 'next'
import { OpenAI } from 'openai'
const openAiKey = process.env.OPENAI_KEY
const editSqlSchema = SchemaBuilder.emptySchema().addString('sql', {
description: stripIndent`
The modified SQL (must be valid SQL).
- Assume the query hasn't been executed yet
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- When creating tables, always add foreign key references inline
- Prefer 'text' over 'varchar'
- Prefer 'timestamp with time zone' over 'date'
- Use vector(384) data type for any embedding/vector related query
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
- Use real examples when possible
`,
})
type EditSqlResult = typeof editSqlSchema.T
const completionFunctions: Record<
string,
OpenAI.Chat.Completions.ChatCompletionCreateParams.Function
> = {
editSql: {
name: 'editSql',
description: "Edits a Postgres SQL query based on the user's instructions",
parameters: editSqlSchema.schema as Record<string, unknown>,
},
}
const openai = new OpenAI({ apiKey: openAiKey })
async function handler(req: NextApiRequest, res: NextApiResponse) {
if (!openAiKey) {
@@ -58,112 +29,45 @@ export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
body: { prompt, sql, entityDefinitions },
} = req
const openAI = new OpenAI({ apiKey: openAiKey })
const model = 'gpt-3.5-turbo-0613'
const maxCompletionTokenCount = 2048
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
if (hasEntityDefinitions) {
completionMessages.push({
role: 'user',
content: codeBlock`
Here is my database schema for reference:
${entityDefinitions.join('\n\n')}
`,
})
}
completionMessages.push(
{
role: 'user',
content: stripIndent`
Here is my current SQL:
${sql}
`,
},
{
role: 'user',
content: prompt,
}
)
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model,
messages: completionMessages,
max_tokens: maxCompletionTokenCount,
temperature: 0,
function_call: {
name: completionFunctions.editSql.name,
},
functions: [completionFunctions.editSql],
stream: false,
}
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
try {
completionResponse = await openAI.chat.completions.create(completionOptions)
} catch (error: any) {
console.error(`AI SQL editing failed: ${error.message}`)
if ('code' in error && error.code === 'context_length_exceeded') {
if (hasEntityDefinitions) {
const definitionsLength = entityDefinitions.reduce(
(sum: number, def: string) => sum + def.length,
0
)
if (definitionsLength > sql.length) {
return res.status(400).json({
error:
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
})
}
}
return res.status(400).json({
error:
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
})
}
return res.status(500).json({
error: 'There was an unknown error editing the SQL snippet. Please try again.',
})
}
const [firstChoice] = completionResponse.choices
const sqlResponseString = firstChoice.message?.function_call?.arguments
if (!sqlResponseString) {
console.error(
`AI SQL editing failed: OpenAI response succeeded, but response format was incorrect`
)
return res.status(500).json({
error: 'There was an unknown error editing the SQL snippet. Please try again.',
})
}
try {
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
const repairedJsonString = jsonrepair(sqlResponseString)
const editSqlResult: EditSqlResult = JSON.parse(repairedJsonString)
if (!editSqlResult.sql) {
console.error(`AI SQL editing failed: Unable to edit SQL for the given prompt`)
return res.status(400).json({
error: 'Unable to edit SQL. Try adding more details to your prompt.',
})
}
return res.json(editSqlResult)
const result = await editSql(openai, prompt, sql, entityDefinitions)
return res.json(result)
} catch (error) {
console.error(
`AI SQL editing failed: ${
isError(error) ? error.message : 'An unknown error occurred'
}, sqlResponseString: ${sqlResponseString}`
)
if (error instanceof Error) {
console.error(`AI SQL editing failed: ${error.message}`)
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
if (error instanceof ContextLengthError) {
// If there are more entity definitions than the SQL provided, attribute the
// error to the database metadata
if (hasEntityDefinitions) {
const definitionsLength = entityDefinitions.reduce(
(sum: number, def: string) => sum + def.length,
0
)
if (definitionsLength > sql.length) {
return res.status(400).json({
error:
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
})
}
}
// Otherwise attribute the error to the SQL being too large
return res.status(400).json({
error:
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
})
}
if (error instanceof EmptySqlError) {
res.status(400).json({
error: 'Unable to edit SQL. Try adding more details to your prompt.',
})
}
} else {
console.log(`Unknown error: ${error}`)
}
return res.status(500).json({
error: 'There was an unknown error editing the SQL snippet. Please try again.',
+25 -124
View File
@@ -1,44 +1,10 @@
import { SchemaBuilder } from '@serafin/schema-builder'
import { codeBlock, stripIndent } from 'common-tags'
import { isError } from 'data/utils/error-check'
import { jsonrepair } from 'jsonrepair'
import { ContextLengthError, EmptySqlError, generateSql } from 'ai-commands'
import apiWrapper from 'lib/api/apiWrapper'
import { NextApiRequest, NextApiResponse } from 'next'
import { OpenAI } from 'openai'
const openAiKey = process.env.OPENAI_KEY
const generateSqlSchema = SchemaBuilder.emptySchema()
.addString('sql', {
description: stripIndent`
The generated SQL (must be valid SQL).
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- Prefer creating foreign key references in the create statement
- Prefer 'text' over 'varchar'
- Prefer 'timestamp with time zone' over 'date'
- Use vector(384) data type for any embedding/vector related query
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
`,
})
.addString('title', {
description: stripIndent`
The title of the SQL.
- Omit words like 'SQL', 'Postgres', or 'Query'
`,
})
type GenerateSqlResult = typeof generateSqlSchema.T
const completionFunctions: Record<
string,
OpenAI.Chat.Completions.ChatCompletionCreateParams.Function
> = {
generateSql: {
name: 'generateSql',
description: 'Generates Postgres SQL based on a natural language prompt',
parameters: generateSqlSchema.schema as Record<string, unknown>,
},
}
const openai = new OpenAI({ apiKey: openAiKey })
async function handler(req: NextApiRequest, res: NextApiResponse) {
if (!openAiKey) {
@@ -59,102 +25,37 @@ async function handler(req: NextApiRequest, res: NextApiResponse) {
}
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
const openAI = new OpenAI({ apiKey: openAiKey })
const {
body: { prompt, entityDefinitions },
} = req
const model = 'gpt-3.5-turbo-0613'
const maxCompletionTokenCount = 1024
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
if (hasEntityDefinitions) {
completionMessages.push({
role: 'user',
content: codeBlock`
Here is my database schema for reference:
${entityDefinitions.join('\n\n')}
`,
})
}
completionMessages.push({
role: 'user',
content: prompt,
})
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model,
messages: completionMessages,
max_tokens: maxCompletionTokenCount,
temperature: 0,
function_call: {
name: completionFunctions.generateSql.name,
},
functions: [completionFunctions.generateSql],
stream: false,
}
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
try {
completionResponse = await openAI.chat.completions.create(completionOptions)
} catch (error: any) {
console.error(`AI SQL generation failed: ${error.message}`)
if ('code' in error && error.code === 'context_length_exceeded' && hasEntityDefinitions) {
return res.status(400).json({
error:
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
})
}
return res.status(500).json({
error: 'There was an unknown error generating the SQL snippet. Please try again.',
})
}
const [firstChoice] = completionResponse.choices
const sqlResponseString = firstChoice.message?.function_call?.arguments
if (!sqlResponseString) {
console.error(
`AI SQL generation failed: OpenAI response succeeded, but response format was incorrect`
)
return res.status(500).json({
error: 'There was an unknown error generating the SQL snippet. Please try again.',
})
}
try {
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
const repairedJsonString = jsonrepair(sqlResponseString)
const generateSqlResult: GenerateSqlResult = JSON.parse(repairedJsonString)
if (!generateSqlResult.sql) {
console.error(`AI SQL generation failed: Unable to generate SQL for the given prompt`)
res.status(400).json({
error: 'Unable to generate SQL. Try adding more details to your prompt.',
})
return
}
return res.json(generateSqlResult)
const result = await generateSql(openai, prompt, entityDefinitions)
return res.json(result)
} catch (error) {
console.error(
`AI SQL editing failed: ${
isError(error) ? error.message : 'An unknown error occurred'
}, sqlResponseString: ${sqlResponseString}`
)
if (error instanceof Error) {
console.error(`AI SQL generation failed: ${error.message}`)
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
if (error instanceof ContextLengthError && hasEntityDefinitions) {
return res.status(400).json({
error:
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
})
}
if (error instanceof EmptySqlError) {
res.status(400).json({
error: 'Unable to generate SQL. Try adding more details to your prompt.',
})
}
} else {
console.log(`Unknown error: ${error}`)
}
return res.status(500).json({
error: 'There was an unknown error editing the SQL snippet. Please try again.',
error: 'There was an unknown error generating the SQL snippet. Please try again.',
})
}
}
+4 -66
View File
@@ -1,5 +1,5 @@
import { OpenAIStream, StreamingTextResponse } from 'ai'
import { codeBlock, oneLine, stripIndent } from 'common-tags'
import { StreamingTextResponse } from 'ai'
import { chatRlsPolicy } from 'ai-commands/edge'
import { NextRequest } from 'next/server'
import OpenAI from 'openai'
@@ -47,72 +47,10 @@ async function handlePost(request: NextRequest) {
const { messages, entityDefinitions, policyDefinition } = body
const initMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
{
role: 'system',
content: stripIndent`
You're an Postgres expert in writing row level security policies. Your purpose is to
generate a policy with the constraints given by the user. You will be provided a schema
on which the policy should be applied.
The output should use the following instructions:
- The generated SQL must be valid SQL.
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
- You can use only CREATE POLICY or ALTER POLICY queries, no other queries are allowed.
- You can add short explanations to your messages.
- The result should be a valid markdown. The SQL code should be wrapped in \`\`\`.
- Always use "auth.uid()" instead of "current_user".
- You can't use "USING" expression on INSERT policies.
- Only use "WITH CHECK" expression on INSERT or UPDATE policies.
- The policy name should be short text explaining the policy, enclosed in double quotes.
- Always put explanations as separate text. Never use inline SQL comments.
- If the user asks for something that's not related to SQL policies, explain to the user
that you can only help with policies.
The output should look like this:
"CREATE POLICY user_policy ON users FOR INSERT USING (user_name = current_user) WITH (true);"
`,
},
]
if (entityDefinitions) {
const definitions = codeBlock`${entityDefinitions.join('\n\n')}`
initMessages.push({
role: 'user',
content: oneLine`Here is my database schema for reference: ${definitions}`,
})
}
if (policyDefinition !== undefined) {
const definitionBlock = codeBlock`${policyDefinition}`
initMessages.push({
role: 'user',
content: codeBlock`
Here is my policy definition for reference:
${definitionBlock}
`.trim(),
})
}
if (messages) {
initMessages.push(...messages)
}
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: 'gpt-3.5-turbo-1106',
messages: initMessages,
max_tokens: 1024,
temperature: 0,
stream: true,
}
try {
const response = await openai.chat.completions.create(completionOptions)
// Proxy the streamed SSE response from OpenAI
const stream = OpenAIStream(response)
const stream = await chatRlsPolicy(openai, messages, entityDefinitions, policyDefinition)
return new StreamingTextResponse(stream)
} catch (error: any) {
} catch (error) {
console.error(error)
return new Response(
+17 -103
View File
@@ -1,36 +1,10 @@
import { SchemaBuilder } from '@serafin/schema-builder'
import { stripIndent } from 'common-tags'
import { isError } from 'data/utils/error-check'
import { jsonrepair } from 'jsonrepair'
import { ContextLengthError, titleSql } from 'ai-commands'
import apiWrapper from 'lib/api/apiWrapper'
import { NextApiRequest, NextApiResponse } from 'next'
import { OpenAI } from 'openai'
const openAiKey = process.env.OPENAI_KEY
const generateTitleSchema = SchemaBuilder.emptySchema()
.addString('title', {
description: stripIndent`
The generated title for the SQL snippet (short and concise).
- Omit these words: 'SQL', 'Postgres', 'Query', 'Database'
`,
})
.addString('description', {
description: stripIndent`
The generated description for the SQL snippet (longer and more detailed than title).
- Read the SQL line by line and summarize it
`,
})
type GenerateTitleResult = typeof generateTitleSchema.T
const completionFunctions: Record<string, OpenAI.ChatCompletionCreateParams.Function> = {
generateTitle: {
name: 'generateTitle',
description: 'Generates a short title and detailed description for a Postgres SQL snippet',
parameters: generateTitleSchema.schema as Record<string, unknown>,
},
}
const openai = new OpenAI({ apiKey: openAiKey })
async function handler(req: NextApiRequest, res: NextApiResponse) {
if (!openAiKey) {
@@ -51,89 +25,29 @@ async function handler(req: NextApiRequest, res: NextApiResponse) {
}
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
const openAI = new OpenAI({ apiKey: openAiKey })
const {
body: { sql },
} = req
const model = 'gpt-3.5-turbo-0613'
const maxCompletionTokenCount = 1024
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
{
role: 'user',
content: sql,
},
]
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model,
messages: completionMessages,
max_tokens: maxCompletionTokenCount,
temperature: 0,
function_call: {
name: completionFunctions.generateTitle.name,
},
functions: [completionFunctions.generateTitle],
stream: false,
}
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
try {
completionResponse = await openAI.chat.completions.create(completionOptions)
} catch (error: any) {
console.error(`AI title generation failed: ${error.message}`)
if ('code' in error && error.code === 'context_length_exceeded') {
return res.status(400).json({
error:
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
})
}
return res.status(500).json({
error: 'There was an unknown error generating the snippet title. Please try again.',
})
}
const [firstChoice] = completionResponse.choices
const titleResponseString = firstChoice.message?.function_call?.arguments
if (!titleResponseString) {
console.error(
`AI title generation failed: OpenAI response succeeded, but response format was incorrect`
)
return res.status(500).json({
error: 'There was an unknown error generating the snippet title. Please try again.',
})
}
try {
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
const repairedJsonString = jsonrepair(titleResponseString)
const generateTitleResult: GenerateTitleResult = JSON.parse(repairedJsonString)
if (!generateTitleResult.title) {
console.error(`AI title generation failed: Unable to generate title for the given SQL`)
res.status(400).json({
error: 'Unable to generate title',
})
}
return res.json(generateTitleResult)
const result = await titleSql(openai, sql)
return res.json(result)
} catch (error) {
console.error(
`AI SQL editing failed: ${
isError(error) ? error.message : 'An unknown error occurred'
}, titleResponseString: ${titleResponseString}`
)
if (error instanceof Error) {
console.error(`AI title generation failed: ${error.message}`)
if (error instanceof ContextLengthError) {
return res.status(400).json({
error:
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
})
}
} else {
console.log(`Unknown error: ${error}`)
}
return res.status(500).json({
error: 'There was an unknown error editing the SQL snippet. Please try again.',
error: 'There was an unknown error generating the snippet title. Please try again.',
})
}
}
+475 -286
View File
File diff suppressed because it is too large Load Diff
+12
View File
@@ -0,0 +1,12 @@
# ai-commands
## Main purpose
This package contains all features involving OpenAI API. Technically, each feature is implemented as a function which
can be easily tested for regressions.
The streaming functions only work on Edge runtime so they can only be imported via a special `edge` subpath like so:
```
import { chatRlsPolicy } from 'ai-commands/edge'
```
+4
View File
@@ -0,0 +1,4 @@
module.exports = {
presets: [['@babel/preset-env', { targets: { node: 'current' } }]],
plugins: ['babel-plugin-transform-import-meta'],
}
+2
View File
@@ -0,0 +1,2 @@
export * from './src/errors'
export * from './src/sql.edge'
+2
View File
@@ -0,0 +1,2 @@
export * from './src/errors'
export * from './src/sql'
+14
View File
@@ -0,0 +1,14 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'^.+\\.ts?$': 'ts-jest',
'^.+\\.(js|jsx)$': 'babel-jest',
},
setupFiles: ['./test/setup.ts'],
testTimeout: 15000,
transformIgnorePatterns: [
'node_modules/(?!(mdast-.*|micromark|micromark-.*|unist-.*|decode-named-character-reference|character-entities)/)',
],
}
+34
View File
@@ -0,0 +1,34 @@
{
"name": "ai-commands",
"version": "0.0.0",
"main": "./index.ts",
"types": "./index.ts",
"license": "MIT",
"scripts": {
"typecheck": "tsc --noEmit",
"test": "jest"
},
"dependencies": {
"@serafin/schema-builder": "^0.18.5",
"ai": "^2.2.29",
"common-tags": "^1.8.2",
"config": "*",
"jsonrepair": "^3.5.0",
"openai": "^4.20.1"
},
"devDependencies": {
"@babel/core": "^7.23.6",
"@babel/preset-env": "^7.23.6",
"@jest/globals": "^29.7.0",
"@types/common-tags": "^1.8.4",
"babel-jest": "^29.7.0",
"babel-plugin-transform-import-meta": "^2.2.1",
"dotenv": "^16.3.1",
"jest": "^29.7.0",
"mdast-util-from-markdown": "^2.0.0",
"sql-formatter": "^15.0.2",
"ts-jest": "^29.1.1",
"tsconfig": "*",
"typescript": "^5.2.2"
}
}
@@ -0,0 +1,51 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`debug fix order of operations 1`] = `
"create table departments (
id bigint primary key generated always as identity,
name text
);
create table employees (
id bigint primary key generated always as identity,
name text,
email text,
department_id bigint references departments (id)
);"
`;
exports[`debug fix typos 1`] = `
"select
*
from
employees;"
`;
exports[`edit add length constraint 1`] = `
"create table employees (
id bigint primary key generated always as identity,
name text check (length(name) >= 4),
email text
);"
`;
exports[`generate single table with specified columns 1`] = `
"create table employees (
id bigint primary key generated always as identity,
name text,
email text,
position text
);"
`;
exports[`generate single table with specified columns 2`] = `"Employee Tracking Table"`;
exports[`rls chat select policy using table definition 1`] = `
"create policy select_todo_policy on todos for
select
using (user_id = auth.uid ());"
`;
exports[`title title matches content 1`] = `"Employee and Department Tables"`;
exports[`title title matches content 2`] = `"Tables to track employees and their respective departments"`;
+23
View File
@@ -0,0 +1,23 @@
export class ApplicationError extends Error {
constructor(message: string, public data: Record<string, any> = {}) {
super(message)
}
}
export class ContextLengthError extends ApplicationError {
constructor() {
super('LLM context length exceeded')
}
}
export class EmptyResponseError extends ApplicationError {
constructor() {
super('LLM API response succeeded but returned nothing')
}
}
export class EmptySqlError extends ApplicationError {
constructor() {
super('LLM did not generate any SQL')
}
}
+90
View File
@@ -0,0 +1,90 @@
import { OpenAIStream } from 'ai'
import { codeBlock, oneLine, stripIndent } from 'common-tags'
import OpenAI from 'openai'
import { ContextLengthError } from './errors'
export type AiAssistantMessage = {
content: string
role: 'user' | 'assistant'
}
/**
* Responds to a conversation about building an RLS policy.
*
* @returns A `ReadableStream` containing the response text and SQL.
*/
export async function chatRlsPolicy(
openai: OpenAI,
messages: AiAssistantMessage[],
entityDefinitions?: string[],
policyDefinition?: string
): Promise<ReadableStream<Uint8Array>> {
const initMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
{
role: 'system',
content: stripIndent`
You're an Postgres expert in writing row level security policies. Your purpose is to
generate a policy with the constraints given by the user. You will be provided a schema
on which the policy should be applied.
The output should use the following instructions:
- The generated SQL must be valid SQL.
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
- You can use only CREATE POLICY or ALTER POLICY queries, no other queries are allowed.
- You can add short explanations to your messages.
- The result should be a valid markdown. The SQL code should be wrapped in \`\`\`.
- Always use "auth.uid()" instead of "current_user".
- You can't use "USING" expression on INSERT policies.
- Only use "WITH CHECK" expression on INSERT or UPDATE policies.
- The policy name should be short text explaining the policy, enclosed in double quotes.
- Always put explanations as separate text. Never use inline SQL comments.
- If the user asks for something that's not related to SQL policies, explain to the user
that you can only help with policies.
The output should look like this:
"CREATE POLICY user_policy ON users FOR INSERT USING (user_name = current_user) WITH (true);"
`,
},
]
if (entityDefinitions) {
const definitions = codeBlock`${entityDefinitions.join('\n\n')}`
initMessages.push({
role: 'user',
content: oneLine`Here is my database schema for reference: ${definitions}`,
})
}
if (policyDefinition !== undefined) {
const definitionBlock = codeBlock`${policyDefinition}`
initMessages.push({
role: 'user',
content: codeBlock`
Here is my policy definition for reference:
${definitionBlock}
`.trim(),
})
}
if (messages) {
initMessages.push(...messages)
}
try {
const response = await openai.chat.completions.create({
model: 'gpt-3.5-turbo-1106',
messages: initMessages,
max_tokens: 1024,
temperature: 0,
stream: true,
})
// Transform the streamed SSE response from OpenAI to a ReadableStream
return OpenAIStream(response)
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
throw new ContextLengthError()
}
throw error
}
}
+128
View File
@@ -0,0 +1,128 @@
import { describe, expect, test } from '@jest/globals'
import { codeBlock } from 'common-tags'
import OpenAI from 'openai'
import { collectStream, extractMarkdownSql, formatSql } from '../test/util'
import { debugSql, editSql, generateSql, titleSql } from './sql'
import { chatRlsPolicy } from './sql.edge'
const openAiKey = process.env.OPENAI_KEY
const openai = new OpenAI({ apiKey: openAiKey })
describe('generate', () => {
test('single table with specified columns', async () => {
const { sql, title } = await generateSql(
openai,
'create a table to track employees with name, email, and position'
)
expect(formatSql(sql)).toMatchSnapshot()
expect(title).toMatchSnapshot()
})
})
describe('edit', () => {
test('add length constraint', async () => {
const { sql } = await editSql(
openai,
'force name to be at least 4 characters',
codeBlock`
create table employees (
id bigint primary key generated always as identity,
name text,
email text
);
`
)
expect(formatSql(sql)).toMatchSnapshot()
})
})
describe('debug', () => {
test('fix order of operations', async () => {
const { sql } = await debugSql(
openai,
'relation "departments" does not exist',
codeBlock`
create table employees (
id bigint primary key generated always as identity,
name text,
email text,
department_id bigint references departments (id)
);
create table departments (
id bigint primary key generated always as identity,
name text
);
`
)
expect(formatSql(sql)).toMatchSnapshot()
})
test('fix typos', async () => {
const { sql, solution } = await debugSql(
openai,
'syntax error at or near "fromm"',
codeBlock`
select * fromm employees;
`
)
expect(solution).toBeDefined()
expect(formatSql(sql)).toMatchSnapshot()
})
})
describe('title', () => {
test('title matches content', async () => {
const { title, description } = await titleSql(
openai,
codeBlock`
create table employees (
id bigint primary key generated always as identity,
name text,
email text,
department_id bigint references departments (id)
);
create table departments (
id bigint primary key generated always as identity,
name text
);
`
)
expect(title).toMatchSnapshot()
expect(description).toMatchSnapshot()
})
})
describe('rls chat', () => {
test('select policy using table definition', async () => {
const responseStream = await chatRlsPolicy(
openai,
[
{
role: 'user',
content: 'Users can only select their own todos',
},
],
[
codeBlock`
create table todos (
id bigint primary key generated always as identity,
task text,
email text,
user_id uuid references auth.users (id)
);
`,
]
)
const responseText = await collectStream(responseStream)
const [sql] = extractMarkdownSql(responseText)
expect(formatSql(sql)).toMatchSnapshot()
})
})
+409
View File
@@ -0,0 +1,409 @@
import { SchemaBuilder } from '@serafin/schema-builder'
import { codeBlock, stripIndent } from 'common-tags'
import { jsonrepair } from 'jsonrepair'
import OpenAI from 'openai'
import { ContextLengthError, EmptyResponseError, EmptySqlError } from './errors'
// Declare JSON schema for each function that the LLM can call
const generateSqlSchema = SchemaBuilder.emptySchema()
.addString('sql', {
description: stripIndent`
The generated SQL (must be valid SQL).
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- Prefer creating foreign key references in the create statement
- Prefer 'text' over 'varchar'
- Prefer 'timestamp with time zone' over 'date'
- Use vector(384) data type for any embedding/vector related query
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
`,
})
.addString('title', {
description: stripIndent`
The title of the SQL.
- Omit words like 'SQL', 'Postgres', or 'Query'
`,
})
const editSqlSchema = SchemaBuilder.emptySchema().addString('sql', {
description: stripIndent`
The modified SQL (must be valid SQL).
- Assume the query hasn't been executed yet
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- When creating tables, always add foreign key references inline
- Prefer 'text' over 'varchar'
- Prefer 'timestamp with time zone' over 'date'
- Use vector(384) data type for any embedding/vector related query
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
- Use real examples when possible
- Add constraints if requested
`,
})
const debugSqlSchema = SchemaBuilder.emptySchema()
.addString('solution', {
description: 'A short suggested solution for the error (as concise as possible).',
})
.addString('sql', {
description: 'The SQL rewritten to apply the solution. Includes all the original logic, but modified to fix the issue.',
})
const generateTitleSchema = SchemaBuilder.emptySchema()
.addString('title', {
description: stripIndent`
The generated title for the SQL snippet (short and concise).
- Omit these words: 'SQL', 'Postgres', 'Query', 'Database'
`,
})
.addString('description', {
description: stripIndent`
The generated description for the SQL snippet.
`,
})
// Reference auto-generated types for each JSON schema
export type GenerateSqlResult = typeof generateSqlSchema.T
export type EditSqlResult = typeof editSqlSchema.T
export type DebugSqlResult = typeof debugSqlSchema.T
export type GenerateTitleResult = typeof generateTitleSchema.T
// Combine the completion functions
const completionFunctions = {
generateSql: {
name: 'generateSql',
description: 'Generates Postgres SQL based on a natural language prompt',
parameters: generateSqlSchema.schema as Record<string, unknown>,
},
editSql: {
name: 'editSql',
description: "Edits a Postgres SQL query based on the user's instructions",
parameters: editSqlSchema.schema as Record<string, unknown>,
},
debugSql: {
name: 'debugSql',
description: stripIndent`
Debugs a Postgres SQL error. Returns the fixed SQL and a solution explaining it.
- Create extensions if they are missing (only for valid extensions)
- Suggest creating tables if they are missing
- Include all of the original SQL
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- When creating tables, always add foreign key references inline
- Prefer 'text' over 'varchar'
- Prefer 'timestamp with time zone' over 'date'
- Use vector(384) data type for any embedding/vector related query
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
`,
parameters: debugSqlSchema.schema as Record<string, unknown>,
},
generateTitle: {
name: 'generateTitle',
description: stripIndent`
Generates a short title and summarized description for a Postgres SQL snippet.
The description should describe why this table was created (eg. "Table to track todos")
`,
parameters: generateTitleSchema.schema as Record<string, unknown>,
},
} satisfies Record<string, OpenAI.Chat.Completions.ChatCompletionCreateParams.Function>
/**
* Generates a SQL snippet based on the provided prompt.
*
* @returns The generated SQL along with a title for it.
*/
export async function generateSql(openai: OpenAI, prompt: string, entityDefinitions?: string[]) {
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
if (hasEntityDefinitions) {
completionMessages.push({
role: 'user',
content: codeBlock`
Here is my database schema for reference:
${entityDefinitions.join('\n\n')}
`,
})
}
completionMessages.push({
role: 'user',
content: prompt,
})
try {
const completionResponse = await openai.chat.completions.create({
model: 'gpt-3.5-turbo-1106',
messages: completionMessages,
max_tokens: 1024,
temperature: 0,
tool_choice: {
type: 'function',
function: {
name: completionFunctions.generateSql.name,
},
},
tools: [
{
type: 'function',
function: completionFunctions.generateSql,
},
],
stream: false,
})
const [firstChoice] = completionResponse.choices
const [firstTool] = firstChoice.message?.tool_calls ?? []
const sqlResponseString = firstTool?.function.arguments
if (!sqlResponseString) {
throw new EmptyResponseError()
}
const repairedJsonString = jsonrepair(sqlResponseString)
const result: GenerateSqlResult = JSON.parse(repairedJsonString)
if (!result.sql) {
throw new EmptySqlError()
}
return result
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
throw new ContextLengthError()
}
throw error
}
}
/**
* Modifies a SQL snippet based on the provided prompt.
*
* @returns The modified SQL.
*/
export async function editSql(
openai: OpenAI,
prompt: string,
sql: string,
entityDefinitions?: string[]
) {
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
if (hasEntityDefinitions) {
completionMessages.push({
role: 'user',
content: codeBlock`
Here is my database schema for reference:
${entityDefinitions.join('\n\n')}
`,
})
}
completionMessages.push(
{
role: 'user',
content: stripIndent`
Here is my current SQL:
${sql}
`,
},
{
role: 'user',
content: prompt,
}
)
try {
const completionResponse = await openai.chat.completions.create({
model: 'gpt-3.5-turbo-1106',
messages: completionMessages,
max_tokens: 2048,
temperature: 0,
tool_choice: {
type: 'function',
function: {
name: completionFunctions.editSql.name,
},
},
tools: [
{
type: 'function',
function: completionFunctions.editSql,
},
],
stream: false,
})
const [firstChoice] = completionResponse.choices
const [firstTool] = firstChoice.message?.tool_calls ?? []
const sqlResponseString = firstTool?.function.arguments
if (!sqlResponseString) {
throw new EmptyResponseError()
}
const repairedJsonString = jsonrepair(sqlResponseString)
const result: EditSqlResult = JSON.parse(repairedJsonString)
if (!result.sql) {
throw new EmptySqlError()
}
return result
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
throw new ContextLengthError()
}
throw error
}
}
/**
* Debugs SQL errors.
*
* @returns A suggested SQL fix along with a solution explanation.
*/
export async function debugSql(
openai: OpenAI,
errorMessage: string,
sql: string,
entityDefinitions?: string[]
) {
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
if (hasEntityDefinitions) {
completionMessages.push({
role: 'user',
content: codeBlock`
Here is my database schema for reference:
${entityDefinitions.join('\n\n')}
`,
})
}
completionMessages.push(
{
role: 'user',
content: stripIndent`
Here is my current SQL:
${sql}
`,
},
{
role: 'user',
content: stripIndent`
Here is the error I am getting:
${errorMessage}
`,
}
)
try {
const completionResponse = await openai.chat.completions.create({
model: 'gpt-3.5-turbo-1106',
messages: completionMessages,
max_tokens: 2048,
temperature: 0,
tool_choice: {
type: 'function',
function: {
name: completionFunctions.debugSql.name,
},
},
tools: [
{
type: 'function',
function: completionFunctions.debugSql,
},
],
stream: false,
})
const [firstChoice] = completionResponse.choices
const [firstTool] = firstChoice.message?.tool_calls ?? []
const sqlResponseString = firstTool?.function.arguments
if (!sqlResponseString) {
throw new EmptyResponseError()
}
const repairedJsonString = jsonrepair(sqlResponseString)
const result: DebugSqlResult = JSON.parse(repairedJsonString)
if (!result.sql) {
throw new EmptySqlError()
}
return result
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
throw new ContextLengthError()
}
throw error
}
}
/**
* Generates a snippet title based on the provided SQL.
*
* @returns A title and description for the SQL snippet.
*/
export async function titleSql(openai: OpenAI, sql: string) {
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
{
role: 'user',
content: sql,
},
]
try {
const completionResponse = await openai.chat.completions.create({
model: 'gpt-3.5-turbo-1106',
messages: completionMessages,
max_tokens: 1024,
temperature: 0,
tool_choice: {
type: 'function',
function: {
name: completionFunctions.generateTitle.name,
},
},
tools: [
{
type: 'function',
function: completionFunctions.generateTitle,
},
],
stream: false,
})
const [firstChoice] = completionResponse.choices
const [firstTool] = firstChoice.message?.tool_calls ?? []
const sqlResponseString = firstTool?.function.arguments
if (!sqlResponseString) {
throw new EmptyResponseError()
}
const repairedJsonString = jsonrepair(sqlResponseString)
const result: GenerateTitleResult = JSON.parse(repairedJsonString)
return result
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
throw new ContextLengthError()
}
throw error
}
}
+8
View File
@@ -0,0 +1,8 @@
import { config } from 'dotenv'
import { statSync } from 'fs'
// Use studio .env.local for now
const envPath = '../../apps/studio/.env.local'
statSync(envPath)
config({ path: envPath })
+47
View File
@@ -0,0 +1,47 @@
import { fromMarkdown } from 'mdast-util-from-markdown'
import type { Code } from 'mdast-util-from-markdown/lib'
import { format } from 'sql-formatter'
declare global {
interface ReadableStream<R = any> {
[Symbol.asyncIterator](): AsyncIterableIterator<R>
}
}
/**
* Formats Postgres SQL into a consistent format.
*
* @returns The formatted SQL.
*/
export const formatSql = (sql: string) =>
format(sql, { language: 'postgresql', keywordCase: 'lower' })
/**
* Collects an `ArrayBuffer` stream into a single decoded string.
*
* @returns A single string combining all the decoded stream chunks.
*/
export async function collectStream<R extends BufferSource>(stream: ReadableStream<R>) {
const textDecoderStream = new TextDecoderStream()
let content = ''
for await (const chunk of stream.pipeThrough(textDecoderStream)) {
content += chunk
}
return content
}
/**
* Parses markdown and extracts all SQL code blocks.
*
* @returns An array of string content from each SQL code block.
*/
export function extractMarkdownSql(markdown: string) {
const mdTree = fromMarkdown(markdown)
return mdTree.children
.filter((node): node is Code => node.type === 'code' && node.lang === 'sql')
.map(({ value }) => value)
}
+5
View File
@@ -0,0 +1,5 @@
{
"extends": "tsconfig/react-library.json",
"include": ["."],
"exclude": ["dist", "build", "node_modules"]
}
+1
View File
@@ -54,6 +54,7 @@
"cache": false
},
"typecheck": {
"dependsOn": ["^typecheck"],
"outputs": ["**/node_modules/.cache/tsbuildinfo.json"]
}
}