Skip to content

Commit

Permalink
Merge pull request #38 from deepu105/optimize-for-fga-batch-check
Browse files Browse the repository at this point in the history
Optimize for fga batch check rate limits
  • Loading branch information
jcenturion authored Jan 30, 2025
2 parents 5a53838 + 5a47d1f commit 98247d7
Show file tree
Hide file tree
Showing 6 changed files with 561 additions and 91 deletions.
25 changes: 20 additions & 5 deletions packages/ai-genkit/src/retrievers/fga-reranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export type FGARerankerCheckerFn = (doc: Document) => ClientBatchCheckItem;

export type FGARerankerConstructorArgs = {
buildQuery: FGARerankerCheckerFn;
consistency?: ConsistencyPreference;
};

export type FGARerankerArgs = FGARerankerConstructorArgs & {
Expand Down Expand Up @@ -45,17 +46,19 @@ export type FGARerankerArgs = FGARerankerConstructorArgs & {
export class FGAReranker {
lc_namespace = ["genkit", "rerankers", "fga-reranker"];
private buildQuery: FGARerankerCheckerFn;
private consistency: ConsistencyPreference;
private fgaClient: OpenFgaClient;

static lc_name() {
return "FGAReranker";
}

private constructor(
{ buildQuery }: FGARerankerConstructorArgs,
{ buildQuery, consistency }: FGARerankerConstructorArgs,
fgaClient?: OpenFgaClient
) {
this.buildQuery = buildQuery;
this.consistency = consistency;
this.fgaClient =
fgaClient ||
new OpenFgaClient({
Expand Down Expand Up @@ -134,19 +137,25 @@ export class FGAReranker {
const response = await this.fgaClient.batchCheck(
{ checks },
{
consistency: ConsistencyPreference.HigherConsistency,
consistency:
this.consistency || ConsistencyPreference.HigherConsistency,
}
);

return response.result.reduce(
(permissionMap: Map<string, boolean>, result) => {
permissionMap.set(result.request.object, result.allowed || false);
const checkKey = this.getCheckKey(result.request);
permissionMap.set(checkKey, result.allowed || false);
return permissionMap;
},
new Map<string, boolean>()
);
}

private getCheckKey(check: ClientBatchCheckItem): string {
return `${check.user}|${check.object}|${check.relation}`;
}

/**
* Retrieves a filtered list of documents based on permission checks.
*
Expand All @@ -157,13 +166,19 @@ export class FGAReranker {
const { checks, documentToObjectMap } = documents.reduce(
(acc, document: Document) => {
const check = this.buildQuery(document);
acc.checks.push(check);
acc.documentToObjectMap.set(document, check.object);
const checkKey = this.getCheckKey(check);
acc.documentToObjectMap.set(document, checkKey);
// Skip duplicate checks for same user, object, and relation
if (!acc.seenChecks.has(checkKey)) {
acc.seenChecks.add(checkKey);
acc.checks.push(check);
}
return acc;
},
{
checks: [] as ClientBatchCheckItem[],
documentToObjectMap: new Map<Document, string>(),
seenChecks: new Set<string>(),
}
);

Expand Down
194 changes: 190 additions & 4 deletions packages/ai-genkit/test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { describe, it, expect, vi } from "vitest";
import { genkit, Document } from "genkit";
import { FGAReranker, auth0 } from "../src/retrievers/fga-reranker";

import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk";
import {
OpenFgaClient,
CredentialsMethod,
ConsistencyPreference,
} from "@openfga/sdk";

describe("FGAReranker", async () => {
process.env.FGA_CLIENT_ID = "client-id";
Expand Down Expand Up @@ -54,12 +58,26 @@ describe("FGAReranker", async () => {
expect(retriever.__action.name).toBe("auth0/fga-reranker");
});

it("should filter relevant documents based on batchCheck results", async () => {
it("should filter relevant documents based on permission", async () => {
// @ts-ignore
mockClient.batchCheck = vi.fn().mockResolvedValue({
result: [
{ request: { object: "doc:public-doc" }, allowed: true },
{ request: { object: "doc:private-doc" }, allowed: false },
{
request: {
object: "doc:public-doc",
user: "user:user1",
relation: "viewer",
},
allowed: true,
},
{
request: {
object: "doc:private-doc",
user: "user:user1",
relation: "viewer",
},
allowed: false,
},
],
});

Expand All @@ -72,4 +90,172 @@ describe("FGAReranker", async () => {
expect(rankedDocuments[0].content).toEqual(documents[0].content);
expect(rankedDocuments[0].metadata.id).toEqual(documents[0].metadata?.id);
});

it("should handle empty document list", async () => {
const rankedDocuments = await ai.rerank({
reranker: FGAReranker.create(args, mockClient),
query: "input",
documents: [],
});

expect(rankedDocuments).toEqual([]);
});

it("should handle empty permission list", async () => {
// @ts-ignore
mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] });

const rankedDocuments = await ai.rerank({
reranker: FGAReranker.create(args, mockClient),
query: "input",
documents,
});

expect(rankedDocuments).toEqual([]);
});

it("should deduplicate permission checks for same object/user/relation", async () => {
const duplicateDocuments = [
...documents,
Document.fromText("private content", { id: "private-doc" }),
Document.fromText("private content", { id: "public-doc" }),
];

// @ts-ignore
mockClient.batchCheck = vi.fn().mockResolvedValue({
result: [
{
request: {
object: "doc:public-doc",
user: "user:user1",
relation: "viewer",
},
allowed: true,
},
{
request: {
object: "doc:private-doc",
user: "user:user1",
relation: "viewer",
},
allowed: false,
},
],
});

const rankedDocuments = await ai.rerank({
reranker: FGAReranker.create(args, mockClient),
query: "input",
documents: duplicateDocuments,
});

expect(mockClient.batchCheck).toHaveBeenCalledTimes(1);
expect(mockClient.batchCheck).toBeCalledWith(
{
checks: [
{ object: "doc:public-doc", relation: "viewer", user: "user:user1" },
{ object: "doc:private-doc", relation: "viewer", user: "user:user1" },
],
},
{ consistency: ConsistencyPreference.HigherConsistency }
);
expect(rankedDocuments.length).toEqual(2);
expect(rankedDocuments[0].content).toEqual(documents[0].content);
expect(rankedDocuments[0].metadata.id).toEqual(documents[0].metadata?.id);
});

it("should handle all documents being filtered out", async () => {
// @ts-ignore
mockClient.batchCheck = vi.fn().mockResolvedValue({
result: [
{
request: {
object: "doc:private-doc",
user: "user:user1",
relation: "viewer",
},
allowed: false,
},
{
request: {
object: "doc:public-doc",
user: "user:user1",
relation: "viewer",
},
allowed: false,
},
],
});

const rankedDocuments = await ai.rerank({
reranker: FGAReranker.create(args, mockClient),
query: "input",
documents,
});

expect(rankedDocuments).toEqual([]);
});

it("should handle batchCheck error gracefully", async () => {
// @ts-ignore
mockClient.batchCheck = vi
.fn()
.mockRejectedValue(new Error("FGA API Error"));

await expect(
ai.rerank({
reranker: FGAReranker.create(args, mockClient),
query: "input",
documents,
})
).rejects.toThrow("FGA API Error");
});

it("should preserve document metadata in filtered results", async () => {
const docsWithMetadata = [
Document.fromText("public content", {
id: "public-doc",
importance: "high",
}),
Document.fromText("private content", {
id: "private-doc",
importance: "high",
}),
];

// @ts-ignore
mockClient.batchCheck = vi.fn().mockResolvedValue({
result: [
{
request: {
object: "doc:public-doc",
user: "user:user1",
relation: "viewer",
},
allowed: true,
},
{
request: {
object: "doc:private-doc",
user: "user:user1",
relation: "viewer",
},
allowed: true,
},
],
});

const rankedDocuments = await ai.rerank({
reranker: FGAReranker.create(args, mockClient),
query: "input",
documents: docsWithMetadata,
});

expect(rankedDocuments).toHaveLength(2);
expect(rankedDocuments[0].metadata).toEqual({
id: "public-doc",
importance: "high",
score: 1,
});
});
});
Loading

0 comments on commit 98247d7

Please sign in to comment.