Skip to content

Commit

Permalink
fix(search): the revelance sorting was not done when having the llm r…
Browse files Browse the repository at this point in the history
…esults
  • Loading branch information
sneko committed Mar 13, 2024
1 parent 68bd0e1 commit d7180a0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
49 changes: 40 additions & 9 deletions src/server/routers/initiative.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { initiativeNotFoundError } from '@etabli/src/models/entities/errors';
import { prisma } from '@etabli/src/prisma/client';
import { initiativePrismaToModel } from '@etabli/src/server/routers/mappers';
import { publicProcedure, router } from '@etabli/src/server/trpc';
import { paginate } from '@etabli/src/utils/page';

export const initiativeRouter = router({
getInitiative: publicProcedure.input(GetInitiativeSchema).query(async ({ ctx, input }) => {
Expand Down Expand Up @@ -48,22 +49,35 @@ export const initiativeRouter = router({
};
}),
listInitiatives: publicProcedure.input(ListInitiativesSchema).query(async ({ ctx, input }) => {
// TODO: for when implementing filters on associations
const where: Prisma.InitiativeWhereInput = {};

let homemadePaginationInitiativeIds: string[] | null = null;
let homemadePaginationTotalCount: number | null = null;

if (!!input.filterBy.query) {
// Restrict the search, the pagination will work as expected
const initiativesIds = await llmManagerInstance.getInitiativesFromQuery(input.filterBy.query);
const matchingInitiativesIds = await llmManagerInstance.getInitiativesFromQuery(input.filterBy.query);
const currentPageInitiativesIds = paginate(matchingInitiativesIds, input.pageSize, input.page);

homemadePaginationInitiativeIds = currentPageInitiativesIds;
homemadePaginationTotalCount = matchingInitiativesIds.length;

where.id = {
in: initiativesIds,
in: homemadePaginationInitiativeIds,
};
}

// We do a transaction to get along the total count
// Refs:
// - https://github.com/prisma/prisma/issues/7550
// - https://github.com/prisma/prisma/discussions/3087
const [initiatives, totalCount] = await prisma.$transaction([
//
// ---
// Prisma has a limitation to order by a given list (https://github.com/prisma/prisma/issues/11336#issuecomment-1986031261) and cannot only use a `whereRaw` to work around
// The first possibility would be to switch to a full raw query, but it's more complicated for typings and `include`... so we decided to do it in multiple steps to preverse the Prisma usage
// Note: in case of a full raw query we would just kept the original code and enable the logs debug mode to have 90% of the raw query written, and using a where with something like `ORDER BY array_position(ARRAY[1, 2, 3]::uuid[], table_name.id::uuid)`
let [initiatives, totalCount] = await prisma.$transaction([
prisma.initiative.findMany({
where: where,
include: {
Expand All @@ -86,15 +100,32 @@ export const initiativeRouter = router({
},
},
},
orderBy: {
name: 'asc',
},
skip: (input.page - 1) * input.pageSize,
take: input.pageSize,
...(!!homemadePaginationInitiativeIds
? {
// The pagination is done before this, and sorting will be done after
}
: {
orderBy: {
name: 'asc',
},
skip: (input.page - 1) * input.pageSize,
take: input.pageSize,
}),
}),
prisma.initiative.count({ where: where }),
...(homemadePaginationTotalCount !== null ? [] : [prisma.initiative.count({ where: where })]),
]);

if (totalCount === undefined && homemadePaginationTotalCount !== null) {
totalCount = homemadePaginationTotalCount;
}

if (!!homemadePaginationInitiativeIds) {
// As explained above we need to sort them since Prisma does not handle this, in the order returned by the LLM instance since IDs are initially sorted by relevance score
initiatives = initiatives.sort((a, b) => {
return (homemadePaginationInitiativeIds as string[]).indexOf(a.id) - (homemadePaginationInitiativeIds as string[]).indexOf(b.id);
});
}

return {
initiatives: initiatives.map((initiative) =>
initiativePrismaToModel({
Expand Down
5 changes: 5 additions & 0 deletions src/utils/page.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
export function formatPageTitle(pageSpecificTitle: string) {
return `${pageSpecificTitle} - Établi`;
}

export function paginate<T>(items: T[], pageSize: number, pageNumber: number) {
// Human-readable page numbers in our application start with 1, so we reduce 1 in the first argument
return items.slice((pageNumber - 1) * pageSize, pageNumber * pageSize);
}

0 comments on commit d7180a0

Please sign in to comment.