Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for auth with token #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ After creating .env, please setup your real values for following variables:

1. LLM_API_KEY - Example dRAG uses LLM to formulate the answer. If you want to add your own variables for any other provider, model or any kind of data you need, you can do in Auth service -> UserConfig table, those variables will be available inside your DRAG:
```javascript
const userData = await authService.authenticateAndCache(req.sessionSid);
const userData = await authService.authenticateAndCache(req);
```

```sh
Expand Down
11 changes: 5 additions & 6 deletions controllers/exampleSparqlController.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export default {
try {
const { question, chatHistory } = req.body;

const userData = await authService.authenticateAndCache(req.sessionSid);
const userData = await authService.authenticateAndCache(req);

const llmService = new LLMService(userData);

Expand All @@ -22,14 +22,13 @@ export default {
chatHistory
);

const headline = await llmService.getDigitalDocumentTitle(standaloneQuestion);
const headline = await llmService.getDigitalDocumentTitle(
standaloneQuestion
);

const sparqlQuery = getSparqlQuery(headline);

const queryResults = await dkgService.query(
sparqlQuery,
userData
);
const queryResults = await dkgService.query(sparqlQuery, userData);

let result = queryResults[0];
if (result.length === 0) {
Expand Down
2 changes: 1 addition & 1 deletion controllers/exampleVectorController.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export default {
try {
const { question, chatHistory } = req.body;

const userData = await authService.authenticateAndCache(req.sessionSid);
const userData = await authService.authenticateAndCache(req);

const llmService = new LLMService(userData);

Expand Down
10 changes: 1 addition & 9 deletions middleware/auth.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
import authService from "../services/authService.js";

export function authenticateToken(req, res, next) {
const token = req.headers.authorization;

if (!token || !token.startsWith("Bearer ")) {
return res.status(401).json({ error: "Unauthorized" });
}

const authToken = token.split(" ")[1];

authService
.authenticateAndCache(authToken)
.authenticateAndCache(req)
.then((userData) => {
if (userData) {
next();
Expand Down
4 changes: 2 additions & 2 deletions services/authService.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import CacheService from "./cacheService.js";
import { authenticateToken } from "./userManagementService.js";

class AuthService {
async authenticateAndCache(cookie) {
async authenticateAndCache(req) {
try {
const { userData, expiresIn } = await authenticateToken(cookie);
const { userData, expiresIn } = await authenticateToken(req);

return userData;
} catch (error) {
Expand Down
156 changes: 104 additions & 52 deletions services/userManagementService.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,83 +15,135 @@ const baseUserData = {
apiKey: process.env.LLM_API_KEY,
};

export async function authenticateToken(cookie) {
export async function authenticateToken(req) {
// try {
// const response = await axios.get(`${AUTH_ENDPOINT}/auth/check`, {
// headers: { Cookie: `${COOKIE_NAME}=${cookie}` },
// withCredentials: "true",
// });
// } catch (error) {
// logger.error("Error fetching user config: " + error);
// }

try {
const response = await axios.get(`${AUTH_ENDPOINT}/auth/check`, {
headers: { Cookie: `${COOKIE_NAME}=${cookie}` },
withCredentials: "true",
});
const authHeader = req.headers["authorization"];

if (authHeader && authHeader.startsWith("Bearer ")) {
// Bearer token is present
const token = authHeader.split(" ")[1];

if (!token) {
throw Error("Invalid Bearer token format");
}

const response = await axios.get(`${AUTH_ENDPOINT}/auth/check`, {
headers: { Authorization: `Bearer ${token}` },
withCredentials: true,
});
return prepareResponse(response, baseUserData);
} else {
// Bearer token not present, check for session cookie
const sessionCookie = req.headers.cookie;

if (!sessionCookie) {
throw Error("No Bearer token or session cookie found");
}

const response = await axios.get(`${AUTH_ENDPOINT}/auth/check`, {
headers: { Cookie: sessionCookie },
withCredentials: true,
});
return prepareResponse(response, baseUserData);
}
} catch (error) {
console.error("Error during authentication:", error);
logger.error("Error fetching user config: " + error);
throw new Error("Internal server error");
}
}

const userConfig = response.data.user.config.find(
function prepareResponse(response, baseUserData) {
const userConfig =
response.data.user.config.find(
(cfg) => cfg.option === DRAG_USER_CONFIG_OPTION
) || null;

const edgeNodePublishMode = response.data.user.config.find(
const edgeNodePublishMode =
response.data.user.config.find(
(cfg) => cfg.option === "edge_node_publish_mode"
)?.value || null;

const edgeNodeParanetUAL = response.data.user.config.find(
(cfg) => cfg.option === "edge_node_paranet_ual"
const edgeNodeParanetUAL =
response.data.user.config.find(
(cfg) => cfg.option === "edge_node_paranet_ual"
)?.value || null;

const environment = response.data.user.config.find(
const environment =
response.data.user.config.find(
(cfg) => cfg.option === "edge_node_environment"
)?.value || null;

const runTimeNodeEndpoint = response.data.user.config.find(
const runTimeNodeEndpoint =
response.data.user.config.find(
(cfg) => cfg.option === "run_time_node_endpoint"
)?.value || null;

const blockchain = response.data.user.config.find(
(cfg) => cfg.option === "blockchain"
)?.value || null;
const blockchain =
response.data.user.config.find((cfg) => cfg.option === "blockchain")
?.value || null;

const vectorDBUri = response.data.user.config.find(
(cfg) => cfg.option === "milvus_address"
)?.value || null;
const vectorDBUri =
response.data.user.config.find((cfg) => cfg.option === "milvus_address")
?.value || null;

const vectorDBCredentials = response.data.user.config
const vectorDBCredentials =
response.data.user.config
.find((cfg) => cfg.option === "milvus_token")
?.value?.split(":") || null;

const vectorDBUsername = (vectorDBCredentials !== null && vectorDBCredentials.length > 0) ? vectorDBCredentials[0] : null;
const vectorDBPassword = (vectorDBCredentials !== null && vectorDBCredentials.length > 0) ? vectorDBCredentials[1] : null;

const vectorCollection = response.data.user.config
const vectorDBUsername =
vectorDBCredentials !== null && vectorDBCredentials.length > 0
? vectorDBCredentials[0]
: null;
const vectorDBPassword =
vectorDBCredentials !== null && vectorDBCredentials.length > 0
? vectorDBCredentials[1]
: null;

const vectorCollection =
response.data.user.config
.find((cfg) => cfg.option === "vector_collection")
?.value?.split(",") || null;

const embeddingModelAPIKey = response.data.user.config.find(
const embeddingModelAPIKey =
response.data.user.config.find(
(cfg) => cfg.option === "embedding_model_api_key"
)?.value || null;

const embeddingModel = response.data.user.config.find(
(cfg) => cfg.option === "embedding_model"
)?.value || null;

const cohereKey = response.data.user.config.find(
(cfg) => cfg.option === "cohere_key"
)?.value || null;

return {
userData: {
...baseUserData,
id: userConfig.id,
edgeNodePublishMode: edgeNodePublishMode,
paranetUAL: edgeNodeParanetUAL,
environment: environment,
endpoint: runTimeNodeEndpoint,
blockchain: blockchain,
vectorCollection: vectorCollection,
vectorDBUri: vectorDBUri,
vectorDBUsername: vectorDBUsername,
vectorDBPassword: vectorDBPassword,
embeddingModelAPIKey: embeddingModelAPIKey,
embeddingModel: embeddingModel,
cohereKey: cohereKey,
},
};
} catch (error) {
logger.error("Error fetching user config: " + error);
}
const embeddingModel =
response.data.user.config.find((cfg) => cfg.option === "embedding_model")
?.value || null;

const cohereKey =
response.data.user.config.find((cfg) => cfg.option === "cohere_key")
?.value || null;

return {
userData: {
...baseUserData,
id: userConfig.id,
edgeNodePublishMode: edgeNodePublishMode,
paranetUAL: edgeNodeParanetUAL,
environment: environment,
endpoint: runTimeNodeEndpoint,
blockchain: blockchain,
vectorCollection: vectorCollection,
vectorDBUri: vectorDBUri,
vectorDBUsername: vectorDBUsername,
vectorDBPassword: vectorDBPassword,
embeddingModelAPIKey: embeddingModelAPIKey,
embeddingModel: embeddingModel,
cohereKey: cohereKey,
},
};
}