Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenFGA/OktaFGA retriever to authorize user/document access when doing RAG #6629

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions libs/langchain-community/src/retrievers/fga.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { BaseRetriever, type BaseRetrieverInput } from "@langchain/core/retrievers";
import type { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager";
import { Document, DocumentInterface } from "@langchain/core/documents";
import { OpenFgaClient, ClientCheckRequest } from "@openfga/sdk";

type FGARetrieverArgs = {
user: string;
retriever: BaseRetriever;
fgaClient: OpenFgaClient;
fields?: BaseRetrieverInput;
checkFromDocument: (user: string, doc: DocumentInterface<Record<string, any>>, query: string) => ClientCheckRequest;
}

export class FGARetriever extends BaseRetriever {
static lc_name() {
return "FGARetriever";
}

lc_namespace = ["langchain", "retrievers", "fga"];
private retriever: BaseRetriever;
private checkFromDocument: (user: string, doc: DocumentInterface<Record<string, any>>, query: string) => ClientCheckRequest;
private user: string;
private fgaClient: OpenFgaClient;

constructor({ user, retriever, fgaClient, fields, checkFromDocument }: FGARetrieverArgs) {
super(fields);
this.user = user;
this.fgaClient = fgaClient;
this.retriever = retriever;
this.checkFromDocument = checkFromDocument;
}

private async accessByDocument(checks: ClientCheckRequest[]): Promise<Map<string, boolean>> {
const results = await this.fgaClient.batchCheck(checks);
return results.responses.reduce((c: Map<string, boolean>, v) => {
c.set(v._request.object, v.allowed || false);
return c;
}, new Map<string, boolean>());
}

async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise<Document[]> {
const documents = await this.retriever._getRelevantDocuments(query, runManager);
const out = documents.reduce((out, doc) => {
const check = this.checkFromDocument(this.user, doc, query);
out.checks.push(check);
out.documentToObject.set(doc, check.object);
return out;
}, { checks: [] as ClientCheckRequest[], documentToObject: new Map<DocumentInterface<Record<string, any>>, string>() });
const { checks, documentToObject } = out;
const resultsByObject = await this.accessByDocument(checks);

return documents.filter((d, _) => resultsByObject.get(documentToObject.get(d) || '') === true);
}
}
Empty file.
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,8 @@
"eslint --cache --fix"
],
"*.md": "prettier --config .prettierrc --write"
},
"dependencies": {
"@openfga/sdk": "^0.6.2"
}
}
34 changes: 34 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13551,6 +13551,18 @@ __metadata:
languageName: node
linkType: hard

"@openfga/sdk@npm:^0.6.2":
version: 0.6.2
resolution: "@openfga/sdk@npm:0.6.2"
dependencies:
"@opentelemetry/api": ^1.9.0
"@opentelemetry/semantic-conventions": ^1.25.0
axios: ^1.6.8
tiny-async-pool: ^2.1.0
checksum: 6b75080d450b0be659f334dc8e5fd82551a1794fb87b24eb68dc6791c1c14cc4dfb238d55d858a20d2fa2d58c9301980bde1b16bc738cafcc1d4f40ab3f5d5e5
languageName: node
linkType: hard

"@opensearch-project/opensearch@npm:^2.2.0":
version: 2.2.0
resolution: "@opensearch-project/opensearch@npm:2.2.0"
Expand All @@ -13571,6 +13583,20 @@ __metadata:
languageName: node
linkType: hard

"@opentelemetry/api@npm:^1.9.0":
version: 1.9.0
resolution: "@opentelemetry/api@npm:1.9.0"
checksum: 9e88e59d53ced668f3daaecfd721071c5b85a67dd386f1c6f051d1be54375d850016c881f656ffbe9a03bedae85f7e89c2f2b635313f9c9b195ad033cdc31020
languageName: node
linkType: hard

"@opentelemetry/semantic-conventions@npm:^1.25.0":
version: 1.25.1
resolution: "@opentelemetry/semantic-conventions@npm:1.25.1"
checksum: fea418a4b09c55121c6da11c49dd2105116533838c484aead17e8acf8029dad711e145849812f9c61f9e48fad8e2b6cf103d2c18847ca993032ce9b27c2f863d
languageName: node
linkType: hard

"@petamoriken/float16@npm:^3.8.6":
version: 3.8.7
resolution: "@petamoriken/float16@npm:3.8.7"
Expand Down Expand Up @@ -32327,6 +32353,7 @@ __metadata:
version: 0.0.0-use.local
resolution: "langchainjs@workspace:."
dependencies:
"@openfga/sdk": ^0.6.2
"@tsconfig/recommended": ^1.0.2
"@types/jest": ^29.5.3
"@types/semver": ^7
Expand Down Expand Up @@ -40441,6 +40468,13 @@ __metadata:
languageName: node
linkType: hard

"tiny-async-pool@npm:^2.1.0":
version: 2.1.0
resolution: "tiny-async-pool@npm:2.1.0"
checksum: 8891326f30e587590f94c5e1f8cab59c9aa305e442fc5b9f7ea997f8611d805a797aae2ea93cd00b42b494ef749353df38b4555e2a769d6fff31a3db7add7208
languageName: node
linkType: hard

"tiny-invariant@npm:^1.0.2":
version: 1.3.1
resolution: "tiny-invariant@npm:1.3.1"
Expand Down