From 6bd45e1124e414124689758a85131091dcff1f1c Mon Sep 17 00:00:00 2001 From: Deepu Date: Mon, 27 Jan 2025 12:39:19 +0100 Subject: [PATCH 1/9] optimize FGA checks by deduping for langchain --- .../src/retrievers/fga-retriever.ts | 13 ++- packages/ai-langchain/test/index.test.ts | 100 ++++++++++++++++-- 2 files changed, 103 insertions(+), 10 deletions(-) diff --git a/packages/ai-langchain/src/retrievers/fga-retriever.ts b/packages/ai-langchain/src/retrievers/fga-retriever.ts index 8615a44..62925ad 100644 --- a/packages/ai-langchain/src/retrievers/fga-retriever.ts +++ b/packages/ai-langchain/src/retrievers/fga-retriever.ts @@ -142,8 +142,19 @@ export class FGARetriever extends BaseRetriever { const { checks, documentToObject } = documents.reduce( (acc, doc) => { const check = this.buildQuery(doc); - acc.checks.push(check); acc.documentToObject.set(doc, check.object); + // Skip duplicate checks for same user, object, and relation + if ( + acc.checks.some( + (ex) => + ex.object === check.object && + ex.user === check.user && + ex.relation === check.relation + ) + ) { + return acc; + } + acc.checks.push(check); return acc; }, { diff --git a/packages/ai-langchain/test/index.test.ts b/packages/ai-langchain/test/index.test.ts index 0359a56..d6abd16 100644 --- a/packages/ai-langchain/test/index.test.ts +++ b/packages/ai-langchain/test/index.test.ts @@ -1,6 +1,10 @@ import { describe, it, expect, vi } from "vitest"; import { FGARetriever } from "../src/retrievers/fga-retriever"; -import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk"; +import { + OpenFgaClient, + CredentialsMethod, + ConsistencyPreference, +} from "@openfga/sdk"; import { Document } from "@langchain/core/documents"; import { BaseRetriever } from "@langchain/core/retrievers"; @@ -41,10 +45,6 @@ describe("FGARetriever", () => { metadata: { id: "private-doc" }, pageContent: "private content", }), - new Document({ - metadata: { id: "private-doc-2" }, - pageContent: "private content 2", - }), ]; const args = { @@ -62,7 +62,7 @@ describe("FGARetriever", () => { expect(retriever).toBeInstanceOf(FGARetriever); }); - it("should filter relevant documents based on accessByDocument results", async () => { + it("should filter documents based on permissions", async () => { const retriever = FGARetriever.create(args, mockClient); // @ts-ignore mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); @@ -78,21 +78,103 @@ describe("FGARetriever", () => { expect(result).toEqual([mockDocuments[0]]); }); - it("should return joined string of filtered doc content", async () => { + it("should handle empty document list", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue([]); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(0); + }); + + it("should handle empty permission list", async () => { const retriever = FGARetriever.create(args, mockClient); // @ts-ignore mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] }); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(0); + }); + + it("should deduplicate permission checks for same object/user/relation", async () => { + const duplicateDocuments = [ + ...mockDocuments, + new Document({ + metadata: { id: "public-doc" }, + pageContent: "public content", + }), + new Document({ + metadata: { id: "private-doc" }, + pageContent: "private content", + }), + ]; + + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue(duplicateDocuments); + // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ { request: { object: "doc:public-doc" }, allowed: true }, { request: { object: "doc:private-doc" }, allowed: false }, - { request: { object: "doc:private-doc-2" }, allowed: true }, + ], + }); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(2); + expect(mockClient.batchCheck).toHaveBeenCalledTimes(1); + expect(mockClient.batchCheck).toBeCalledWith( + { + checks: [ + { + object: "doc:public-doc", + relation: "viewer", + user: "user:user1", + }, + { + object: "doc:private-doc", + relation: "viewer", + user: "user:user1", + }, + ], + }, + { consistency: ConsistencyPreference.HigherConsistency } + ); + }); + + it("should handle all documents being filtered out", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: false }, + { request: { object: "doc:private-doc" }, allowed: false }, + { request: { object: "doc:private-doc-2" }, allowed: false }, + ], + }); + + const result = await retriever.invoke("test query"); + expect(result).toHaveLength(0); + }); + + it("should return joined string of filtered doc content", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockRetriever._getRelevantDocuments.mockResolvedValue(mockDocuments); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: true }, + { request: { object: "doc:private-doc" }, allowed: true }, ], }); const tool = retriever.asJoinedStringTool(); const result = await tool.invoke({ query: "test query" }); - expect(result).toEqual("public content\n\nprivate content 2"); + expect(result).toEqual("public content\n\nprivate content"); }); }); From d9def5915cb40d0fd9cb09776ac2a26c7105812b Mon Sep 17 00:00:00 2001 From: Deepu Date: Mon, 27 Jan 2025 12:39:54 +0100 Subject: [PATCH 2/9] optimize FGA checks by deduping for genkit --- .../ai-genkit/src/retrievers/fga-reranker.ts | 13 +- packages/ai-genkit/test/index.test.ts | 134 +++++++++++++++++- 2 files changed, 144 insertions(+), 3 deletions(-) diff --git a/packages/ai-genkit/src/retrievers/fga-reranker.ts b/packages/ai-genkit/src/retrievers/fga-reranker.ts index d2eeee5..6db0309 100644 --- a/packages/ai-genkit/src/retrievers/fga-reranker.ts +++ b/packages/ai-genkit/src/retrievers/fga-reranker.ts @@ -157,8 +157,19 @@ export class FGAReranker { const { checks, documentToObjectMap } = documents.reduce( (acc, document: Document) => { const check = this.buildQuery(document); - acc.checks.push(check); acc.documentToObjectMap.set(document, check.object); + // Skip duplicate checks for same user, object, and relation + if ( + acc.checks.some( + (ex) => + ex.object === check.object && + ex.user === check.user && + ex.relation === check.relation + ) + ) { + return acc; + } + acc.checks.push(check); return acc; }, { diff --git a/packages/ai-genkit/test/index.test.ts b/packages/ai-genkit/test/index.test.ts index 4135dbb..8736eae 100644 --- a/packages/ai-genkit/test/index.test.ts +++ b/packages/ai-genkit/test/index.test.ts @@ -2,7 +2,11 @@ import { describe, it, expect, vi } from "vitest"; import { genkit, Document } from "genkit"; import { FGAReranker, auth0 } from "../src/retrievers/fga-reranker"; -import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk"; +import { + OpenFgaClient, + CredentialsMethod, + ConsistencyPreference, +} from "@openfga/sdk"; describe("FGAReranker", async () => { process.env.FGA_CLIENT_ID = "client-id"; @@ -54,7 +58,7 @@ describe("FGAReranker", async () => { expect(retriever.__action.name).toBe("auth0/fga-reranker"); }); - it("should filter relevant documents based on batchCheck results", async () => { + it("should filter relevant documents based on permission", async () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ @@ -72,4 +76,130 @@ describe("FGAReranker", async () => { expect(rankedDocuments[0].content).toEqual(documents[0].content); expect(rankedDocuments[0].metadata.id).toEqual(documents[0].metadata?.id); }); + + it("should handle empty document list", async () => { + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents: [], + }); + + expect(rankedDocuments).toEqual([]); + }); + + it("should handle empty permission list", async () => { + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents, + }); + + expect(rankedDocuments).toEqual([]); + }); + + it("should deduplicate permission checks for same object/user/relation", async () => { + const duplicateDocuments = [ + ...documents, + Document.fromText("private content", { id: "private-doc" }), + Document.fromText("private content", { id: "public-doc" }), + ]; + + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: true }, + { request: { object: "doc:private-doc" }, allowed: false }, + ], + }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents: duplicateDocuments, + }); + + expect(mockClient.batchCheck).toHaveBeenCalledTimes(1); + expect(mockClient.batchCheck).toBeCalledWith( + { + checks: [ + { object: "doc:public-doc", relation: "viewer", user: "user:user1" }, + { object: "doc:private-doc", relation: "viewer", user: "user:user1" }, + ], + }, + { consistency: ConsistencyPreference.HigherConsistency } + ); + expect(rankedDocuments.length).toEqual(2); + expect(rankedDocuments[0].content).toEqual(documents[0].content); + expect(rankedDocuments[0].metadata.id).toEqual(documents[0].metadata?.id); + }); + + it("should handle all documents being filtered out", async () => { + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:private-doc" }, allowed: false }, + { request: { object: "doc:public-doc" }, allowed: false }, + ], + }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents, + }); + + expect(rankedDocuments).toEqual([]); + }); + + it("should handle batchCheck error gracefully", async () => { + // @ts-ignore + mockClient.batchCheck = vi + .fn() + .mockRejectedValue(new Error("FGA API Error")); + + await expect( + ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents, + }) + ).rejects.toThrow("FGA API Error"); + }); + + it("should preserve document metadata in filtered results", async () => { + const docsWithMetadata = [ + Document.fromText("public content", { + id: "public-doc", + importance: "high", + }), + Document.fromText("private content", { + id: "private-doc", + importance: "high", + }), + ]; + + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: true }, + { request: { object: "doc:private-doc" }, allowed: true }, + ], + }); + + const rankedDocuments = await ai.rerank({ + reranker: FGAReranker.create(args, mockClient), + query: "input", + documents: docsWithMetadata, + }); + + expect(rankedDocuments).toHaveLength(2); + expect(rankedDocuments[0].metadata).toEqual({ + id: "public-doc", + importance: "high", + score: 1, + }); + }); }); From d2c5d699a442ee9e9161849ba32594e2bcb16564 Mon Sep 17 00:00:00 2001 From: Deepu Date: Mon, 27 Jan 2025 12:57:13 +0100 Subject: [PATCH 3/9] optimize FGA checks by deduping for llamaindex --- .../src/retrievers/fga-retriever.ts | 13 ++- packages/ai-llamaindex/test/index.test.ts | 92 ++++++++++++++++++- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts index a9ed89d..cde92dc 100644 --- a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts +++ b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts @@ -138,8 +138,19 @@ export class FGARetriever extends BaseRetriever { const { checks, documentToObjectMap } = retrievedNodes.reduce( (acc, nodeWithScore: NodeWithScore) => { const check = this.buildQuery(nodeWithScore.node); - acc.checks.push(check); acc.documentToObjectMap.set(nodeWithScore, check.object); + // Skip duplicate checks for same user, object, and relation + if ( + acc.checks.some( + (ex) => + ex.object === check.object && + ex.user === check.user && + ex.relation === check.relation + ) + ) { + return acc; + } + acc.checks.push(check); return acc; }, { diff --git a/packages/ai-llamaindex/test/index.test.ts b/packages/ai-llamaindex/test/index.test.ts index 89bfaf3..b91768a 100644 --- a/packages/ai-llamaindex/test/index.test.ts +++ b/packages/ai-llamaindex/test/index.test.ts @@ -3,7 +3,11 @@ import { FGARetriever, FGARetrieverCheckerFn, } from "../src/retrievers/fga-retriever"; -import { OpenFgaClient, CredentialsMethod } from "@openfga/sdk"; +import { + OpenFgaClient, + CredentialsMethod, + ConsistencyPreference, +} from "@openfga/sdk"; import { BaseRetriever, NodeWithScore } from "llamaindex"; describe("FGARetriever", () => { @@ -71,7 +75,91 @@ describe("FGARetriever", () => { const retriever = FGARetriever.create(args, mockClient); - const result = await retriever._retrieve({ query: "test" }); + const result = await retriever.retrieve({ query: "test" }); expect(result).toEqual([mockDocuments[0]]); }); + + it("should handle empty document list", async () => { + // @ts-ignore + args.retriever.retrieve.mockResolvedValue([]); + const retriever = FGARetriever.create(args, mockClient); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(0); + }); + + it("should handle empty permission list", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [] }); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(0); + }); + + it("should deduplicate permission checks for same object/user/relation", async () => { + const duplicateDocuments = [ + ...mockDocuments, + { + node: { text: "public content", metadata: { id: "public-doc" } }, + score: 1, + } as unknown as NodeWithScore, + { + node: { text: "private content", metadata: { id: "private-doc" } }, + score: 1, + } as unknown as NodeWithScore, + ]; + + // @ts-ignore + args.retriever.retrieve.mockResolvedValue(duplicateDocuments); + + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: true }, + { request: { object: "doc:private-doc" }, allowed: false }, + ], + }); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(2); + expect(mockClient.batchCheck).toHaveBeenCalledTimes(1); + expect(mockClient.batchCheck).toBeCalledWith( + { + checks: [ + { object: "doc:public-doc", relation: "viewer", user: "user:user1" }, + { object: "doc:private-doc", relation: "viewer", user: "user:user1" }, + ], + }, + { consistency: ConsistencyPreference.HigherConsistency } + ); + }); + + it("should handle all documents being filtered out", async () => { + const retriever = FGARetriever.create(args, mockClient); + // @ts-ignore + mockClient.batchCheck = vi.fn().mockResolvedValue({ + result: [ + { request: { object: "doc:public-doc" }, allowed: false }, + { request: { object: "doc:private-doc" }, allowed: false }, + ], + }); + + const result = await retriever.retrieve({ query: "test" }); + expect(result).toHaveLength(0); + }); + + it("should handle batchCheck error gracefully", async () => { + // @ts-ignore + mockClient.batchCheck = vi + .fn() + .mockRejectedValue(new Error("FGA API Error")); + + const retriever = FGARetriever.create(args, mockClient); + + await expect(retriever.retrieve({ query: "test" })).rejects.toThrow( + "FGA API Error" + ); + }); }); From 5d91da20db6e85b1065946cf59426427b5fae7ee Mon Sep 17 00:00:00 2001 From: Deepu Date: Mon, 27 Jan 2025 12:57:32 +0100 Subject: [PATCH 4/9] add failure case for langchain --- packages/ai-langchain/test/index.test.ts | 25 +++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/packages/ai-langchain/test/index.test.ts b/packages/ai-langchain/test/index.test.ts index d6abd16..1190ac9 100644 --- a/packages/ai-langchain/test/index.test.ts +++ b/packages/ai-langchain/test/index.test.ts @@ -128,16 +128,8 @@ describe("FGARetriever", () => { expect(mockClient.batchCheck).toBeCalledWith( { checks: [ - { - object: "doc:public-doc", - relation: "viewer", - user: "user:user1", - }, - { - object: "doc:private-doc", - relation: "viewer", - user: "user:user1", - }, + { object: "doc:public-doc", relation: "viewer", user: "user:user1" }, + { object: "doc:private-doc", relation: "viewer", user: "user:user1" }, ], }, { consistency: ConsistencyPreference.HigherConsistency } @@ -153,7 +145,6 @@ describe("FGARetriever", () => { result: [ { request: { object: "doc:public-doc" }, allowed: false }, { request: { object: "doc:private-doc" }, allowed: false }, - { request: { object: "doc:private-doc-2" }, allowed: false }, ], }); @@ -177,4 +168,16 @@ describe("FGARetriever", () => { const result = await tool.invoke({ query: "test query" }); expect(result).toEqual("public content\n\nprivate content"); }); + + it("should handle batchCheck error gracefully", async () => { + // @ts-ignore + mockClient.batchCheck = vi + .fn() + .mockRejectedValue(new Error("FGA API Error")); + + const retriever = FGARetriever.create(args, mockClient); + await expect(retriever.invoke("test query")).rejects.toThrow( + "FGA API Error" + ); + }); }); From 7e3adf3990a5cd204f195fbad630e4e75f5f5d80 Mon Sep 17 00:00:00 2001 From: Deepu Date: Mon, 27 Jan 2025 15:18:46 +0100 Subject: [PATCH 5/9] more efficient checks --- packages/ai-genkit/src/retrievers/fga-reranker.ts | 15 +++++---------- .../ai-langchain/src/retrievers/fga-retriever.ts | 15 +++++---------- .../ai-llamaindex/src/retrievers/fga-retriever.ts | 15 +++++---------- 3 files changed, 15 insertions(+), 30 deletions(-) diff --git a/packages/ai-genkit/src/retrievers/fga-reranker.ts b/packages/ai-genkit/src/retrievers/fga-reranker.ts index 6db0309..199ecbe 100644 --- a/packages/ai-genkit/src/retrievers/fga-reranker.ts +++ b/packages/ai-genkit/src/retrievers/fga-reranker.ts @@ -159,22 +159,17 @@ export class FGAReranker { const check = this.buildQuery(document); acc.documentToObjectMap.set(document, check.object); // Skip duplicate checks for same user, object, and relation - if ( - acc.checks.some( - (ex) => - ex.object === check.object && - ex.user === check.user && - ex.relation === check.relation - ) - ) { - return acc; + const checkKey = `${check.user}|${check.object}|${check.relation}`; + if (!acc.seenChecks.has(checkKey)) { + acc.seenChecks.add(checkKey); + acc.checks.push(check); } - acc.checks.push(check); return acc; }, { checks: [] as ClientBatchCheckItem[], documentToObjectMap: new Map(), + seenChecks: new Set(), } ); diff --git a/packages/ai-langchain/src/retrievers/fga-retriever.ts b/packages/ai-langchain/src/retrievers/fga-retriever.ts index 62925ad..64813de 100644 --- a/packages/ai-langchain/src/retrievers/fga-retriever.ts +++ b/packages/ai-langchain/src/retrievers/fga-retriever.ts @@ -144,17 +144,11 @@ export class FGARetriever extends BaseRetriever { const check = this.buildQuery(doc); acc.documentToObject.set(doc, check.object); // Skip duplicate checks for same user, object, and relation - if ( - acc.checks.some( - (ex) => - ex.object === check.object && - ex.user === check.user && - ex.relation === check.relation - ) - ) { - return acc; + const checkKey = `${check.user}|${check.object}|${check.relation}`; + if (!acc.seenChecks.has(checkKey)) { + acc.seenChecks.add(checkKey); + acc.checks.push(check); } - acc.checks.push(check); return acc; }, { @@ -163,6 +157,7 @@ export class FGARetriever extends BaseRetriever { DocumentInterface>, string >(), + seenChecks: new Set(), } ); diff --git a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts index cde92dc..0a21446 100644 --- a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts +++ b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts @@ -140,22 +140,17 @@ export class FGARetriever extends BaseRetriever { const check = this.buildQuery(nodeWithScore.node); acc.documentToObjectMap.set(nodeWithScore, check.object); // Skip duplicate checks for same user, object, and relation - if ( - acc.checks.some( - (ex) => - ex.object === check.object && - ex.user === check.user && - ex.relation === check.relation - ) - ) { - return acc; + const checkKey = `${check.user}|${check.object}|${check.relation}`; + if (!acc.seenChecks.has(checkKey)) { + acc.seenChecks.add(checkKey); + acc.checks.push(check); } - acc.checks.push(check); return acc; }, { checks: [] as ClientBatchCheckItem[], documentToObjectMap: new Map, string>(), + seenChecks: new Set(), } ); From 4025a9399c24e097ca923150ae456cbd1cd1b285 Mon Sep 17 00:00:00 2001 From: Deepu Date: Mon, 27 Jan 2025 16:21:50 +0100 Subject: [PATCH 6/9] make consistency configurable --- .../ai-genkit/src/retrievers/fga-reranker.ts | 8 ++++++-- .../ai-langchain/src/retrievers/fga-retriever.ts | 13 ++++++++----- .../src/retrievers/fga-retriever.ts | 16 +++++++++------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/packages/ai-genkit/src/retrievers/fga-reranker.ts b/packages/ai-genkit/src/retrievers/fga-reranker.ts index 199ecbe..81e52bd 100644 --- a/packages/ai-genkit/src/retrievers/fga-reranker.ts +++ b/packages/ai-genkit/src/retrievers/fga-reranker.ts @@ -12,6 +12,7 @@ export type FGARerankerCheckerFn = (doc: Document) => ClientBatchCheckItem; export type FGARerankerConstructorArgs = { buildQuery: FGARerankerCheckerFn; + consistency?: ConsistencyPreference; }; export type FGARerankerArgs = FGARerankerConstructorArgs & { @@ -45,6 +46,7 @@ export type FGARerankerArgs = FGARerankerConstructorArgs & { export class FGAReranker { lc_namespace = ["genkit", "rerankers", "fga-reranker"]; private buildQuery: FGARerankerCheckerFn; + private consistency: ConsistencyPreference; private fgaClient: OpenFgaClient; static lc_name() { @@ -52,10 +54,11 @@ export class FGAReranker { } private constructor( - { buildQuery }: FGARerankerConstructorArgs, + { buildQuery, consistency }: FGARerankerConstructorArgs, fgaClient?: OpenFgaClient ) { this.buildQuery = buildQuery; + this.consistency = consistency; this.fgaClient = fgaClient || new OpenFgaClient({ @@ -134,7 +137,8 @@ export class FGAReranker { const response = await this.fgaClient.batchCheck( { checks }, { - consistency: ConsistencyPreference.HigherConsistency, + consistency: + this.consistency || ConsistencyPreference.HigherConsistency, } ); diff --git a/packages/ai-langchain/src/retrievers/fga-retriever.ts b/packages/ai-langchain/src/retrievers/fga-retriever.ts index 64813de..7920b32 100644 --- a/packages/ai-langchain/src/retrievers/fga-retriever.ts +++ b/packages/ai-langchain/src/retrievers/fga-retriever.ts @@ -19,7 +19,8 @@ export type FGARetrieverCheckerFn = ( export type FGARetrieverArgs = { retriever: BaseRetriever; buildQuery: FGARetrieverCheckerFn; - fields?: BaseRetrieverInput; + consistency?: ConsistencyPreference; + retrieverFields?: BaseRetrieverInput; }; type AccessByDocumentFn = ( @@ -63,10 +64,10 @@ export class FGARetriever extends BaseRetriever { private constructor({ retriever, buildQuery, - fields, + retrieverFields, accessByDocument, }: FGARetrieverArgsWithAccessByDocument) { - super(fields); + super(retrieverFields); this.buildQuery = buildQuery; this.retriever = retriever; this.accessByDocument = accessByDocument as AccessByDocumentFn; @@ -78,7 +79,8 @@ export class FGARetriever extends BaseRetriever { * @param args - @FGARetrieverArgs * @param args.retriever - The underlying retriever instance to fetch documents. * @param args.buildQuery - A function to generate access check requests for each document. - * @param args.fields - Optional - Additional fields to pass to the underlying retriever. + * @param args.consistency - Optional - The consistency preference for the OpenFGA client. + * @param args.retrieverFields - Optional - Additional fields to pass to the underlying retriever. * @param fgaClient - Optional - OpenFgaClient instance to execute checks against. * @returns A newly created FGARetriever instance configured with the provided arguments. */ @@ -107,7 +109,8 @@ export class FGARetriever extends BaseRetriever { const response = await client.batchCheck( { checks }, { - consistency: ConsistencyPreference.HigherConsistency, + consistency: + args.consistency || ConsistencyPreference.HigherConsistency, } ); return response.result.reduce( diff --git a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts index 0a21446..151d696 100644 --- a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts +++ b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts @@ -20,6 +20,7 @@ export type FGARetrieverCheckerFn = ( export interface FGARetrieverArgs { buildQuery: FGARetrieverCheckerFn; retriever: BaseRetriever; + consistency?: ConsistencyPreference; } /** @@ -51,6 +52,7 @@ export class FGARetriever extends BaseRetriever { lc_namespace = ["llamaindex", "retrievers", "fga-retriever"]; private retriever: BaseRetriever; private buildQuery: FGARetrieverCheckerFn; + private consistency: ConsistencyPreference; private fgaClient: OpenFgaClient; static lc_name() { @@ -58,13 +60,14 @@ export class FGARetriever extends BaseRetriever { } private constructor( - { buildQuery, retriever }: FGARetrieverArgs, + { buildQuery, retriever, consistency }: FGARetrieverArgs, fgaClient?: OpenFgaClient ) { super(); this.retriever = retriever; this.buildQuery = buildQuery; + this.consistency = consistency; this.fgaClient = fgaClient || new OpenFgaClient({ @@ -89,14 +92,12 @@ export class FGARetriever extends BaseRetriever { * @param args - @FGARetrieverArgs * @param args.retriever - The underlying retriever instance to fetch documents. * @param args.buildQuery - A function to generate access check requests for each document. + * @param args.consistency - Optional - The consistency preference for the OpenFGA client. * @param fgaClient - Optional - OpenFgaClient instance to execute checks against. * @returns A newly created FGARetriever instance configured with the provided arguments. */ - static create( - { buildQuery, retriever }: FGARetrieverArgs, - fgaClient?: OpenFgaClient - ) { - return new FGARetriever({ buildQuery, retriever }, fgaClient); + static create(args: FGARetrieverArgs, fgaClient?: OpenFgaClient) { + return new FGARetriever(args, fgaClient); } /** @@ -111,7 +112,8 @@ export class FGARetriever extends BaseRetriever { const response = await this.fgaClient.batchCheck( { checks }, { - consistency: ConsistencyPreference.HigherConsistency, + consistency: + this.consistency || ConsistencyPreference.HigherConsistency, } ); From 38a5aab11bb9c87555eadd88406c14cd3e56414f Mon Sep 17 00:00:00 2001 From: Deepu Date: Tue, 28 Jan 2025 12:03:44 +0100 Subject: [PATCH 7/9] refactor langchain retriever to check for user and relation as well --- .../src/retrievers/fga-retriever.ts | 114 +++++++++--------- packages/ai-langchain/test/index.test.ts | 54 ++++++++- 2 files changed, 107 insertions(+), 61 deletions(-) diff --git a/packages/ai-langchain/src/retrievers/fga-retriever.ts b/packages/ai-langchain/src/retrievers/fga-retriever.ts index 7920b32..d330117 100644 --- a/packages/ai-langchain/src/retrievers/fga-retriever.ts +++ b/packages/ai-langchain/src/retrievers/fga-retriever.ts @@ -23,14 +23,6 @@ export type FGARetrieverArgs = { retrieverFields?: BaseRetrieverInput; }; -type AccessByDocumentFn = ( - checks: ClientBatchCheckItem[] -) => Promise>; - -type FGARetrieverArgsWithAccessByDocument = FGARetrieverArgs & { - accessByDocument: AccessByDocumentFn; -}; - /** * A retriever that allows filtering documents based on access control checks * using OpenFGA. This class wraps an underlying retriever and performs batch @@ -59,18 +51,33 @@ export class FGARetriever extends BaseRetriever { lc_namespace = ["@langchain", "retrievers"]; private retriever: BaseRetriever; private buildQuery: FGARetrieverCheckerFn; - private accessByDocument: AccessByDocumentFn; + private consistency: ConsistencyPreference; + private fgaClient: OpenFgaClient; - private constructor({ - retriever, - buildQuery, - retrieverFields, - accessByDocument, - }: FGARetrieverArgsWithAccessByDocument) { + private constructor( + { buildQuery, retriever, consistency, retrieverFields }: FGARetrieverArgs, + fgaClient?: OpenFgaClient + ) { super(retrieverFields); this.buildQuery = buildQuery; this.retriever = retriever; - this.accessByDocument = accessByDocument as AccessByDocumentFn; + this.consistency = consistency; + this.fgaClient = + fgaClient || + new OpenFgaClient({ + apiUrl: process.env.FGA_API_URL || "https://api.us1.fga.dev", + storeId: process.env.FGA_STORE_ID!, + credentials: { + method: CredentialsMethod.ClientCredentials, + config: { + apiTokenIssuer: process.env.FGA_API_TOKEN_ISSUER || "auth.fga.dev", + apiAudience: + process.env.FGA_API_AUDIENCE || "https://api.us1.fga.dev/", + clientId: process.env.FGA_CLIENT_ID!, + clientSecret: process.env.FGA_CLIENT_SECRET!, + }, + }, + }); } /** @@ -88,41 +95,7 @@ export class FGARetriever extends BaseRetriever { args: FGARetrieverArgs, fgaClient?: OpenFgaClient ): FGARetriever { - const client = - fgaClient || - new OpenFgaClient({ - apiUrl: process.env.FGA_API_URL || "https://api.us1.fga.dev", - storeId: process.env.FGA_STORE_ID!, - credentials: { - method: CredentialsMethod.ClientCredentials, - config: { - apiTokenIssuer: process.env.FGA_API_TOKEN_ISSUER || "auth.fga.dev", - apiAudience: - process.env.FGA_API_AUDIENCE || "https://api.us1.fga.dev/", - clientId: process.env.FGA_CLIENT_ID!, - clientSecret: process.env.FGA_CLIENT_SECRET!, - }, - }, - }); - - const accessByDocument: AccessByDocumentFn = async function (checks) { - const response = await client.batchCheck( - { checks }, - { - consistency: - args.consistency || ConsistencyPreference.HigherConsistency, - } - ); - return response.result.reduce( - (permissionMap: Map, result) => { - permissionMap.set(result.request.object, result.allowed || false); - return permissionMap; - }, - new Map() - ); - }; - - return new FGARetriever({ ...args, accessByDocument }); + return new FGARetriever(args, fgaClient); } /** @@ -145,9 +118,9 @@ export class FGARetriever extends BaseRetriever { const { checks, documentToObject } = documents.reduce( (acc, doc) => { const check = this.buildQuery(doc); - acc.documentToObject.set(doc, check.object); + const checkKey = this.getCheckKey(check); + acc.documentToObject.set(doc, checkKey); // Skip duplicate checks for same user, object, and relation - const checkKey = `${check.user}|${check.object}|${check.relation}`; if (!acc.seenChecks.has(checkKey)) { acc.seenChecks.add(checkKey); acc.checks.push(check); @@ -164,13 +137,44 @@ export class FGARetriever extends BaseRetriever { } ); - const resultsByObject = await this.accessByDocument(checks); + const permissionsMap = await this.checkPermissions(checks); return documents.filter( - (d, i) => resultsByObject.get(documentToObject.get(d) || "") === true + (d, i) => permissionsMap.get(documentToObject.get(d) || "") === true ); } + /** + * Checks permissions for a list of client requests. + * + * @param checks - An array of `ClientBatchCheckItem` objects representing the permissions to be checked. + * @returns A promise that resolves to a `Map` where the keys are object identifiers and the values are booleans indicating whether the permission is allowed. + */ + private async checkPermissions( + checks: ClientBatchCheckItem[] + ): Promise> { + const response = await this.fgaClient.batchCheck( + { checks }, + { + consistency: + this.consistency || ConsistencyPreference.HigherConsistency, + } + ); + + return response.result.reduce( + (permissionMap: Map, result) => { + const checkKey = this.getCheckKey(result.request); + permissionMap.set(checkKey, result.allowed || false); + return permissionMap; + }, + new Map() + ); + } + + private getCheckKey(check: ClientBatchCheckItem): string { + return `${check.user}|${check.object}|${check.relation}`; + } + /** * Converts the FGA retriever into a tool that can be used by a LangGraph agent. * @returns StructuredToolInterface. diff --git a/packages/ai-langchain/test/index.test.ts b/packages/ai-langchain/test/index.test.ts index 1190ac9..9294eee 100644 --- a/packages/ai-langchain/test/index.test.ts +++ b/packages/ai-langchain/test/index.test.ts @@ -69,8 +69,22 @@ describe("FGARetriever", () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -117,8 +131,22 @@ describe("FGARetriever", () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -159,8 +187,22 @@ describe("FGARetriever", () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: true }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, ], }); From 5a7d08530e90aa75496a2ffb00fb0114314f93be Mon Sep 17 00:00:00 2001 From: Deepu Date: Tue, 28 Jan 2025 12:04:00 +0100 Subject: [PATCH 8/9] refactor llamaindex retriever to check for user and relation as well --- .../src/retrievers/fga-retriever.ts | 11 ++++-- packages/ai-llamaindex/test/index.test.ts | 36 ++++++++++++++++--- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts index 151d696..3664991 100644 --- a/packages/ai-llamaindex/src/retrievers/fga-retriever.ts +++ b/packages/ai-llamaindex/src/retrievers/fga-retriever.ts @@ -119,13 +119,18 @@ export class FGARetriever extends BaseRetriever { return response.result.reduce( (permissionMap: Map, result) => { - permissionMap.set(result.request.object, result.allowed || false); + const checkKey = this.getCheckKey(result.request); + permissionMap.set(checkKey, result.allowed || false); return permissionMap; }, new Map() ); } + private getCheckKey(check: ClientBatchCheckItem): string { + return `${check.user}|${check.object}|${check.relation}`; + } + /** * Retrieves nodes based on the provided query parameters, processes * them through a checker function, @@ -140,9 +145,9 @@ export class FGARetriever extends BaseRetriever { const { checks, documentToObjectMap } = retrievedNodes.reduce( (acc, nodeWithScore: NodeWithScore) => { const check = this.buildQuery(nodeWithScore.node); - acc.documentToObjectMap.set(nodeWithScore, check.object); + const checkKey = this.getCheckKey(check); + acc.documentToObjectMap.set(nodeWithScore, checkKey); // Skip duplicate checks for same user, object, and relation - const checkKey = `${check.user}|${check.object}|${check.relation}`; if (!acc.seenChecks.has(checkKey)) { acc.seenChecks.add(checkKey); acc.checks.push(check); diff --git a/packages/ai-llamaindex/test/index.test.ts b/packages/ai-llamaindex/test/index.test.ts index b91768a..f326659 100644 --- a/packages/ai-llamaindex/test/index.test.ts +++ b/packages/ai-llamaindex/test/index.test.ts @@ -68,8 +68,22 @@ describe("FGARetriever", () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -117,8 +131,22 @@ describe("FGARetriever", () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); From 5a47d1f1772f468dee58b3fccc70756c5c47be27 Mon Sep 17 00:00:00 2001 From: Deepu Date: Tue, 28 Jan 2025 12:10:33 +0100 Subject: [PATCH 9/9] refactor genkit reranker to check for user and relation as well --- .../ai-genkit/src/retrievers/fga-reranker.ts | 11 ++- packages/ai-genkit/test/index.test.ts | 72 ++++++++++++++++--- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/packages/ai-genkit/src/retrievers/fga-reranker.ts b/packages/ai-genkit/src/retrievers/fga-reranker.ts index 81e52bd..12669ea 100644 --- a/packages/ai-genkit/src/retrievers/fga-reranker.ts +++ b/packages/ai-genkit/src/retrievers/fga-reranker.ts @@ -144,13 +144,18 @@ export class FGAReranker { return response.result.reduce( (permissionMap: Map, result) => { - permissionMap.set(result.request.object, result.allowed || false); + const checkKey = this.getCheckKey(result.request); + permissionMap.set(checkKey, result.allowed || false); return permissionMap; }, new Map() ); } + private getCheckKey(check: ClientBatchCheckItem): string { + return `${check.user}|${check.object}|${check.relation}`; + } + /** * Retrieves a filtered list of documents based on permission checks. * @@ -161,9 +166,9 @@ export class FGAReranker { const { checks, documentToObjectMap } = documents.reduce( (acc, document: Document) => { const check = this.buildQuery(document); - acc.documentToObjectMap.set(document, check.object); + const checkKey = this.getCheckKey(check); + acc.documentToObjectMap.set(document, checkKey); // Skip duplicate checks for same user, object, and relation - const checkKey = `${check.user}|${check.object}|${check.relation}`; if (!acc.seenChecks.has(checkKey)) { acc.seenChecks.add(checkKey); acc.checks.push(check); diff --git a/packages/ai-genkit/test/index.test.ts b/packages/ai-genkit/test/index.test.ts index 8736eae..19c651f 100644 --- a/packages/ai-genkit/test/index.test.ts +++ b/packages/ai-genkit/test/index.test.ts @@ -62,8 +62,22 @@ describe("FGAReranker", async () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -110,8 +124,22 @@ describe("FGAReranker", async () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: false }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -140,8 +168,22 @@ describe("FGAReranker", async () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:private-doc" }, allowed: false }, - { request: { object: "doc:public-doc" }, allowed: false }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: false, + }, ], }); @@ -184,8 +226,22 @@ describe("FGAReranker", async () => { // @ts-ignore mockClient.batchCheck = vi.fn().mockResolvedValue({ result: [ - { request: { object: "doc:public-doc" }, allowed: true }, - { request: { object: "doc:private-doc" }, allowed: true }, + { + request: { + object: "doc:public-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, + { + request: { + object: "doc:private-doc", + user: "user:user1", + relation: "viewer", + }, + allowed: true, + }, ], });