diff --git a/packages/ai-genkit/src/retrievers/fga-reranker.ts b/packages/ai-genkit/src/retrievers/fga-reranker.ts index d2eeee5..12669ea 100644 --- a/packages/ai-genkit/src/retrievers/fga-reranker.ts +++ b/packages/ai-genkit/src/retrievers/fga-reranker.ts @@ -12,6 +12,7 @@ export type FGARerankerCheckerFn = (doc: Document) => ClientBatchCheckItem; export type FGARerankerConstructorArgs = { buildQuery: FGARerankerCheckerFn; + consistency?: ConsistencyPreference; }; export type FGARerankerArgs = FGARerankerConstructorArgs & { @@ -45,6 +46,7 @@ export type FGARerankerArgs = FGARerankerConstructorArgs & { export class FGAReranker { lc_namespace = ["genkit", "rerankers", "fga-reranker"]; private buildQuery: FGARerankerCheckerFn; + private consistency: ConsistencyPreference; private fgaClient: OpenFgaClient; static lc_name() { @@ -52,10 +54,11 @@ export class FGAReranker { } private constructor( - { buildQuery }: FGARerankerConstructorArgs, + { buildQuery, consistency }: FGARerankerConstructorArgs, fgaClient?: OpenFgaClient ) { this.buildQuery = buildQuery; + this.consistency = consistency; this.fgaClient = fgaClient || new OpenFgaClient({ @@ -134,19 +137,25 @@ export class FGAReranker { const response = await this.fgaClient.batchCheck( { checks }, { - consistency: ConsistencyPreference.HigherConsistency, + consistency: + this.consistency || ConsistencyPreference.HigherConsistency, } ); return response.result.reduce( (permissionMap: Map, result) => { - permissionMap.set(result.request.object, result.allowed || false); + const checkKey = this.getCheckKey(result.request); + permissionMap.set(checkKey, result.allowed || false); return permissionMap; }, new Map() ); } + private getCheckKey(check: ClientBatchCheckItem): string { + return `${check.user}|${check.object}|${check.relation}`; + } + /** * Retrieves a filtered list of documents based on permission checks. * @@ -157,13 +166,19 @@ export class FGAReranker { const { checks, documentToObjectMap } = documents.reduce( (acc, document: Document) => { const check = this.buildQuery(document); - acc.checks.push(check); - acc.documentToObjectMap.set(document, check.object); + const checkKey = this.getCheckKey(check); + acc.documentToObjectMap.set(document, checkKey); + // Skip duplicate checks for same user, object, and relation + if (!acc.seenChecks.has(checkKey)) { + acc.seenChecks.add(checkKey); + acc.checks.push(check); + } return acc; }, { checks: [] as ClientBatchCheckItem[], documentToObjectMap: new Map(), + seenChecks: new Set(), } ); diff --git a/packages/ai-genkit/test/index.test.ts b/packages/ai-genkit/test/index.test.ts index 4135dbb..19c651f 100644 --- a/packages/ai-genkit/test/index.test.ts +++ b/packages/ai-genkit/test/index.test.ts @@ -2,7 +2,11 @@ import { describe, it, expect, vi } from "vitest"; import { genkit, Document } from "genkit"; import { FGAReranker, auth0 } from "../src/retrievers/fga-reranker"; -import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk"; +import { + OpenFgaClient, + CredentialsMethod, + ConsistencyPreference, +} from "@openfga/sdk"; describe("FGAReranker", async () => { process.env.FGA_CLIENT_ID = "client-id"; @@ -54,12 +58,26 @@ describe("FGAReranker", async () => { expect(retriever.__action.name).toBe("auth0/fga-reranker"); }); - it("should filter relevant documents based on batchCheck results", async () => { + it("should filter relevant documents based on permission", async () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -72,4 +90,172 @@ describe("FGAReranker", async () => { expect(rankedDocuments[0].content).toEqual(documents[0].content); expect(rankedDocuments[0].metadata.id).toEqual(documents[0].metadata?.id); }); + + it("should handle empty document list", async () => { + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents: [], + }); + + expect(rankedDocuments).toEqual([]); + }); + + it("should handle empty permission list", async () => { + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents, + }); + + expect(rankedDocuments).toEqual([]); + }); + + it("should deduplicate permission checks for same object/user/relation", async () => { + const duplicateDocuments = [ + ...documents, + Document.fromText("private content", { id: "private-doc" }), + Document.fromText("private content", { id: "public-doc" }), + ]; + + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, + ], + }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents: duplicateDocuments, + }); + + expect(mockClient.batchCheck).toHaveBeenCalledTimes(1); + expect(mockClient.batchCheck).toBeCalledWith( + { + checks: [ + { object: "doc:public-doc", relation: "viewer", user: "user:user1" }, + { object: "doc:private-doc", relation: "viewer", user: "user:user1" }, + ], + }, + { consistency: ConsistencyPreference.HigherConsistency } + ); + expect(rankedDocuments.length).toEqual(2); + expect(rankedDocuments[0].content).toEqual(documents[0].content); + expect(rankedDocuments[0].metadata.id).toEqual(documents[0].metadata?.id); + }); + + it("should handle all documents being filtered out", async () => { + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, + ], + }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents, + }); + + expect(rankedDocuments).toEqual([]); + }); + + it("should handle batchCheck error gracefully", async () => { + // @ts-ignore + mockClient.batchCheck = vi + .fn() + .mockRejectedValue(new Error("FGA API Error")); + + await expect( + ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents, + }) + ).rejects.toThrow("FGA API Error"); + }); + + it("should preserve document metadata in filtered results", async () => { + const docsWithMetadata = [ + Document.fromText("public content", { + id: "public-doc", + importance: "high", + }), + Document.fromText("private content", { + id: "private-doc", + importance: "high", + }), + ]; + + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + ], + }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents: docsWithMetadata, + }); + + expect(rankedDocuments).toHaveLength(2); + expect(rankedDocuments[0].metadata).toEqual({ + id: "public-doc", + importance: "high", + score: 1, + }); + }); }); diff --git a/packages/ai-langchain/src/retrievers/fga-retriever.ts b/packages/ai-langchain/src/retrievers/fga-retriever.ts index 8615a44..d330117 100644 --- a/packages/ai-langchain/src/retrievers/fga-retriever.ts +++ b/packages/ai-langchain/src/retrievers/fga-retriever.ts @@ -19,15 +19,8 @@ export type FGARetrieverCheckerFn = ( export type FGARetrieverArgs = { retriever: BaseRetriever; buildQuery: FGARetrieverCheckerFn; - fields?: BaseRetrieverInput; -}; - -type AccessByDocumentFn = ( - checks: ClientBatchCheckItem[] -) => Promise>; - -type FGARetrieverArgsWithAccessByDocument = FGARetrieverArgs & { - accessByDocument: AccessByDocumentFn; + consistency?: ConsistencyPreference; + retrieverFields?: BaseRetrieverInput; }; /** @@ -58,35 +51,18 @@ export class FGARetriever extends BaseRetriever { lc_namespace = ["@langchain", "retrievers"]; private retriever: BaseRetriever; private buildQuery: FGARetrieverCheckerFn; - private accessByDocument: AccessByDocumentFn; + private consistency: ConsistencyPreference; + private fgaClient: OpenFgaClient; - private constructor({ - retriever, - buildQuery, - fields, - accessByDocument, - }: FGARetrieverArgsWithAccessByDocument) { - super(fields); + private constructor( + { buildQuery, retriever, consistency, retrieverFields }: FGARetrieverArgs, + fgaClient?: OpenFgaClient + ) { + super(retrieverFields); this.buildQuery = buildQuery; this.retriever = retriever; - this.accessByDocument = accessByDocument as AccessByDocumentFn; - } - - /** - * Creates a new FGARetriever instance using the given arguments and optional OpenFgaClient. - * - * @param args - @FGARetrieverArgs - * @param args.retriever - The underlying retriever instance to fetch documents. - * @param args.buildQuery - A function to generate access check requests for each document. - * @param args.fields - Optional - Additional fields to pass to the underlying retriever. - * @param fgaClient - Optional - OpenFgaClient instance to execute checks against. - * @returns A newly created FGARetriever instance configured with the provided arguments. - */ - static create( - args: FGARetrieverArgs, - fgaClient?: OpenFgaClient - ): FGARetriever { - const client = + this.consistency = consistency; + this.fgaClient = fgaClient || new OpenFgaClient({ apiUrl: process.env.FGA_API_URL || "https://api.us1.fga.dev", @@ -102,24 +78,24 @@ export class FGARetriever extends BaseRetriever { }, }, }); + } - const accessByDocument: AccessByDocumentFn = async function (checks) { - const response = await client.batchCheck( - { checks }, - { - consistency: ConsistencyPreference.HigherConsistency, - } - ); - return response.result.reduce( - (permissionMap: Map, result) => { - permissionMap.set(result.request.object, result.allowed || false); - return permissionMap; - }, - new Map() - ); - }; - - return new FGARetriever({ ...args, accessByDocument }); + /** + * Creates a new FGARetriever instance using the given arguments and optional OpenFgaClient. + * + * @param args - @FGARetrieverArgs + * @param args.retriever - The underlying retriever instance to fetch documents. + * @param args.buildQuery - A function to generate access check requests for each document. + * @param args.consistency - Optional - The consistency preference for the OpenFGA client. + * @param args.retrieverFields - Optional - Additional fields to pass to the underlying retriever. + * @param fgaClient - Optional - OpenFgaClient instance to execute checks against. + * @returns A newly created FGARetriever instance configured with the provided arguments. + */ + static create( + args: FGARetrieverArgs, + fgaClient?: OpenFgaClient + ): FGARetriever { + return new FGARetriever(args, fgaClient); } /** @@ -142,8 +118,13 @@ export class FGARetriever extends BaseRetriever { const { checks, documentToObject } = documents.reduce( (acc, doc) => { const check = this.buildQuery(doc); - acc.checks.push(check); - acc.documentToObject.set(doc, check.object); + const checkKey = this.getCheckKey(check); + acc.documentToObject.set(doc, checkKey); + // Skip duplicate checks for same user, object, and relation + if (!acc.seenChecks.has(checkKey)) { + acc.seenChecks.add(checkKey); + acc.checks.push(check); + } return acc; }, { @@ -152,16 +133,48 @@ export class FGARetriever extends BaseRetriever { DocumentInterface>, string >(), + seenChecks: new Set(), } ); - const resultsByObject = await this.accessByDocument(checks); + const permissionsMap = await this.checkPermissions(checks); return documents.filter( - (d, i) => resultsByObject.get(documentToObject.get(d) || "") === true + (d, i) => permissionsMap.get(documentToObject.get(d) || "") === true ); } + /** + * Checks permissions for a list of client requests. + * + * @param checks - An array of `ClientBatchCheckItem` objects representing the permissions to be checked. + * @returns A promise that resolves to a `Map` where the keys are object identifiers and the values are booleans indicating whether the permission is allowed. + */ + private async checkPermissions( + checks: ClientBatchCheckItem[] + ): Promise> { + const response = await this.fgaClient.batchCheck( + { checks }, + { + consistency: + this.consistency || ConsistencyPreference.HigherConsistency, + } + ); + + return response.result.reduce( + (permissionMap: Map, result) => { + const checkKey = this.getCheckKey(result.request); + permissionMap.set(checkKey, result.allowed || false); + return permissionMap; + }, + new Map() + ); + } + + private getCheckKey(check: ClientBatchCheckItem): string { + return `${check.user}|${check.object}|${check.relation}`; + } + /** * Converts the FGA retriever into a tool that can be used by a LangGraph agent. * @returns StructuredToolInterface. diff --git a/packages/ai-langchain/test/index.test.ts b/packages/ai-langchain/test/index.test.ts index 0359a56..9294eee 100644 --- a/packages/ai-langchain/test/index.test.ts +++ b/packages/ai-langchain/test/index.test.ts @@ -1,6 +1,10 @@ import { describe, it, expect, vi } from "vitest"; import { FGARetriever } from "../src/retrievers/fga-retriever"; -import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk"; +import { + OpenFgaClient, + CredentialsMethod, + ConsistencyPreference, +} from "@openfga/sdk"; import { Document } from "@langchain/core/documents"; import { BaseRetriever } from "@langchain/core/retrievers"; @@ -41,10 +45,6 @@ describe("FGARetriever", () => { metadata: { id: "private-doc" }, pageContent: "private content", }), - new Document({ - metadata: { id: "private-doc-2" }, - pageContent: "private content 2", - }), ]; const args = { @@ -62,15 +62,29 @@ describe("FGARetriever", () => { expect(retriever).toBeInstanceOf(FGARetriever); }); - it("should filter relevant documents based on accessByDocument results", async () => { + it("should filter documents based on permissions", async () => { const retriever = FGARetriever.create(args, mockClient); // @ts-ignore mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -78,21 +92,134 @@ describe("FGARetriever", () => { expect(result).toEqual([mockDocuments[0]]); }); - it("should return joined string of filtered doc content", async () => { + it("should handle empty document list", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue([]); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(0); + }); + + it("should handle empty permission list", async () => { const retriever = FGARetriever.create(args, mockClient); // @ts-ignore mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] }); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(0); + }); + + it("should deduplicate permission checks for same object/user/relation", async () => { + const duplicateDocuments = [ + ...mockDocuments, + new Document({ + metadata: { id: "public-doc" }, + pageContent: "public content", + }), + new Document({ + metadata: { id: "private-doc" }, + pageContent: "private content", + }), + ]; + + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue(duplicateDocuments); + // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, + ], + }); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(2); + expect(mockClient.batchCheck).toHaveBeenCalledTimes(1); + expect(mockClient.batchCheck).toBeCalledWith( + { + checks: [ + { object: "doc:public-doc", relation: "viewer", user: "user:user1" }, + { object: "doc:private-doc", relation: "viewer", user: "user:user1" }, + ], + }, + { consistency: ConsistencyPreference.HigherConsistency } + ); + }); + + it("should handle all documents being filtered out", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: false }, { request: { object: "doc:private-doc" }, allowed: false }, - { request: { object: "doc:private-doc-2" }, allowed: true }, + ], + }); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(0); + }); + + it("should return joined string of filtered doc content", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, ], }); const tool = retriever.asJoinedStringTool(); const result = await tool.invoke({ query: "test query" }); - expect(result).toEqual("public content\n\nprivate content 2"); + expect(result).toEqual("public content\n\nprivate content"); + }); + + it("should handle batchCheck error gracefully", async () => { + // @ts-ignore + mockClient.batchCheck = vi + .fn() + .mockRejectedValue(new Error("FGA API Error")); + + const retriever = FGARetriever.create(args, mockClient); + await expect(retriever.invoke("test query")).rejects.toThrow( + "FGA API Error" + ); }); }); diff --git a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts index a9ed89d..3664991 100644 --- a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts +++ b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts @@ -20,6 +20,7 @@ export type FGARetrieverCheckerFn = ( export interface FGARetrieverArgs { buildQuery: FGARetrieverCheckerFn; retriever: BaseRetriever; + consistency?: ConsistencyPreference; } /** @@ -51,6 +52,7 @@ export class FGARetriever extends BaseRetriever { lc_namespace = ["llamaindex", "retrievers", "fga-retriever"]; private retriever: BaseRetriever; private buildQuery: FGARetrieverCheckerFn; + private consistency: ConsistencyPreference; private fgaClient: OpenFgaClient; static lc_name() { @@ -58,13 +60,14 @@ export class FGARetriever extends BaseRetriever { } private constructor( - { buildQuery, retriever }: FGARetrieverArgs, + { buildQuery, retriever, consistency }: FGARetrieverArgs, fgaClient?: OpenFgaClient ) { super(); this.retriever = retriever; this.buildQuery = buildQuery; + this.consistency = consistency; this.fgaClient = fgaClient || new OpenFgaClient({ @@ -89,14 +92,12 @@ export class FGARetriever extends BaseRetriever { * @param args - @FGARetrieverArgs * @param args.retriever - The underlying retriever instance to fetch documents. * @param args.buildQuery - A function to generate access check requests for each document. + * @param args.consistency - Optional - The consistency preference for the OpenFGA client. * @param fgaClient - Optional - OpenFgaClient instance to execute checks against. * @returns A newly created FGARetriever instance configured with the provided arguments. */ - static create( - { buildQuery, retriever }: FGARetrieverArgs, - fgaClient?: OpenFgaClient - ) { - return new FGARetriever({ buildQuery, retriever }, fgaClient); + static create(args: FGARetrieverArgs, fgaClient?: OpenFgaClient) { + return new FGARetriever(args, fgaClient); } /** @@ -111,19 +112,25 @@ export class FGARetriever extends BaseRetriever { const response = await this.fgaClient.batchCheck( { checks }, { - consistency: ConsistencyPreference.HigherConsistency, + consistency: + this.consistency || ConsistencyPreference.HigherConsistency, } ); return response.result.reduce( (permissionMap: Map, result) => { - permissionMap.set(result.request.object, result.allowed || false); + const checkKey = this.getCheckKey(result.request); + permissionMap.set(checkKey, result.allowed || false); return permissionMap; }, new Map() ); } + private getCheckKey(check: ClientBatchCheckItem): string { + return `${check.user}|${check.object}|${check.relation}`; + } + /** * Retrieves nodes based on the provided query parameters, processes * them through a checker function, @@ -138,13 +145,19 @@ export class FGARetriever extends BaseRetriever { const { checks, documentToObjectMap } = retrievedNodes.reduce( (acc, nodeWithScore: NodeWithScore) => { const check = this.buildQuery(nodeWithScore.node); - acc.checks.push(check); - acc.documentToObjectMap.set(nodeWithScore, check.object); + const checkKey = this.getCheckKey(check); + acc.documentToObjectMap.set(nodeWithScore, checkKey); + // Skip duplicate checks for same user, object, and relation + if (!acc.seenChecks.has(checkKey)) { + acc.seenChecks.add(checkKey); + acc.checks.push(check); + } return acc; }, { checks: [] as ClientBatchCheckItem[], documentToObjectMap: new Map, string>(), + seenChecks: new Set(), } ); diff --git a/packages/ai-llamaindex/test/index.test.ts b/packages/ai-llamaindex/test/index.test.ts index 89bfaf3..f326659 100644 --- a/packages/ai-llamaindex/test/index.test.ts +++ b/packages/ai-llamaindex/test/index.test.ts @@ -3,7 +3,11 @@ import { FGARetriever, FGARetrieverCheckerFn, } from "../src/retrievers/fga-retriever"; -import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk"; +import { + OpenFgaClient, + CredentialsMethod, + ConsistencyPreference, +} from "@openfga/sdk"; import { BaseRetriever, NodeWithScore } from "llamaindex"; describe("FGARetriever", () => { @@ -64,14 +68,126 @@ describe("FGARetriever", () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); const retriever = FGARetriever.create(args, mockClient); - const result = await retriever._retrieve({ query: "test" }); + const result = await retriever.retrieve({ query: "test" }); expect(result).toEqual([mockDocuments[0]]); }); + + it("should handle empty document list", async () => { + // @ts-ignore + args.retriever.retrieve.mockResolvedValue([]); + const retriever = FGARetriever.create(args, mockClient); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(0); + }); + + it("should handle empty permission list", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] }); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(0); + }); + + it("should deduplicate permission checks for same object/user/relation", async () => { + const duplicateDocuments = [ + ...mockDocuments, + { + node: { text: "public content", metadata: { id: "public-doc" } }, + score: 1, + } as unknown as NodeWithScore, + { + node: { text: "private content", metadata: { id: "private-doc" } }, + score: 1, + } as unknown as NodeWithScore, + ]; + + // @ts-ignore + args.retriever.retrieve.mockResolvedValue(duplicateDocuments); + + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, + ], + }); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(2); + expect(mockClient.batchCheck).toHaveBeenCalledTimes(1); + expect(mockClient.batchCheck).toBeCalledWith( + { + checks: [ + { object: "doc:public-doc", relation: "viewer", user: "user:user1" }, + { object: "doc:private-doc", relation: "viewer", user: "user:user1" }, + ], + }, + { consistency: ConsistencyPreference.HigherConsistency } + ); + }); + + it("should handle all documents being filtered out", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: false }, + { request: { object: "doc:private-doc" }, allowed: false }, + ], + }); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(0); + }); + + it("should handle batchCheck error gracefully", async () => { + // @ts-ignore + mockClient.batchCheck = vi + .fn() + .mockRejectedValue(new Error("FGA API Error")); + + const retriever = FGARetriever.create(args, mockClient); + + await expect(retriever.retrieve({ query: "test" })).rejects.toThrow( + "FGA API Error" + ); + }); });