Skip to content

Commit

Permalink
feat: Support INSERT INTO ... RETURNING (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb authored Aug 15, 2022
1 parent fa2a94a commit 7f3f706
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 13 deletions.
40 changes: 39 additions & 1 deletion src/lib/builders/commands/insert.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ test("insert works", async (t) => {
)
.run(pool)

assert<Equals<typeof result, never>>()
assert<Equals<typeof result, {}[]>>()

const {
rows: [{ count: countAfterInsert }],
Expand Down Expand Up @@ -254,3 +254,41 @@ test("onConflict().doNothing().where()", async (t) => {

await t.notThrowsAsync(async () => await query.run(pool))
})

test(".returning()", async (t) => {
const { pool } = await getTestDatabase()

const query = new InsertCommand("actor")
.values({
first_name: "foo",
last_name: "bar",
})
.returning("first_name")

const result = await query.run(pool)

assert<Equals<typeof result[0], { first_name: string }>>()

t.is(result.length, 1)
t.deepEqual(result[0], {
first_name: "foo",
})
})

test(".returning() (array, wildcard)", async (t) => {
const { pool } = await getTestDatabase()

const query = new InsertCommand("actor")
.values({
first_name: "foo",
last_name: "bar",
})
.returning(["actor.*"])

const result = await query.run(pool)
t.is(result.length, 1)
t.like(result[0], {
first_name: "foo",
last_name: "bar",
})
})
69 changes: 63 additions & 6 deletions src/lib/builders/commands/insert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,37 @@ import { mix } from "ts-mixer"
import { cols, sql, vals } from "zapatos/db"
import * as schema from "zapatos/schema"
import { SQLCommand } from "../common/sql-command"
import { mapWithSeparator } from "../utils/map-with-separator"
import {
ColumnSpecificationsForTableWithWildcards,
constructColumnSelection,
mapWithSeparator,
SelectableFromColumnSpecifications,
} from "../utils"
import { OnConflictClauseBuilder } from "./clauses/on-conflict"

export interface InsertCommand<TableName extends schema.Table>
extends OnConflictClauseBuilder<TableName>,
SQLCommand<never> {}
export interface InsertCommand<
TableName extends schema.Table,
SelectableMap extends Record<TableName, any> = Record<
TableName,
schema.SelectableForTable<TableName>
>,
Returning = {}
> extends OnConflictClauseBuilder<TableName>,
SQLCommand<Returning[]> {}

@mix(OnConflictClauseBuilder, SQLCommand)
export class InsertCommand<TableName extends schema.Table> {
export class InsertCommand<
TableName extends schema.Table,
SelectableMap extends Record<TableName, any> = Record<
TableName,
schema.SelectableForTable<TableName>
>,
Returning = {}
> {
private readonly _tableName: string
private _rows: Array<schema.InsertableForTable<TableName>> = []
private _returning: ColumnSpecificationsForTableWithWildcards<TableName>[] =
[]

constructor(tableName: TableName) {
this._tableName = tableName
Expand All @@ -30,6 +50,38 @@ export class InsertCommand<TableName extends schema.Table> {
return this
}

returning<T extends ColumnSpecificationsForTableWithWildcards<TableName>[]>(
...columnNames: T
): InsertCommand<
TableName,
SelectableMap,
Returning &
SelectableFromColumnSpecifications<TableName, T[number], SelectableMap>
>
returning<T extends ColumnSpecificationsForTableWithWildcards<TableName>[]>(
columnSpecifications: T
): InsertCommand<
TableName,
SelectableMap,
Returning &
SelectableFromColumnSpecifications<TableName, T[number], SelectableMap>
>
returning<T extends ColumnSpecificationsForTableWithWildcards<TableName>[]>(
...args: any
): InsertCommand<
TableName,
SelectableMap,
Returning &
SelectableFromColumnSpecifications<TableName, T[number], SelectableMap>
> {
if (args.length === 1 && Array.isArray(args[0])) {
this._returning = [...this._returning, ...args[0]]
} else {
this._returning = [...this._returning, ...args]
}
return this as any
}

compile() {
const colsSQL = sql`${cols(this._rows[0])}`

Expand All @@ -39,10 +91,15 @@ export class InsertCommand<TableName extends schema.Table> {
(v) => sql`(${vals(v)})`
)

const returningSQL =
this._returning.length > 0
? sql`RETURNING ${constructColumnSelection(this._returning)}`
: []

return sql`INSERT INTO ${
this._tableName
} (${colsSQL}) VALUES ${valuesSQL} ${this.compileOnConflict(
Object.keys(this._rows[0])
)};`.compile()
)} ${returningSQL};`.compile()
}
}
29 changes: 25 additions & 4 deletions src/lib/builders/commands/select.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,33 @@ export class SelectCommand<
return new DeleteCommand<TableName>(this._tableName).where(this._whereable)
}

insert() {
return new InsertCommand<TableName>(this._tableName).where(this._whereable)
insert(
...rows: Array<schema.InsertableForTable<TableName>>
): InsertCommand<TableName>
insert(
rows: Array<schema.InsertableForTable<TableName>>[]
): InsertCommand<TableName>
insert(...args: any[]) {
return new InsertCommand<TableName>(this._tableName)
.where(this._whereable)
.values(...args)
}

update() {
return new UpdateCommand<TableName>(this._tableName).where(this._whereable)
update<T extends keyof schema.UpdatableForTable<TableName>>(
columnName: T,
value: schema.UpdatableForTable<TableName>[T]
): UpdateCommand<TableName>
update(
values: Partial<schema.UpdatableForTable<TableName>>
): UpdateCommand<TableName>
update(...args: any[]): UpdateCommand<TableName> {
let command = new UpdateCommand<TableName>(this._tableName).where(
this._whereable
)
if (args.length > 0) {
command = (command as any).set(...args)
}
return command
}

compile() {
Expand Down
21 changes: 19 additions & 2 deletions src/lib/builders/query-builder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,29 @@ test("select()", async (t) => {
t.snapshot(result)
})

test("insert()", async (t) => {
const { pool } = await getTestDatabase()
const query = new QueryBuilder("film")
.whereIn("film_id", [1, 2])
.insert({
title: "Test",
language_id: 1,
description: "foo bar",
fulltext: "foo bar",
})
.returning("*")

const result = await query.run(pool)
const { last_update, ...resultWithoutLastUpdate } = result[0]

t.snapshot(resultWithoutLastUpdate)
})

test("update()", async (t) => {
const { pool } = await getTestDatabase()
const query = new QueryBuilder("film")
.whereIn("film_id", [1, 2])
.update()
.set("description", "foo bar")
.update("description", "foo bar")
.returning(["film_id", "description"])
const result = await query.run(pool)

Expand Down
20 changes: 20 additions & 0 deletions src/tests/snapshots/src/lib/builders/query-builder.test.ts.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ Generated by [AVA](https://avajs.dev).
},
]

## insert()

> Snapshot 1
{
description: 'foo bar',
film_id: 1001,
fulltext: '\'bar\':3 \'foo\':2 \'test\':1',
language_id: 1,
length: null,
original_language_id: null,
rating: 'G',
release_year: null,
rental_duration: 3,
rental_rate: '4.99',
replacement_cost: '19.99',
special_features: null,
title: 'Test',
}

## update()

> Snapshot 1
Expand Down
Binary file modified src/tests/snapshots/src/lib/builders/query-builder.test.ts.snap
Binary file not shown.

0 comments on commit 7f3f706

Please sign in to comment.