diff --git a/middleware.js b/middleware.js index a3b20e7a170..6a1a0545218 100644 --- a/middleware.js +++ b/middleware.js @@ -21,24 +21,31 @@ export async function middleware(req) { const protocol = process.env.NODE_ENV === "development" ? "http" : "https"; const hostname = req.headers.get("host"); const reqPathName = req.nextUrl.pathname; - const sessionRequired = ["/account", "/api/account"]; - const adminRequired = ["/admin", "/api/admin"]; + const adminRequired = [ "/admin", "/api/admin" ]; + // Trailing slash is necessary to catch URL /account/statistics/* but not index page + const premiumRequired = [ "/account/statistics/" ]; const adminUsers = process.env.ADMIN_USERS.split(","); const hostedDomain = process.env.NEXT_PUBLIC_BASE_URL.replace( /http:\/\/|https:\/\//, "", ); - const hostedDomains = [hostedDomain, `www.${hostedDomain}`]; - + const hostedDomains = [ hostedDomain, `www.${hostedDomain}` ]; + const sessionRequired = [ "/account", "/api/account" ]; + if ( + !sessionRequired + .concat(adminRequired) + .some((path) => reqPathName.startsWith(path)) + ) { + return NextResponse.next(); + } // if custom domain + on root path if (!hostedDomains.includes(hostname) && reqPathName === "/") { console.log(`custom domain used: "${hostname}"`); let res; let profile; - let url = `${ - process.env.NEXT_PUBLIC_BASE_URL - }/api/search/${encodeURIComponent(hostname)}`; + let url = `${process.env.NEXT_PUBLIC_BASE_URL + }/api/search/${encodeURIComponent(hostname)}`; try { res = await fetch(url, { method: "GET", @@ -73,38 +80,44 @@ export async function middleware(req) { console.error(`custom domain NOT matched "${hostname}"`); } - // if not in sessionRequired or adminRequired, skip - if ( - !sessionRequired - .concat(adminRequired) - .some((path) => reqPathName.startsWith(path)) - ) { - return NextResponse.next(); - } - - const session = await getToken({ + // Check token existence or validity + const token = await getToken({ req: req, secret: process.env.NEXTAUTH_SECRET, }); - // if no session reject request - if (!session) { + console.log(token) + + // if no token reject request + if (!token) { if (reqPathName.startsWith("/api")) { return NextResponse.json({}, { status: 401 }); } return NextResponse.redirect(new URL("/auth/signin", req.url)); } - const username = session.username; - // if admin request check user is allowed - if (adminRequired.some((path) => reqPathName.startsWith(path))) { - if (!adminUsers.includes(username)) { - if (reqPathName.startsWith("/api")) { - return NextResponse.json({}, { status: 401 }); - } - return NextResponse.redirect(new URL("/404", req.url)); + // Premium path + const isPremiumRoute = premiumRequired.some((path) => reqPathName.startsWith(path)) + const isUserPremium = token.accountType === "premium" + if (isPremiumRoute && !isUserPremium) { + if (reqPathName.startsWith("/api")) { + return NextResponse.json({}, { status: 401 }); + } + return NextResponse.redirect(new URL("/pricing", req.url)) + } + + // Admin Path + const username = token.username; + const isAdminRoute = adminRequired.some((path) => reqPathName.startsWith(path)) + const isUserAdmin = adminUsers.includes(username) + if (isAdminRoute && !isUserAdmin) { + if (reqPathName.startsWith("/api")) { + return NextResponse.json({}, { status: 401 }); } + return NextResponse.redirect(new URL("/404", req.url)); } + // Allow request return NextResponse.next(); } + diff --git a/pages/account/statistics/link/[id].js b/pages/account/statistics/link/[id].js index a15c95dbe1f..d2e8ed73df4 100644 --- a/pages/account/statistics/link/[id].js +++ b/pages/account/statistics/link/[id].js @@ -19,14 +19,6 @@ const DynamicChart = dynamic( export async function getServerSideProps(context) { const { req, res } = context; const session = await getServerSession(req, res, authOptions); - if (session.accountType !== "premium") { - return { - redirect: { - destination: "/account/onboarding", - permanent: false, - }, - }; - } const username = session.username; const { status, profile } = await getUserApi(req, res, username); @@ -75,8 +67,8 @@ export default function Statistics({ data }) { )} {data.stats.length > 0 && ( -
-
+
+

Link clicks for {data.url}

diff --git a/pages/account/statistics/locations.js b/pages/account/statistics/locations.js index e266f7a2ea3..8d4f8451914 100644 --- a/pages/account/statistics/locations.js +++ b/pages/account/statistics/locations.js @@ -20,15 +20,6 @@ export async function getServerSideProps(context) { const { req, res } = context; const session = await getServerSession(req, res, authOptions); - if (session.accountType !== "premium") { - return { - redirect: { - destination: "/account/onboarding", - permanent: false, - }, - }; - } - const username = session.username; const { status, profile } = await getUserApi(req, res, username); if (status !== 200) { @@ -50,7 +41,7 @@ export async function getServerSideProps(context) { stats = Object.keys(profile.stats.countries) .map((country) => ({ country, - value: profile.stats.countries[country], + value: profile.stats.countries[ country ], })) .sort((a, b) => b.value - a.value); } @@ -99,14 +90,14 @@ export default function Locations({ stats }) { - + {stats && stats.map((item) => ( - + {item.country} - + {abbreviateNumber(item.value)} diff --git a/pages/account/statistics/referers.js b/pages/account/statistics/referers.js index fe28bedd92e..25c81e8f4ef 100644 --- a/pages/account/statistics/referers.js +++ b/pages/account/statistics/referers.js @@ -20,15 +20,6 @@ export async function getServerSideProps(context) { const { req, res } = context; const session = await getServerSession(req, res, authOptions); - if (session.accountType !== "premium") { - return { - redirect: { - destination: "/account/onboarding", - permanent: false, - }, - }; - } - const username = session.username; const { status, profile } = await getUserApi(req, res, username); if (status !== 200) { @@ -50,7 +41,7 @@ export async function getServerSideProps(context) { stats = Object.keys(profile.stats.referers) .map((referer) => ({ referer, - value: profile.stats.referers[referer], + value: profile.stats.referers[ referer ], })) .sort((a, b) => b.value - a.value); } @@ -99,14 +90,14 @@ export default function Locations({ stats }) { - + {stats && stats.map((item) => ( - + {item.referer.replaceAll("|", ".")} - + {abbreviateNumber(item.value)} diff --git a/pages/api/auth/[...nextauth].js b/pages/api/auth/[...nextauth].js index 14d5b8ce205..860e12fc6bf 100644 --- a/pages/api/auth/[...nextauth].js +++ b/pages/api/auth/[...nextauth].js @@ -60,6 +60,12 @@ export const authOptions = { token.id = profile.id; token.username = profile.login; } + const user = await User.findOne({ _id: token.sub }); + if (user) { + token.accountType = user.type; + } else { + token.accountType = "free"; + } return token; }, async session({ session, token }) { @@ -122,13 +128,13 @@ export const authOptions = { upsert: true, }, ); - const link = await Link.create([defaultLink(profile._id)], { + const link = await Link.create([ defaultLink(profile._id) ], { new: true, }); profile = await Profile.findOneAndUpdate( { username }, { - $push: { links: new ObjectId(link[0]._id) }, + $push: { links: new ObjectId(link[ 0 ]._id) }, }, { new: true }, ); @@ -159,13 +165,13 @@ export const authOptions = { // add github link to profile if no links exist if (profile.links.length === 0) { logger.info("no links found for: ", username); - const link = await Link.create([defaultLink(profile._id)], { + const link = await Link.create([ defaultLink(profile._id) ], { new: true, }); await Profile.findOneAndUpdate( { username }, { - $push: { links: new ObjectId(link[0]._id) }, + $push: { links: new ObjectId(link[ 0 ]._id) }, }, ); } diff --git a/tests/account/stats/location.spec.js b/tests/account/stats/location.spec.js index ce1271150f8..b6b65d13155 100644 --- a/tests/account/stats/location.spec.js +++ b/tests/account/stats/location.spec.js @@ -6,7 +6,7 @@ const premiumUser = { name: "Automated Test Premium User", email: "test-profile-user-6@test.com", username: "_test-profile-user-6", - type: "premium", + accountType: "premium", }; test("Guest user cannot access premium locations stats", async ({ @@ -25,7 +25,7 @@ test("Logged in free user cannot access premium locations stats", async ({ const page = await context.newPage(); await page.goto("/account/statistics/locations"); await page.waitForLoadState("networkidle"); - await expect(page).toHaveURL(/account\/onboarding/); + await expect(page).toHaveURL(/\/pricing/); }); test("Logged in premium user can access premium locations stats", async ({ @@ -46,7 +46,7 @@ test.describe("accessibility tests (light)", () => { const page = await context.newPage(); await page.goto("/account/statistics/locations"); const accessibilityScanResults = await new AxeBuilder({ page }) - .withTags(["wcag2a", "wcag2aa", "wcag21a", "wcag21aa"]) + .withTags([ "wcag2a", "wcag2aa", "wcag21a", "wcag21aa" ]) .analyze(); expect(accessibilityScanResults.violations).toEqual([]); }); @@ -62,7 +62,7 @@ test.describe("accessibility tests (dark)", () => { const page = await context.newPage(); await page.goto("/account/statistics/locations"); const accessibilityScanResults = await new AxeBuilder({ page }) - .withTags(["wcag2a", "wcag2aa", "wcag21a", "wcag21aa"]) + .withTags([ "wcag2a", "wcag2aa", "wcag21a", "wcag21aa" ]) .analyze(); expect(accessibilityScanResults.violations).toEqual([]); }); diff --git a/tests/account/stats/referer.spec.js b/tests/account/stats/referer.spec.js index 934cc47c2b0..c508c06c35b 100644 --- a/tests/account/stats/referer.spec.js +++ b/tests/account/stats/referer.spec.js @@ -23,7 +23,7 @@ test("Logged in free user cannot access premium referers stats", async ({ const page = await context.newPage(); await page.goto("/account/statistics/referers"); await page.waitForLoadState("networkidle"); - await expect(page).toHaveURL(/account\/onboarding/); + await expect(page).toHaveURL(/\/pricing/); }); test("Logged in premium user can access premium referers stats", async ({ @@ -44,7 +44,7 @@ test.describe("accessibility tests (light)", () => { const page = await context.newPage(); await page.goto("/account/statistics/referers"); const accessibilityScanResults = await new AxeBuilder({ page }) - .withTags(["wcag2a", "wcag2aa", "wcag21a", "wcag21aa"]) + .withTags([ "wcag2a", "wcag2aa", "wcag21a", "wcag21aa" ]) .analyze(); expect(accessibilityScanResults.violations).toEqual([]); }); @@ -60,7 +60,7 @@ test.describe("accessibility tests (dark)", () => { const page = await context.newPage(); await page.goto("/account/statistics/referers"); const accessibilityScanResults = await new AxeBuilder({ page }) - .withTags(["wcag2a", "wcag2aa", "wcag21a", "wcag21aa"]) + .withTags([ "wcag2a", "wcag2aa", "wcag21a", "wcag21aa" ]) .analyze(); expect(accessibilityScanResults.violations).toEqual([]); });