mirror of
https://github.com/supabase/supabase.git
synced 2026-05-06 08:56:46 -04:00
AI SQL: Update model + add test suite (#19644)
* refactor: ai sql logic to common package * feat: ai sql tests * feat: test rls policy ai chat assistant * refactor: jest env loading * feat(ai): improve sql title and description quality * fix: ai message role type * Move the new files to a separate package. * Remove a forgotten console.log. * Migrate the tests to use snapshots and commit the snapshots. * Separate the functions which require edge runtime to be exported via /edge subpath. * Bust the turbo cache when one of the deps has been rebuilt. * chore: fix package main/type references * fix: package lock out of sync * fix: ai sql debugging to fix typos --------- Co-authored-by: Ivan Vasilov <vasilov.ivan@gmail.com>
This commit is contained in:
@@ -11,7 +11,7 @@ import { AIPolicyPre } from './AIPolicyPre'
|
||||
|
||||
interface MessageProps {
|
||||
name?: string
|
||||
role: 'function' | 'user' | 'assistant' | 'system'
|
||||
role: 'function' | 'user' | 'assistant' | 'system' | 'data'
|
||||
content?: string
|
||||
createdAt?: number
|
||||
isDebug?: boolean
|
||||
|
||||
@@ -39,6 +39,8 @@
|
||||
"@uidotdev/usehooks": "^2.4.1",
|
||||
"@zip.js/zip.js": "^2.7.29",
|
||||
"ai": "^2.2.26",
|
||||
"ai-commands": "*",
|
||||
"ajv": "^8.6.3",
|
||||
"awesome-debounce-promise": "^2.1.0",
|
||||
"blueimp-md5": "^2.19.0",
|
||||
"clsx": "^1.2.1",
|
||||
|
||||
@@ -1,44 +1,10 @@
|
||||
import { SchemaBuilder } from '@serafin/schema-builder'
|
||||
import { codeBlock, stripIndent } from 'common-tags'
|
||||
import { isError } from 'data/utils/error-check'
|
||||
import { jsonrepair } from 'jsonrepair'
|
||||
import { ContextLengthError, EmptySqlError, debugSql } from 'ai-commands'
|
||||
import apiWrapper from 'lib/api/apiWrapper'
|
||||
import { NextApiRequest, NextApiResponse } from 'next'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
const openAiKey = process.env.OPENAI_KEY
|
||||
|
||||
const debugSqlSchema = SchemaBuilder.emptySchema()
|
||||
.addString('solution', {
|
||||
description: 'A short suggested solution for the error (as concise as possible).',
|
||||
})
|
||||
.addString('sql', {
|
||||
description: 'The SQL rewritten to apply the solution. Includes all the original SQL.',
|
||||
})
|
||||
|
||||
type DebugSqlResult = typeof debugSqlSchema.T
|
||||
|
||||
const completionFunctions: Record<
|
||||
string,
|
||||
OpenAI.Chat.Completions.ChatCompletionCreateParams.Function
|
||||
> = {
|
||||
debugSql: {
|
||||
name: 'debugSql',
|
||||
description: stripIndent`
|
||||
Debugs a Postgres SQL error and modifies the SQL to fix it.
|
||||
- Create extensions if they are missing (only for valid extensions)
|
||||
- Suggest creating tables if they are missing
|
||||
- Include all of the original SQL
|
||||
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
|
||||
- When creating tables, always add foreign key references inline
|
||||
- Prefer 'text' over 'varchar'
|
||||
- Prefer 'timestamp with time zone' over 'date'
|
||||
- Use vector(384) data type for any embedding/vector related query
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
`,
|
||||
parameters: debugSqlSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
}
|
||||
const openai = new OpenAI({ apiKey: openAiKey })
|
||||
|
||||
async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
if (!openAiKey) {
|
||||
@@ -59,128 +25,54 @@ async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
}
|
||||
|
||||
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
|
||||
const openAI = new OpenAI({ apiKey: openAiKey })
|
||||
const {
|
||||
body: { errorMessage, sql, entityDefinitions },
|
||||
} = req
|
||||
|
||||
const model = 'gpt-3.5-turbo-0613'
|
||||
const maxCompletionTokenCount = 2048
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
|
||||
|
||||
if (hasEntityDefinitions) {
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my database schema for reference:
|
||||
${entityDefinitions.join('\n\n')}
|
||||
`,
|
||||
})
|
||||
}
|
||||
|
||||
completionMessages.push(
|
||||
{
|
||||
role: 'user',
|
||||
content: stripIndent`
|
||||
Here is my current SQL:
|
||||
${sql}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: stripIndent`
|
||||
Here is the error I am getting:
|
||||
${errorMessage}
|
||||
`,
|
||||
}
|
||||
)
|
||||
|
||||
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model,
|
||||
messages: completionMessages,
|
||||
max_tokens: maxCompletionTokenCount,
|
||||
temperature: 0,
|
||||
function_call: {
|
||||
name: completionFunctions.debugSql.name,
|
||||
},
|
||||
functions: [completionFunctions.debugSql],
|
||||
stream: false,
|
||||
}
|
||||
|
||||
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
|
||||
try {
|
||||
completionResponse = await openAI.chat.completions.create(completionOptions)
|
||||
} catch (error: any) {
|
||||
console.error(`AI SQL debugging failed: ${error.message}`)
|
||||
const result = await debugSql(openai, errorMessage, sql, entityDefinitions)
|
||||
return res.json(result)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
console.error(`AI SQL debugging failed: ${error.message}`)
|
||||
|
||||
if ('code' in error && error.code === 'context_length_exceeded') {
|
||||
if (hasEntityDefinitions) {
|
||||
const definitionsLength = entityDefinitions.reduce(
|
||||
(sum: number, def: string) => sum + def.length,
|
||||
0
|
||||
)
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
if (definitionsLength > sql.length) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
|
||||
})
|
||||
if (error instanceof ContextLengthError) {
|
||||
// If there are more entity definitions than the SQL provided, attribute the
|
||||
// error to the database metadata
|
||||
if (hasEntityDefinitions) {
|
||||
const definitionsLength = entityDefinitions.reduce(
|
||||
(sum: number, def: string) => sum + def.length,
|
||||
0
|
||||
)
|
||||
if (definitionsLength > sql.length) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
|
||||
})
|
||||
}
|
||||
}
|
||||
// Otherwise attribute the error to the SQL being too large
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
|
||||
})
|
||||
}
|
||||
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
|
||||
})
|
||||
if (error instanceof EmptySqlError) {
|
||||
res.status(400).json({
|
||||
error: 'Unable to debug SQL. No fix identified for the error.',
|
||||
})
|
||||
}
|
||||
} else {
|
||||
console.log(`Unknown error: ${error}`)
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error debugging the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
|
||||
const sqlResponseString = firstChoice.message?.function_call?.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
console.error(
|
||||
`AI SQL debugging failed: OpenAI response succeeded, but response format was incorrect`
|
||||
)
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error debugging the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const debugSqlResult: DebugSqlResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!debugSqlResult.sql) {
|
||||
console.error(`AI SQL debugging failed: Unable to debug SQL for the given error message`)
|
||||
|
||||
return res.status(400).json({
|
||||
error: 'Unable to debug SQL',
|
||||
})
|
||||
}
|
||||
|
||||
return res.json(debugSqlResult)
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`AI SQL editing failed: ${
|
||||
isError(error) ? error.message : 'An unknown error occurred'
|
||||
}, sqlResponseString: ${sqlResponseString}`
|
||||
)
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error editing the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const wrapper = (req: NextApiRequest, res: NextApiResponse) => apiWrapper(req, res, handler)
|
||||
|
||||
@@ -1,39 +1,10 @@
|
||||
import { SchemaBuilder } from '@serafin/schema-builder'
|
||||
import { codeBlock, stripIndent } from 'common-tags'
|
||||
import { isError } from 'data/utils/error-check'
|
||||
import { jsonrepair } from 'jsonrepair'
|
||||
import { ContextLengthError, EmptySqlError, editSql } from 'ai-commands'
|
||||
import apiWrapper from 'lib/api/apiWrapper'
|
||||
import { NextApiRequest, NextApiResponse } from 'next'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
const openAiKey = process.env.OPENAI_KEY
|
||||
|
||||
const editSqlSchema = SchemaBuilder.emptySchema().addString('sql', {
|
||||
description: stripIndent`
|
||||
The modified SQL (must be valid SQL).
|
||||
- Assume the query hasn't been executed yet
|
||||
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
|
||||
- When creating tables, always add foreign key references inline
|
||||
- Prefer 'text' over 'varchar'
|
||||
- Prefer 'timestamp with time zone' over 'date'
|
||||
- Use vector(384) data type for any embedding/vector related query
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
- Use real examples when possible
|
||||
`,
|
||||
})
|
||||
|
||||
type EditSqlResult = typeof editSqlSchema.T
|
||||
|
||||
const completionFunctions: Record<
|
||||
string,
|
||||
OpenAI.Chat.Completions.ChatCompletionCreateParams.Function
|
||||
> = {
|
||||
editSql: {
|
||||
name: 'editSql',
|
||||
description: "Edits a Postgres SQL query based on the user's instructions",
|
||||
parameters: editSqlSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
}
|
||||
const openai = new OpenAI({ apiKey: openAiKey })
|
||||
|
||||
async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
if (!openAiKey) {
|
||||
@@ -58,112 +29,45 @@ export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
|
||||
body: { prompt, sql, entityDefinitions },
|
||||
} = req
|
||||
|
||||
const openAI = new OpenAI({ apiKey: openAiKey })
|
||||
const model = 'gpt-3.5-turbo-0613'
|
||||
const maxCompletionTokenCount = 2048
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
|
||||
|
||||
if (hasEntityDefinitions) {
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my database schema for reference:
|
||||
${entityDefinitions.join('\n\n')}
|
||||
`,
|
||||
})
|
||||
}
|
||||
|
||||
completionMessages.push(
|
||||
{
|
||||
role: 'user',
|
||||
content: stripIndent`
|
||||
Here is my current SQL:
|
||||
${sql}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
}
|
||||
)
|
||||
|
||||
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model,
|
||||
messages: completionMessages,
|
||||
max_tokens: maxCompletionTokenCount,
|
||||
temperature: 0,
|
||||
function_call: {
|
||||
name: completionFunctions.editSql.name,
|
||||
},
|
||||
functions: [completionFunctions.editSql],
|
||||
stream: false,
|
||||
}
|
||||
|
||||
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
|
||||
try {
|
||||
completionResponse = await openAI.chat.completions.create(completionOptions)
|
||||
} catch (error: any) {
|
||||
console.error(`AI SQL editing failed: ${error.message}`)
|
||||
if ('code' in error && error.code === 'context_length_exceeded') {
|
||||
if (hasEntityDefinitions) {
|
||||
const definitionsLength = entityDefinitions.reduce(
|
||||
(sum: number, def: string) => sum + def.length,
|
||||
0
|
||||
)
|
||||
if (definitionsLength > sql.length) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
|
||||
})
|
||||
}
|
||||
}
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
|
||||
})
|
||||
}
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error editing the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
|
||||
const sqlResponseString = firstChoice.message?.function_call?.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
console.error(
|
||||
`AI SQL editing failed: OpenAI response succeeded, but response format was incorrect`
|
||||
)
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error editing the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const editSqlResult: EditSqlResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!editSqlResult.sql) {
|
||||
console.error(`AI SQL editing failed: Unable to edit SQL for the given prompt`)
|
||||
|
||||
return res.status(400).json({
|
||||
error: 'Unable to edit SQL. Try adding more details to your prompt.',
|
||||
})
|
||||
}
|
||||
|
||||
return res.json(editSqlResult)
|
||||
const result = await editSql(openai, prompt, sql, entityDefinitions)
|
||||
return res.json(result)
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`AI SQL editing failed: ${
|
||||
isError(error) ? error.message : 'An unknown error occurred'
|
||||
}, sqlResponseString: ${sqlResponseString}`
|
||||
)
|
||||
if (error instanceof Error) {
|
||||
console.error(`AI SQL editing failed: ${error.message}`)
|
||||
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
if (error instanceof ContextLengthError) {
|
||||
// If there are more entity definitions than the SQL provided, attribute the
|
||||
// error to the database metadata
|
||||
if (hasEntityDefinitions) {
|
||||
const definitionsLength = entityDefinitions.reduce(
|
||||
(sum: number, def: string) => sum + def.length,
|
||||
0
|
||||
)
|
||||
if (definitionsLength > sql.length) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
|
||||
})
|
||||
}
|
||||
}
|
||||
// Otherwise attribute the error to the SQL being too large
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
|
||||
})
|
||||
}
|
||||
|
||||
if (error instanceof EmptySqlError) {
|
||||
res.status(400).json({
|
||||
error: 'Unable to edit SQL. Try adding more details to your prompt.',
|
||||
})
|
||||
}
|
||||
} else {
|
||||
console.log(`Unknown error: ${error}`)
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error editing the SQL snippet. Please try again.',
|
||||
|
||||
@@ -1,44 +1,10 @@
|
||||
import { SchemaBuilder } from '@serafin/schema-builder'
|
||||
import { codeBlock, stripIndent } from 'common-tags'
|
||||
import { isError } from 'data/utils/error-check'
|
||||
import { jsonrepair } from 'jsonrepair'
|
||||
import { ContextLengthError, EmptySqlError, generateSql } from 'ai-commands'
|
||||
import apiWrapper from 'lib/api/apiWrapper'
|
||||
import { NextApiRequest, NextApiResponse } from 'next'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
const openAiKey = process.env.OPENAI_KEY
|
||||
|
||||
const generateSqlSchema = SchemaBuilder.emptySchema()
|
||||
.addString('sql', {
|
||||
description: stripIndent`
|
||||
The generated SQL (must be valid SQL).
|
||||
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
|
||||
- Prefer creating foreign key references in the create statement
|
||||
- Prefer 'text' over 'varchar'
|
||||
- Prefer 'timestamp with time zone' over 'date'
|
||||
- Use vector(384) data type for any embedding/vector related query
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
`,
|
||||
})
|
||||
.addString('title', {
|
||||
description: stripIndent`
|
||||
The title of the SQL.
|
||||
- Omit words like 'SQL', 'Postgres', or 'Query'
|
||||
`,
|
||||
})
|
||||
|
||||
type GenerateSqlResult = typeof generateSqlSchema.T
|
||||
|
||||
const completionFunctions: Record<
|
||||
string,
|
||||
OpenAI.Chat.Completions.ChatCompletionCreateParams.Function
|
||||
> = {
|
||||
generateSql: {
|
||||
name: 'generateSql',
|
||||
description: 'Generates Postgres SQL based on a natural language prompt',
|
||||
parameters: generateSqlSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
}
|
||||
const openai = new OpenAI({ apiKey: openAiKey })
|
||||
|
||||
async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
if (!openAiKey) {
|
||||
@@ -59,102 +25,37 @@ async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
}
|
||||
|
||||
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
|
||||
const openAI = new OpenAI({ apiKey: openAiKey })
|
||||
const {
|
||||
body: { prompt, entityDefinitions },
|
||||
} = req
|
||||
|
||||
const model = 'gpt-3.5-turbo-0613'
|
||||
const maxCompletionTokenCount = 1024
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
|
||||
|
||||
if (hasEntityDefinitions) {
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my database schema for reference:
|
||||
${entityDefinitions.join('\n\n')}
|
||||
`,
|
||||
})
|
||||
}
|
||||
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
})
|
||||
|
||||
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model,
|
||||
messages: completionMessages,
|
||||
max_tokens: maxCompletionTokenCount,
|
||||
temperature: 0,
|
||||
function_call: {
|
||||
name: completionFunctions.generateSql.name,
|
||||
},
|
||||
functions: [completionFunctions.generateSql],
|
||||
stream: false,
|
||||
}
|
||||
|
||||
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
|
||||
try {
|
||||
completionResponse = await openAI.chat.completions.create(completionOptions)
|
||||
} catch (error: any) {
|
||||
console.error(`AI SQL generation failed: ${error.message}`)
|
||||
|
||||
if ('code' in error && error.code === 'context_length_exceeded' && hasEntityDefinitions) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
|
||||
})
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error generating the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
|
||||
const sqlResponseString = firstChoice.message?.function_call?.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
console.error(
|
||||
`AI SQL generation failed: OpenAI response succeeded, but response format was incorrect`
|
||||
)
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error generating the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const generateSqlResult: GenerateSqlResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!generateSqlResult.sql) {
|
||||
console.error(`AI SQL generation failed: Unable to generate SQL for the given prompt`)
|
||||
|
||||
res.status(400).json({
|
||||
error: 'Unable to generate SQL. Try adding more details to your prompt.',
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return res.json(generateSqlResult)
|
||||
const result = await generateSql(openai, prompt, entityDefinitions)
|
||||
return res.json(result)
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`AI SQL editing failed: ${
|
||||
isError(error) ? error.message : 'An unknown error occurred'
|
||||
}, sqlResponseString: ${sqlResponseString}`
|
||||
)
|
||||
if (error instanceof Error) {
|
||||
console.error(`AI SQL generation failed: ${error.message}`)
|
||||
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
if (error instanceof ContextLengthError && hasEntityDefinitions) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your database metadata is too large for Supabase AI to ingest. Try disabling database metadata in AI settings.',
|
||||
})
|
||||
}
|
||||
|
||||
if (error instanceof EmptySqlError) {
|
||||
res.status(400).json({
|
||||
error: 'Unable to generate SQL. Try adding more details to your prompt.',
|
||||
})
|
||||
}
|
||||
} else {
|
||||
console.log(`Unknown error: ${error}`)
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error editing the SQL snippet. Please try again.',
|
||||
error: 'There was an unknown error generating the SQL snippet. Please try again.',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { OpenAIStream, StreamingTextResponse } from 'ai'
|
||||
import { codeBlock, oneLine, stripIndent } from 'common-tags'
|
||||
import { StreamingTextResponse } from 'ai'
|
||||
import { chatRlsPolicy } from 'ai-commands/edge'
|
||||
import { NextRequest } from 'next/server'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
@@ -47,72 +47,10 @@ async function handlePost(request: NextRequest) {
|
||||
|
||||
const { messages, entityDefinitions, policyDefinition } = body
|
||||
|
||||
const initMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
|
||||
{
|
||||
role: 'system',
|
||||
content: stripIndent`
|
||||
You're an Postgres expert in writing row level security policies. Your purpose is to
|
||||
generate a policy with the constraints given by the user. You will be provided a schema
|
||||
on which the policy should be applied.
|
||||
|
||||
The output should use the following instructions:
|
||||
- The generated SQL must be valid SQL.
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
- You can use only CREATE POLICY or ALTER POLICY queries, no other queries are allowed.
|
||||
- You can add short explanations to your messages.
|
||||
- The result should be a valid markdown. The SQL code should be wrapped in \`\`\`.
|
||||
- Always use "auth.uid()" instead of "current_user".
|
||||
- You can't use "USING" expression on INSERT policies.
|
||||
- Only use "WITH CHECK" expression on INSERT or UPDATE policies.
|
||||
- The policy name should be short text explaining the policy, enclosed in double quotes.
|
||||
- Always put explanations as separate text. Never use inline SQL comments.
|
||||
- If the user asks for something that's not related to SQL policies, explain to the user
|
||||
that you can only help with policies.
|
||||
|
||||
The output should look like this:
|
||||
"CREATE POLICY user_policy ON users FOR INSERT USING (user_name = current_user) WITH (true);"
|
||||
`,
|
||||
},
|
||||
]
|
||||
|
||||
if (entityDefinitions) {
|
||||
const definitions = codeBlock`${entityDefinitions.join('\n\n')}`
|
||||
initMessages.push({
|
||||
role: 'user',
|
||||
content: oneLine`Here is my database schema for reference: ${definitions}`,
|
||||
})
|
||||
}
|
||||
|
||||
if (policyDefinition !== undefined) {
|
||||
const definitionBlock = codeBlock`${policyDefinition}`
|
||||
initMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my policy definition for reference:
|
||||
${definitionBlock}
|
||||
`.trim(),
|
||||
})
|
||||
}
|
||||
|
||||
if (messages) {
|
||||
initMessages.push(...messages)
|
||||
}
|
||||
|
||||
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: 'gpt-3.5-turbo-1106',
|
||||
messages: initMessages,
|
||||
max_tokens: 1024,
|
||||
temperature: 0,
|
||||
stream: true,
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await openai.chat.completions.create(completionOptions)
|
||||
// Proxy the streamed SSE response from OpenAI
|
||||
const stream = OpenAIStream(response)
|
||||
|
||||
const stream = await chatRlsPolicy(openai, messages, entityDefinitions, policyDefinition)
|
||||
return new StreamingTextResponse(stream)
|
||||
} catch (error: any) {
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
|
||||
return new Response(
|
||||
|
||||
@@ -1,36 +1,10 @@
|
||||
import { SchemaBuilder } from '@serafin/schema-builder'
|
||||
import { stripIndent } from 'common-tags'
|
||||
import { isError } from 'data/utils/error-check'
|
||||
import { jsonrepair } from 'jsonrepair'
|
||||
import { ContextLengthError, titleSql } from 'ai-commands'
|
||||
import apiWrapper from 'lib/api/apiWrapper'
|
||||
import { NextApiRequest, NextApiResponse } from 'next'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
const openAiKey = process.env.OPENAI_KEY
|
||||
|
||||
const generateTitleSchema = SchemaBuilder.emptySchema()
|
||||
.addString('title', {
|
||||
description: stripIndent`
|
||||
The generated title for the SQL snippet (short and concise).
|
||||
- Omit these words: 'SQL', 'Postgres', 'Query', 'Database'
|
||||
`,
|
||||
})
|
||||
.addString('description', {
|
||||
description: stripIndent`
|
||||
The generated description for the SQL snippet (longer and more detailed than title).
|
||||
- Read the SQL line by line and summarize it
|
||||
`,
|
||||
})
|
||||
|
||||
type GenerateTitleResult = typeof generateTitleSchema.T
|
||||
|
||||
const completionFunctions: Record<string, OpenAI.ChatCompletionCreateParams.Function> = {
|
||||
generateTitle: {
|
||||
name: 'generateTitle',
|
||||
description: 'Generates a short title and detailed description for a Postgres SQL snippet',
|
||||
parameters: generateTitleSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
}
|
||||
const openai = new OpenAI({ apiKey: openAiKey })
|
||||
|
||||
async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
if (!openAiKey) {
|
||||
@@ -51,89 +25,29 @@ async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
}
|
||||
|
||||
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
|
||||
const openAI = new OpenAI({ apiKey: openAiKey })
|
||||
const {
|
||||
body: { sql },
|
||||
} = req
|
||||
|
||||
const model = 'gpt-3.5-turbo-0613'
|
||||
const maxCompletionTokenCount = 1024
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: sql,
|
||||
},
|
||||
]
|
||||
|
||||
const completionOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
||||
model,
|
||||
messages: completionMessages,
|
||||
max_tokens: maxCompletionTokenCount,
|
||||
temperature: 0,
|
||||
function_call: {
|
||||
name: completionFunctions.generateTitle.name,
|
||||
},
|
||||
functions: [completionFunctions.generateTitle],
|
||||
stream: false,
|
||||
}
|
||||
|
||||
let completionResponse: OpenAI.Chat.Completions.ChatCompletion
|
||||
try {
|
||||
completionResponse = await openAI.chat.completions.create(completionOptions)
|
||||
} catch (error: any) {
|
||||
console.error(`AI title generation failed: ${error.message}`)
|
||||
|
||||
if ('code' in error && error.code === 'context_length_exceeded') {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
|
||||
})
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error generating the snippet title. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
|
||||
const titleResponseString = firstChoice.message?.function_call?.arguments
|
||||
|
||||
if (!titleResponseString) {
|
||||
console.error(
|
||||
`AI title generation failed: OpenAI response succeeded, but response format was incorrect`
|
||||
)
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error generating the snippet title. Please try again.',
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
// Attempt to repair broken JSON from OpenAI (eg. multiline strings)
|
||||
const repairedJsonString = jsonrepair(titleResponseString)
|
||||
|
||||
const generateTitleResult: GenerateTitleResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!generateTitleResult.title) {
|
||||
console.error(`AI title generation failed: Unable to generate title for the given SQL`)
|
||||
|
||||
res.status(400).json({
|
||||
error: 'Unable to generate title',
|
||||
})
|
||||
}
|
||||
|
||||
return res.json(generateTitleResult)
|
||||
const result = await titleSql(openai, sql)
|
||||
return res.json(result)
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`AI SQL editing failed: ${
|
||||
isError(error) ? error.message : 'An unknown error occurred'
|
||||
}, titleResponseString: ${titleResponseString}`
|
||||
)
|
||||
if (error instanceof Error) {
|
||||
console.error(`AI title generation failed: ${error.message}`)
|
||||
|
||||
if (error instanceof ContextLengthError) {
|
||||
return res.status(400).json({
|
||||
error:
|
||||
'Your SQL query is too large for Supabase AI to ingest. Try splitting it into smaller queries.',
|
||||
})
|
||||
}
|
||||
} else {
|
||||
console.log(`Unknown error: ${error}`)
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: 'There was an unknown error editing the SQL snippet. Please try again.',
|
||||
error: 'There was an unknown error generating the snippet title. Please try again.',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+475
-286
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
# ai-commands
|
||||
|
||||
## Main purpose
|
||||
|
||||
This package contains all features involving OpenAI API. Technically, each feature is implemented as a function which
|
||||
can be easily tested for regressions.
|
||||
|
||||
The streaming functions only work on Edge runtime so they can only be imported via a special `edge` subpath like so:
|
||||
|
||||
```
|
||||
import { chatRlsPolicy } from 'ai-commands/edge'
|
||||
```
|
||||
@@ -0,0 +1,4 @@
|
||||
module.exports = {
|
||||
presets: [['@babel/preset-env', { targets: { node: 'current' } }]],
|
||||
plugins: ['babel-plugin-transform-import-meta'],
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
export * from './src/errors'
|
||||
export * from './src/sql.edge'
|
||||
@@ -0,0 +1,2 @@
|
||||
export * from './src/errors'
|
||||
export * from './src/sql'
|
||||
@@ -0,0 +1,14 @@
|
||||
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
transform: {
|
||||
'^.+\\.ts?$': 'ts-jest',
|
||||
'^.+\\.(js|jsx)$': 'babel-jest',
|
||||
},
|
||||
setupFiles: ['./test/setup.ts'],
|
||||
testTimeout: 15000,
|
||||
transformIgnorePatterns: [
|
||||
'node_modules/(?!(mdast-.*|micromark|micromark-.*|unist-.*|decode-named-character-reference|character-entities)/)',
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"name": "ai-commands",
|
||||
"version": "0.0.0",
|
||||
"main": "./index.ts",
|
||||
"types": "./index.ts",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
"typecheck": "tsc --noEmit",
|
||||
"test": "jest"
|
||||
},
|
||||
"dependencies": {
|
||||
"@serafin/schema-builder": "^0.18.5",
|
||||
"ai": "^2.2.29",
|
||||
"common-tags": "^1.8.2",
|
||||
"config": "*",
|
||||
"jsonrepair": "^3.5.0",
|
||||
"openai": "^4.20.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.23.6",
|
||||
"@babel/preset-env": "^7.23.6",
|
||||
"@jest/globals": "^29.7.0",
|
||||
"@types/common-tags": "^1.8.4",
|
||||
"babel-jest": "^29.7.0",
|
||||
"babel-plugin-transform-import-meta": "^2.2.1",
|
||||
"dotenv": "^16.3.1",
|
||||
"jest": "^29.7.0",
|
||||
"mdast-util-from-markdown": "^2.0.0",
|
||||
"sql-formatter": "^15.0.2",
|
||||
"ts-jest": "^29.1.1",
|
||||
"tsconfig": "*",
|
||||
"typescript": "^5.2.2"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// Jest Snapshot v1, https://goo.gl/fbAQLP
|
||||
|
||||
exports[`debug fix order of operations 1`] = `
|
||||
"create table departments (
|
||||
id bigint primary key generated always as identity,
|
||||
name text
|
||||
);
|
||||
|
||||
create table employees (
|
||||
id bigint primary key generated always as identity,
|
||||
name text,
|
||||
email text,
|
||||
department_id bigint references departments (id)
|
||||
);"
|
||||
`;
|
||||
|
||||
exports[`debug fix typos 1`] = `
|
||||
"select
|
||||
*
|
||||
from
|
||||
employees;"
|
||||
`;
|
||||
|
||||
exports[`edit add length constraint 1`] = `
|
||||
"create table employees (
|
||||
id bigint primary key generated always as identity,
|
||||
name text check (length(name) >= 4),
|
||||
email text
|
||||
);"
|
||||
`;
|
||||
|
||||
exports[`generate single table with specified columns 1`] = `
|
||||
"create table employees (
|
||||
id bigint primary key generated always as identity,
|
||||
name text,
|
||||
email text,
|
||||
position text
|
||||
);"
|
||||
`;
|
||||
|
||||
exports[`generate single table with specified columns 2`] = `"Employee Tracking Table"`;
|
||||
|
||||
exports[`rls chat select policy using table definition 1`] = `
|
||||
"create policy select_todo_policy on todos for
|
||||
select
|
||||
using (user_id = auth.uid ());"
|
||||
`;
|
||||
|
||||
exports[`title title matches content 1`] = `"Employee and Department Tables"`;
|
||||
|
||||
exports[`title title matches content 2`] = `"Tables to track employees and their respective departments"`;
|
||||
@@ -0,0 +1,23 @@
|
||||
export class ApplicationError extends Error {
|
||||
constructor(message: string, public data: Record<string, any> = {}) {
|
||||
super(message)
|
||||
}
|
||||
}
|
||||
|
||||
export class ContextLengthError extends ApplicationError {
|
||||
constructor() {
|
||||
super('LLM context length exceeded')
|
||||
}
|
||||
}
|
||||
|
||||
export class EmptyResponseError extends ApplicationError {
|
||||
constructor() {
|
||||
super('LLM API response succeeded but returned nothing')
|
||||
}
|
||||
}
|
||||
|
||||
export class EmptySqlError extends ApplicationError {
|
||||
constructor() {
|
||||
super('LLM did not generate any SQL')
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
import { OpenAIStream } from 'ai'
|
||||
import { codeBlock, oneLine, stripIndent } from 'common-tags'
|
||||
import OpenAI from 'openai'
|
||||
import { ContextLengthError } from './errors'
|
||||
|
||||
export type AiAssistantMessage = {
|
||||
content: string
|
||||
role: 'user' | 'assistant'
|
||||
}
|
||||
|
||||
/**
|
||||
* Responds to a conversation about building an RLS policy.
|
||||
*
|
||||
* @returns A `ReadableStream` containing the response text and SQL.
|
||||
*/
|
||||
export async function chatRlsPolicy(
|
||||
openai: OpenAI,
|
||||
messages: AiAssistantMessage[],
|
||||
entityDefinitions?: string[],
|
||||
policyDefinition?: string
|
||||
): Promise<ReadableStream<Uint8Array>> {
|
||||
const initMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
|
||||
{
|
||||
role: 'system',
|
||||
content: stripIndent`
|
||||
You're an Postgres expert in writing row level security policies. Your purpose is to
|
||||
generate a policy with the constraints given by the user. You will be provided a schema
|
||||
on which the policy should be applied.
|
||||
|
||||
The output should use the following instructions:
|
||||
- The generated SQL must be valid SQL.
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
- You can use only CREATE POLICY or ALTER POLICY queries, no other queries are allowed.
|
||||
- You can add short explanations to your messages.
|
||||
- The result should be a valid markdown. The SQL code should be wrapped in \`\`\`.
|
||||
- Always use "auth.uid()" instead of "current_user".
|
||||
- You can't use "USING" expression on INSERT policies.
|
||||
- Only use "WITH CHECK" expression on INSERT or UPDATE policies.
|
||||
- The policy name should be short text explaining the policy, enclosed in double quotes.
|
||||
- Always put explanations as separate text. Never use inline SQL comments.
|
||||
- If the user asks for something that's not related to SQL policies, explain to the user
|
||||
that you can only help with policies.
|
||||
|
||||
The output should look like this:
|
||||
"CREATE POLICY user_policy ON users FOR INSERT USING (user_name = current_user) WITH (true);"
|
||||
`,
|
||||
},
|
||||
]
|
||||
|
||||
if (entityDefinitions) {
|
||||
const definitions = codeBlock`${entityDefinitions.join('\n\n')}`
|
||||
initMessages.push({
|
||||
role: 'user',
|
||||
content: oneLine`Here is my database schema for reference: ${definitions}`,
|
||||
})
|
||||
}
|
||||
|
||||
if (policyDefinition !== undefined) {
|
||||
const definitionBlock = codeBlock`${policyDefinition}`
|
||||
initMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my policy definition for reference:
|
||||
${definitionBlock}
|
||||
`.trim(),
|
||||
})
|
||||
}
|
||||
|
||||
if (messages) {
|
||||
initMessages.push(...messages)
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await openai.chat.completions.create({
|
||||
model: 'gpt-3.5-turbo-1106',
|
||||
messages: initMessages,
|
||||
max_tokens: 1024,
|
||||
temperature: 0,
|
||||
stream: true,
|
||||
})
|
||||
|
||||
// Transform the streamed SSE response from OpenAI to a ReadableStream
|
||||
return OpenAIStream(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
|
||||
throw new ContextLengthError()
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
import { describe, expect, test } from '@jest/globals'
|
||||
import { codeBlock } from 'common-tags'
|
||||
import OpenAI from 'openai'
|
||||
import { collectStream, extractMarkdownSql, formatSql } from '../test/util'
|
||||
import { debugSql, editSql, generateSql, titleSql } from './sql'
|
||||
import { chatRlsPolicy } from './sql.edge'
|
||||
|
||||
const openAiKey = process.env.OPENAI_KEY
|
||||
const openai = new OpenAI({ apiKey: openAiKey })
|
||||
|
||||
describe('generate', () => {
|
||||
test('single table with specified columns', async () => {
|
||||
const { sql, title } = await generateSql(
|
||||
openai,
|
||||
'create a table to track employees with name, email, and position'
|
||||
)
|
||||
|
||||
expect(formatSql(sql)).toMatchSnapshot()
|
||||
expect(title).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('edit', () => {
|
||||
test('add length constraint', async () => {
|
||||
const { sql } = await editSql(
|
||||
openai,
|
||||
'force name to be at least 4 characters',
|
||||
codeBlock`
|
||||
create table employees (
|
||||
id bigint primary key generated always as identity,
|
||||
name text,
|
||||
email text
|
||||
);
|
||||
`
|
||||
)
|
||||
|
||||
expect(formatSql(sql)).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('debug', () => {
|
||||
test('fix order of operations', async () => {
|
||||
const { sql } = await debugSql(
|
||||
openai,
|
||||
'relation "departments" does not exist',
|
||||
codeBlock`
|
||||
create table employees (
|
||||
id bigint primary key generated always as identity,
|
||||
name text,
|
||||
email text,
|
||||
department_id bigint references departments (id)
|
||||
);
|
||||
|
||||
create table departments (
|
||||
id bigint primary key generated always as identity,
|
||||
name text
|
||||
);
|
||||
`
|
||||
)
|
||||
|
||||
expect(formatSql(sql)).toMatchSnapshot()
|
||||
})
|
||||
|
||||
test('fix typos', async () => {
|
||||
const { sql, solution } = await debugSql(
|
||||
openai,
|
||||
'syntax error at or near "fromm"',
|
||||
codeBlock`
|
||||
select * fromm employees;
|
||||
`
|
||||
)
|
||||
|
||||
expect(solution).toBeDefined()
|
||||
expect(formatSql(sql)).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('title', () => {
|
||||
test('title matches content', async () => {
|
||||
const { title, description } = await titleSql(
|
||||
openai,
|
||||
codeBlock`
|
||||
create table employees (
|
||||
id bigint primary key generated always as identity,
|
||||
name text,
|
||||
email text,
|
||||
department_id bigint references departments (id)
|
||||
);
|
||||
|
||||
create table departments (
|
||||
id bigint primary key generated always as identity,
|
||||
name text
|
||||
);
|
||||
`
|
||||
)
|
||||
|
||||
expect(title).toMatchSnapshot()
|
||||
expect(description).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('rls chat', () => {
|
||||
test('select policy using table definition', async () => {
|
||||
const responseStream = await chatRlsPolicy(
|
||||
openai,
|
||||
[
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Users can only select their own todos',
|
||||
},
|
||||
],
|
||||
[
|
||||
codeBlock`
|
||||
create table todos (
|
||||
id bigint primary key generated always as identity,
|
||||
task text,
|
||||
email text,
|
||||
user_id uuid references auth.users (id)
|
||||
);
|
||||
`,
|
||||
]
|
||||
)
|
||||
const responseText = await collectStream(responseStream)
|
||||
const [sql] = extractMarkdownSql(responseText)
|
||||
|
||||
expect(formatSql(sql)).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,409 @@
|
||||
import { SchemaBuilder } from '@serafin/schema-builder'
|
||||
import { codeBlock, stripIndent } from 'common-tags'
|
||||
import { jsonrepair } from 'jsonrepair'
|
||||
import OpenAI from 'openai'
|
||||
import { ContextLengthError, EmptyResponseError, EmptySqlError } from './errors'
|
||||
|
||||
// Declare JSON schema for each function that the LLM can call
|
||||
const generateSqlSchema = SchemaBuilder.emptySchema()
|
||||
.addString('sql', {
|
||||
description: stripIndent`
|
||||
The generated SQL (must be valid SQL).
|
||||
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
|
||||
- Prefer creating foreign key references in the create statement
|
||||
- Prefer 'text' over 'varchar'
|
||||
- Prefer 'timestamp with time zone' over 'date'
|
||||
- Use vector(384) data type for any embedding/vector related query
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
`,
|
||||
})
|
||||
.addString('title', {
|
||||
description: stripIndent`
|
||||
The title of the SQL.
|
||||
- Omit words like 'SQL', 'Postgres', or 'Query'
|
||||
`,
|
||||
})
|
||||
|
||||
const editSqlSchema = SchemaBuilder.emptySchema().addString('sql', {
|
||||
description: stripIndent`
|
||||
The modified SQL (must be valid SQL).
|
||||
- Assume the query hasn't been executed yet
|
||||
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
|
||||
- When creating tables, always add foreign key references inline
|
||||
- Prefer 'text' over 'varchar'
|
||||
- Prefer 'timestamp with time zone' over 'date'
|
||||
- Use vector(384) data type for any embedding/vector related query
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
- Use real examples when possible
|
||||
- Add constraints if requested
|
||||
`,
|
||||
})
|
||||
|
||||
const debugSqlSchema = SchemaBuilder.emptySchema()
|
||||
.addString('solution', {
|
||||
description: 'A short suggested solution for the error (as concise as possible).',
|
||||
})
|
||||
.addString('sql', {
|
||||
description: 'The SQL rewritten to apply the solution. Includes all the original logic, but modified to fix the issue.',
|
||||
})
|
||||
|
||||
const generateTitleSchema = SchemaBuilder.emptySchema()
|
||||
.addString('title', {
|
||||
description: stripIndent`
|
||||
The generated title for the SQL snippet (short and concise).
|
||||
- Omit these words: 'SQL', 'Postgres', 'Query', 'Database'
|
||||
`,
|
||||
})
|
||||
.addString('description', {
|
||||
description: stripIndent`
|
||||
The generated description for the SQL snippet.
|
||||
`,
|
||||
})
|
||||
|
||||
// Reference auto-generated types for each JSON schema
|
||||
export type GenerateSqlResult = typeof generateSqlSchema.T
|
||||
export type EditSqlResult = typeof editSqlSchema.T
|
||||
export type DebugSqlResult = typeof debugSqlSchema.T
|
||||
export type GenerateTitleResult = typeof generateTitleSchema.T
|
||||
|
||||
// Combine the completion functions
|
||||
const completionFunctions = {
|
||||
generateSql: {
|
||||
name: 'generateSql',
|
||||
description: 'Generates Postgres SQL based on a natural language prompt',
|
||||
parameters: generateSqlSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
editSql: {
|
||||
name: 'editSql',
|
||||
description: "Edits a Postgres SQL query based on the user's instructions",
|
||||
parameters: editSqlSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
debugSql: {
|
||||
name: 'debugSql',
|
||||
description: stripIndent`
|
||||
Debugs a Postgres SQL error. Returns the fixed SQL and a solution explaining it.
|
||||
- Create extensions if they are missing (only for valid extensions)
|
||||
- Suggest creating tables if they are missing
|
||||
- Include all of the original SQL
|
||||
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
|
||||
- When creating tables, always add foreign key references inline
|
||||
- Prefer 'text' over 'varchar'
|
||||
- Prefer 'timestamp with time zone' over 'date'
|
||||
- Use vector(384) data type for any embedding/vector related query
|
||||
- Always use double apostrophe in SQL strings (eg. 'Night''s watch')
|
||||
`,
|
||||
parameters: debugSqlSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
generateTitle: {
|
||||
name: 'generateTitle',
|
||||
description: stripIndent`
|
||||
Generates a short title and summarized description for a Postgres SQL snippet.
|
||||
|
||||
The description should describe why this table was created (eg. "Table to track todos")
|
||||
`,
|
||||
parameters: generateTitleSchema.schema as Record<string, unknown>,
|
||||
},
|
||||
} satisfies Record<string, OpenAI.Chat.Completions.ChatCompletionCreateParams.Function>
|
||||
|
||||
/**
|
||||
* Generates a SQL snippet based on the provided prompt.
|
||||
*
|
||||
* @returns The generated SQL along with a title for it.
|
||||
*/
|
||||
export async function generateSql(openai: OpenAI, prompt: string, entityDefinitions?: string[]) {
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
|
||||
|
||||
if (hasEntityDefinitions) {
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my database schema for reference:
|
||||
${entityDefinitions.join('\n\n')}
|
||||
`,
|
||||
})
|
||||
}
|
||||
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
})
|
||||
|
||||
try {
|
||||
const completionResponse = await openai.chat.completions.create({
|
||||
model: 'gpt-3.5-turbo-1106',
|
||||
messages: completionMessages,
|
||||
max_tokens: 1024,
|
||||
temperature: 0,
|
||||
tool_choice: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: completionFunctions.generateSql.name,
|
||||
},
|
||||
},
|
||||
tools: [
|
||||
{
|
||||
type: 'function',
|
||||
function: completionFunctions.generateSql,
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
const [firstTool] = firstChoice.message?.tool_calls ?? []
|
||||
|
||||
const sqlResponseString = firstTool?.function.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
throw new EmptyResponseError()
|
||||
}
|
||||
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const result: GenerateSqlResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!result.sql) {
|
||||
throw new EmptySqlError()
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
|
||||
throw new ContextLengthError()
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Modifies a SQL snippet based on the provided prompt.
|
||||
*
|
||||
* @returns The modified SQL.
|
||||
*/
|
||||
export async function editSql(
|
||||
openai: OpenAI,
|
||||
prompt: string,
|
||||
sql: string,
|
||||
entityDefinitions?: string[]
|
||||
) {
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
|
||||
|
||||
if (hasEntityDefinitions) {
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my database schema for reference:
|
||||
${entityDefinitions.join('\n\n')}
|
||||
`,
|
||||
})
|
||||
}
|
||||
|
||||
completionMessages.push(
|
||||
{
|
||||
role: 'user',
|
||||
content: stripIndent`
|
||||
Here is my current SQL:
|
||||
${sql}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
}
|
||||
)
|
||||
|
||||
try {
|
||||
const completionResponse = await openai.chat.completions.create({
|
||||
model: 'gpt-3.5-turbo-1106',
|
||||
messages: completionMessages,
|
||||
max_tokens: 2048,
|
||||
temperature: 0,
|
||||
tool_choice: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: completionFunctions.editSql.name,
|
||||
},
|
||||
},
|
||||
tools: [
|
||||
{
|
||||
type: 'function',
|
||||
function: completionFunctions.editSql,
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
const [firstTool] = firstChoice.message?.tool_calls ?? []
|
||||
|
||||
const sqlResponseString = firstTool?.function.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
throw new EmptyResponseError()
|
||||
}
|
||||
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const result: EditSqlResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!result.sql) {
|
||||
throw new EmptySqlError()
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
|
||||
throw new ContextLengthError()
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Debugs SQL errors.
|
||||
*
|
||||
* @returns A suggested SQL fix along with a solution explanation.
|
||||
*/
|
||||
export async function debugSql(
|
||||
openai: OpenAI,
|
||||
errorMessage: string,
|
||||
sql: string,
|
||||
entityDefinitions?: string[]
|
||||
) {
|
||||
const hasEntityDefinitions = entityDefinitions !== undefined && entityDefinitions.length > 0
|
||||
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []
|
||||
|
||||
if (hasEntityDefinitions) {
|
||||
completionMessages.push({
|
||||
role: 'user',
|
||||
content: codeBlock`
|
||||
Here is my database schema for reference:
|
||||
${entityDefinitions.join('\n\n')}
|
||||
`,
|
||||
})
|
||||
}
|
||||
|
||||
completionMessages.push(
|
||||
{
|
||||
role: 'user',
|
||||
content: stripIndent`
|
||||
Here is my current SQL:
|
||||
${sql}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: stripIndent`
|
||||
Here is the error I am getting:
|
||||
${errorMessage}
|
||||
`,
|
||||
}
|
||||
)
|
||||
|
||||
try {
|
||||
const completionResponse = await openai.chat.completions.create({
|
||||
model: 'gpt-3.5-turbo-1106',
|
||||
messages: completionMessages,
|
||||
max_tokens: 2048,
|
||||
temperature: 0,
|
||||
tool_choice: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: completionFunctions.debugSql.name,
|
||||
},
|
||||
},
|
||||
tools: [
|
||||
{
|
||||
type: 'function',
|
||||
function: completionFunctions.debugSql,
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
const [firstTool] = firstChoice.message?.tool_calls ?? []
|
||||
|
||||
const sqlResponseString = firstTool?.function.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
throw new EmptyResponseError()
|
||||
}
|
||||
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const result: DebugSqlResult = JSON.parse(repairedJsonString)
|
||||
|
||||
if (!result.sql) {
|
||||
throw new EmptySqlError()
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
|
||||
throw new ContextLengthError()
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a snippet title based on the provided SQL.
|
||||
*
|
||||
* @returns A title and description for the SQL snippet.
|
||||
*/
|
||||
export async function titleSql(openai: OpenAI, sql: string) {
|
||||
const completionMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: sql,
|
||||
},
|
||||
]
|
||||
|
||||
try {
|
||||
const completionResponse = await openai.chat.completions.create({
|
||||
model: 'gpt-3.5-turbo-1106',
|
||||
messages: completionMessages,
|
||||
max_tokens: 1024,
|
||||
temperature: 0,
|
||||
tool_choice: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: completionFunctions.generateTitle.name,
|
||||
},
|
||||
},
|
||||
tools: [
|
||||
{
|
||||
type: 'function',
|
||||
function: completionFunctions.generateTitle,
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const [firstChoice] = completionResponse.choices
|
||||
const [firstTool] = firstChoice.message?.tool_calls ?? []
|
||||
|
||||
const sqlResponseString = firstTool?.function.arguments
|
||||
|
||||
if (!sqlResponseString) {
|
||||
throw new EmptyResponseError()
|
||||
}
|
||||
|
||||
const repairedJsonString = jsonrepair(sqlResponseString)
|
||||
|
||||
const result: GenerateTitleResult = JSON.parse(repairedJsonString)
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code === 'context_length_exceeded') {
|
||||
throw new ContextLengthError()
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
import { config } from 'dotenv'
|
||||
import { statSync } from 'fs'
|
||||
|
||||
// Use studio .env.local for now
|
||||
const envPath = '../../apps/studio/.env.local'
|
||||
|
||||
statSync(envPath)
|
||||
config({ path: envPath })
|
||||
@@ -0,0 +1,47 @@
|
||||
import { fromMarkdown } from 'mdast-util-from-markdown'
|
||||
import type { Code } from 'mdast-util-from-markdown/lib'
|
||||
import { format } from 'sql-formatter'
|
||||
|
||||
declare global {
|
||||
interface ReadableStream<R = any> {
|
||||
[Symbol.asyncIterator](): AsyncIterableIterator<R>
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats Postgres SQL into a consistent format.
|
||||
*
|
||||
* @returns The formatted SQL.
|
||||
*/
|
||||
export const formatSql = (sql: string) =>
|
||||
format(sql, { language: 'postgresql', keywordCase: 'lower' })
|
||||
|
||||
/**
|
||||
* Collects an `ArrayBuffer` stream into a single decoded string.
|
||||
*
|
||||
* @returns A single string combining all the decoded stream chunks.
|
||||
*/
|
||||
export async function collectStream<R extends BufferSource>(stream: ReadableStream<R>) {
|
||||
const textDecoderStream = new TextDecoderStream()
|
||||
|
||||
let content = ''
|
||||
|
||||
for await (const chunk of stream.pipeThrough(textDecoderStream)) {
|
||||
content += chunk
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses markdown and extracts all SQL code blocks.
|
||||
*
|
||||
* @returns An array of string content from each SQL code block.
|
||||
*/
|
||||
export function extractMarkdownSql(markdown: string) {
|
||||
const mdTree = fromMarkdown(markdown)
|
||||
|
||||
return mdTree.children
|
||||
.filter((node): node is Code => node.type === 'code' && node.lang === 'sql')
|
||||
.map(({ value }) => value)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "tsconfig/react-library.json",
|
||||
"include": ["."],
|
||||
"exclude": ["dist", "build", "node_modules"]
|
||||
}
|
||||
@@ -54,6 +54,7 @@
|
||||
"cache": false
|
||||
},
|
||||
"typecheck": {
|
||||
"dependsOn": ["^typecheck"],
|
||||
"outputs": ["**/node_modules/.cache/tsbuildinfo.json"]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user