diff --git a/README.md b/README.md index 10a0417..5d3e60a 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,55 @@ This works across the entire test suite. Note that if unique parameters are passed to the `beforeTemplateIsBaked` (`null` in the above example), separate databases will still be created. +### Manual template creation + +In some cases, if you do extensive setup in your `beforeTemplateIsBaked` hook, you might want to obtain a separate, additional database within it if your application uses several databases for different purposes. This is possible by using the `manuallyBuildAdditionalTemplate()` function passed to your hook callback: + +```ts +import test from "ava" + +const getTestDatabase = getTestPostgresDatabaseFactory({ + beforeTemplateIsBaked: async ({ + params, + connection: { pool }, + manuallyBuildAdditionalTemplate, + }) => { + await pool.query(`CREATE TABLE "bar" ("id" SERIAL PRIMARY KEY)`) + + const fooTemplateBuilder = await manuallyBuildAdditionalTemplate() + await fooTemplateBuilder.connection.pool.query( + `CREATE TABLE "foo" ("id" SERIAL PRIMARY KEY)` + ) + const { templateName: fooTemplateName } = await fooTemplateBuilder.finish() + + return { fooTemplateName } + }, +}) + +test("foo", async (t) => { + const barDatabase = await getTestDatabase({ type: "bar" }) + + // the "bar" database has the "bar" table... + await t.notThrowsAsync(async () => { + await barDatabase.pool.query(`SELECT * FROM "bar"`) + }) + + // ...but not the "foo" table... + await t.throwsAsync(async () => { + await barDatabase.pool.query(`SELECT * FROM "foo"`) + }) + + // ...and we can obtain a separate database with the "foo" table + const fooDatabase = await getTestDatabase.fromTemplate( + t, + barDatabase.beforeTemplateIsBakedResult.fooTemplateName + ) + await t.notThrowsAsync(async () => { + await fooDatabase.pool.query(`SELECT * FROM "foo"`) + }) +}) +``` + ### Bind mounts & `exec`ing in the container `ava-postgres` uses [testcontainers](https://www.npmjs.com/package/testcontainers) under the hood to manage the Postgres container. diff --git a/src/index.ts b/src/index.ts index 0d08122..6861bb4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,12 +12,12 @@ import type { GetTestPostgresDatabase, GetTestPostgresDatabaseFactoryOptions, GetTestPostgresDatabaseOptions, + GetTestPostgresDatabaseResult, } from "./public-types" import { Pool } from "pg" import type { Jsonifiable } from "type-fest" import type { ExecutionContext } from "ava" -import { once } from "node:events" -import { createBirpc } from "birpc" +import { BirpcReturn, createBirpc } from "birpc" import { ExecResult } from "testcontainers" import isPlainObject from "lodash/isPlainObject" @@ -105,7 +105,7 @@ export const getTestPostgresDatabaseFactory = < Params extends Jsonifiable = never >( options?: GetTestPostgresDatabaseFactoryOptions -) => { +): GetTestPostgresDatabase => { const initialData: InitialWorkerData = { postgresVersion: options?.postgresVersion ?? "14", containerOptions: options?.container, @@ -136,57 +136,73 @@ export const getTestPostgresDatabaseFactory = < } let rpcCallback: (data: any) => void - const rpc = createBirpc( - { - runBeforeTemplateIsBakedHook: async (connection, params) => { - if (options?.beforeTemplateIsBaked) { - const connectionDetails = - mapWorkerConnectionDetailsToConnectionDetails(connection) - - // Ignore if the pool is terminated by the shared worker - // (This happens in CI for some reason even though we drain the pool first.) - connectionDetails.pool.on("error", (error) => { - if ( - error.message.includes( - "terminating connection due to administrator command" - ) - ) { - return - } + const rpc: BirpcReturn = + createBirpc( + { + runBeforeTemplateIsBakedHook: async (connection, params) => { + if (options?.beforeTemplateIsBaked) { + const connectionDetails = + mapWorkerConnectionDetailsToConnectionDetails(connection) - throw error - }) + // Ignore if the pool is terminated by the shared worker + // (This happens in CI for some reason even though we drain the pool first.) + connectionDetails.pool.on("error", (error) => { + if ( + error.message.includes( + "terminating connection due to administrator command" + ) + ) { + return + } - const hookResult = await options.beforeTemplateIsBaked({ - params: params as any, - connection: connectionDetails, - containerExec: async (command): Promise => - rpc.execCommandInContainer(command), - }) + throw error + }) - await teardownConnection(connectionDetails) + const hookResult = await options.beforeTemplateIsBaked({ + params: params as any, + connection: connectionDetails, + containerExec: async (command): Promise => + rpc.execCommandInContainer(command), + // This is what allows a consumer to get a "nested" database from within their beforeTemplateIsBaked hook + manuallyBuildAdditionalTemplate: async () => { + const connection = + mapWorkerConnectionDetailsToConnectionDetails( + await rpc.createEmptyDatabase() + ) - if (hookResult && !isSerializable(hookResult)) { - throw new TypeError( - "Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values." - ) - } + return { + connection, + finish: async () => { + await teardownConnection(connection) + return rpc.convertDatabaseToTemplate(connection.database) + }, + } + }, + }) - return hookResult - } - }, - }, - { - post: async (data) => { - const worker = await workerPromise - await worker.available - worker.publish(data) - }, - on: (data) => { - rpcCallback = data + await teardownConnection(connectionDetails) + + if (hookResult && !isSerializable(hookResult)) { + throw new TypeError( + "Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values." + ) + } + + return hookResult + } + }, }, - } - ) + { + post: async (data) => { + const worker = await workerPromise + await worker.available + worker.publish(data) + }, + on: (data) => { + rpcCallback = data + }, + } + ) // Automatically cleaned up by AVA since each test file runs in a separate worker const _messageHandlerPromise = (async () => { @@ -198,11 +214,11 @@ export const getTestPostgresDatabaseFactory = < } })() - const getTestPostgresDatabase: GetTestPostgresDatabase = async ( + const getTestPostgresDatabase = async ( t: ExecutionContext, params: any, getTestDatabaseOptions?: GetTestPostgresDatabaseOptions - ) => { + ): Promise => { const testDatabaseConnection = await rpc.getTestDatabase({ databaseDedupeKey: getTestDatabaseOptions?.databaseDedupeKey, params, @@ -223,7 +239,22 @@ export const getTestPostgresDatabaseFactory = < } } - return getTestPostgresDatabase + getTestPostgresDatabase.fromTemplate = async ( + t: ExecutionContext, + templateName: string + ) => { + const connection = mapWorkerConnectionDetailsToConnectionDetails( + await rpc.createDatabaseFromTemplate(templateName) + ) + + t.teardown(async () => { + await teardownConnection(connection) + }) + + return connection + } + + return getTestPostgresDatabase as any } export * from "./public-types" diff --git a/src/internal-types.ts b/src/internal-types.ts index 2700f0e..20aeac6 100644 --- a/src/internal-types.ts +++ b/src/internal-types.ts @@ -30,4 +30,11 @@ export interface SharedWorkerFunctions { beforeTemplateIsBakedResult: unknown }> execCommandInContainer: (command: string[]) => Promise + createEmptyDatabase: () => Promise + createDatabaseFromTemplate: ( + templateName: string + ) => Promise + convertDatabaseToTemplate: ( + databaseName: string + ) => Promise<{ templateName: string }> } diff --git a/src/public-types.ts b/src/public-types.ts index b95e3a7..92e4d2c 100644 --- a/src/public-types.ts +++ b/src/public-types.ts @@ -1,8 +1,8 @@ import type { Pool } from "pg" import type { Jsonifiable } from "type-fest" -import { ExecutionContext } from "ava" -import { ExecResult } from "testcontainers" -import { BindMode } from "testcontainers/build/types" +import type { ExecutionContext } from "ava" +import type { ExecResult } from "testcontainers" +import type { BindMode } from "testcontainers/build/types" export interface ConnectionDetails { connectionString: string @@ -58,6 +58,59 @@ export interface GetTestPostgresDatabaseFactoryOptions< connection: ConnectionDetails params: Params containerExec: (command: string[]) => Promise + /** + * In some cases, if you do extensive setup in your `beforeTemplateIsBaked` hook, you might want to obtain a separate, additional database within it if your application uses several databases for different purposes. + * + * @example + * ```ts + * import test from "ava" + * + * const getTestDatabase = getTestPostgresDatabaseFactory({ + * beforeTemplateIsBaked: async ({ + * params, + * connection: { pool }, + * manuallyBuildAdditionalTemplate, + * }) => { + * await pool.query(`CREATE TABLE "bar" ("id" SERIAL PRIMARY KEY)`) + * + * const fooTemplateBuilder = await manuallyBuildAdditionalTemplate() + * await fooTemplateBuilder.connection.pool.query( + * `CREATE TABLE "foo" ("id" SERIAL PRIMARY KEY)` + * ) + * const { templateName: fooTemplateName } = await fooTemplateBuilder.finish() + * + * return { fooTemplateName } + * }, + * }) + * + * test("foo", async (t) => { + * const barDatabase = await getTestDatabase({ type: "bar" }) + * + * // the "bar" database has the "bar" table... + * await t.notThrowsAsync(async () => { + * await barDatabase.pool.query(`SELECT * FROM "bar"`) + * }) + * + * // ...but not the "foo" table... + * await t.throwsAsync(async () => { + * await barDatabase.pool.query(`SELECT * FROM "foo"`) + * }) + * + * // ...and we can obtain a separate database with the "foo" table + * const fooDatabase = await getTestDatabase.fromTemplate( + * t, + * barDatabase.beforeTemplateIsBakedResult.fooTemplateName + * ) + * await t.notThrowsAsync(async () => { + * await fooDatabase.pool.query(`SELECT * FROM "foo"`) + * }) + * }) + * ``` + */ + manuallyBuildAdditionalTemplate: () => Promise<{ + connection: ConnectionDetails + finish: () => Promise<{ templateName: string }> + }> }) => Promise } @@ -94,14 +147,23 @@ export type GetTestPostgresDatabaseOptions = { // https://github.com/microsoft/TypeScript/issues/23182#issuecomment-379091887 type IsNeverType = [T] extends [never] ? true : false +interface BaseGetTestPostgresDatabase { + fromTemplate( + t: ExecutionContext, + templateName: string + ): Promise +} + export type GetTestPostgresDatabase = IsNeverType extends true - ? ( + ? (( t: ExecutionContext, args?: null, options?: GetTestPostgresDatabaseOptions - ) => Promise - : ( + ) => Promise) & + BaseGetTestPostgresDatabase + : (( t: ExecutionContext, args: Params, options?: GetTestPostgresDatabaseOptions - ) => Promise + ) => Promise) & + BaseGetTestPostgresDatabase diff --git a/src/tests/hooks.test.ts b/src/tests/hooks.test.ts index 1944190..2340170 100644 --- a/src/tests/hooks.test.ts +++ b/src/tests/hooks.test.ts @@ -145,3 +145,41 @@ test("beforeTemplateIsBaked (result isn't serializable)", async (t) => { } ) }) + +test("beforeTemplateIsBaked with manual template build", async (t) => { + const getTestDatabase = getTestPostgresDatabaseFactory({ + postgresVersion: process.env.POSTGRES_VERSION, + workerDedupeKey: "beforeTemplateIsBakedHookManualTemplateBuild", + beforeTemplateIsBaked: async ({ + connection: { pool }, + manuallyBuildAdditionalTemplate, + }) => { + await pool.query(`CREATE TABLE "bar" ("id" SERIAL PRIMARY KEY)`) + + const fooTemplateBuilder = await manuallyBuildAdditionalTemplate() + await fooTemplateBuilder.connection.pool.query( + `CREATE TABLE "foo" ("id" SERIAL PRIMARY KEY)` + ) + const { templateName: fooTemplateName } = + await fooTemplateBuilder.finish() + + return { fooTemplateName } + }, + }) + + const barDatabase = await getTestDatabase(t) + t.truthy(barDatabase.beforeTemplateIsBakedResult.fooTemplateName) + + const fooDatabase = await getTestDatabase.fromTemplate( + t, + barDatabase.beforeTemplateIsBakedResult.fooTemplateName + ) + + await t.notThrowsAsync(async () => { + await fooDatabase.pool.query('SELECT * FROM "foo"') + }, "foo table should exist on database manually created from template") + + await t.throwsAsync(async () => { + await fooDatabase.pool.query('SELECT * FROM "bar"') + }) +}) diff --git a/src/worker.ts b/src/worker.ts index 8130e6f..6ca2ee0 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -73,6 +73,32 @@ export class Worker { const container = (await this.startContainerPromise).container return container.exec(command) }, + createEmptyDatabase: async () => { + const { postgresClient } = await this.startContainerPromise + const databaseName = getRandomDatabaseName() + await postgresClient.query(`CREATE DATABASE ${databaseName}`) + return this.getConnectionDetails(databaseName) + }, + createDatabaseFromTemplate: async (templateName) => { + const { postgresClient } = await this.startContainerPromise + const databaseName = getRandomDatabaseName() + await postgresClient.query( + `CREATE DATABASE ${databaseName} WITH TEMPLATE ${templateName};` + ) + + testWorker.teardown(async () => { + await this.teardownDatabase(databaseName) + }) + + return this.getConnectionDetails(databaseName) + }, + convertDatabaseToTemplate: async (databaseName) => { + const { postgresClient } = await this.startContainerPromise + await postgresClient.query( + `ALTER DATABASE ${databaseName} WITH is_template TRUE;` + ) + return { templateName: databaseName } + }, }, rpcChannel ) @@ -148,8 +174,7 @@ export class Worker { return } - await this.forceDisconnectClientsFrom(databaseName!) - await postgresClient.query(`DROP DATABASE ${databaseName}`) + await this.teardownDatabase(databaseName!) }) return { @@ -158,6 +183,22 @@ export class Worker { } } + private async teardownDatabase(databaseName: string) { + const { postgresClient } = await this.startContainerPromise + + try { + await this.forceDisconnectClientsFrom(databaseName!) + await postgresClient.query(`DROP DATABASE ${databaseName}`) + } catch (error) { + if ((error as Error)?.message?.includes("does not exist")) { + // Database was likely a nested database and manually dropped by the test worker, ignore + return + } + + throw error + } + } + private async createTemplate(rpc: WorkerRpc, params?: Jsonifiable) { const databaseName = getRandomDatabaseName()