From e54c62875161c0d2b7e21c0d9108d311e23072a8 Mon Sep 17 00:00:00 2001 From: Wang Guan Date: Fri, 19 Jul 2024 07:43:18 +0900 Subject: [PATCH] feat(oauth-providers): allow custom redirect_uri (#601) * oauth-providers: allow set of redirect_uri * only check state in req with auth code * update readme.md * add changeset file * update readme * run lint:fix * add test * test other providers * revent unreleated test --- .changeset/cold-tigers-lie.md | 5 + packages/oauth-providers/README.md | 45 ++++++ .../src/providers/discord/discordAuth.ts | 21 +-- .../src/providers/facebook/facebookAuth.ts | 21 +-- .../src/providers/github/authFlow.ts | 4 +- .../src/providers/github/githubAuth.ts | 5 +- .../src/providers/google/googleAuth.ts | 19 +-- .../src/providers/linkedin/authFlow.ts | 2 +- .../src/providers/linkedin/linkedinAuth.ts | 19 +-- .../oauth-providers/src/providers/x/xAuth.ts | 3 +- packages/oauth-providers/test/handlers.ts | 6 +- packages/oauth-providers/test/index.test.ts | 133 +++++++++++++++--- 12 files changed, 219 insertions(+), 64 deletions(-) create mode 100644 .changeset/cold-tigers-lie.md diff --git a/.changeset/cold-tigers-lie.md b/.changeset/cold-tigers-lie.md new file mode 100644 index 00000000..07ce041e --- /dev/null +++ b/.changeset/cold-tigers-lie.md @@ -0,0 +1,5 @@ +--- +'@hono/oauth-providers': minor +--- + +allow override of redirect_uri diff --git a/packages/oauth-providers/README.md b/packages/oauth-providers/README.md index 80f4cc27..0376bd5f 100644 --- a/packages/oauth-providers/README.md +++ b/packages/oauth-providers/README.md @@ -918,6 +918,51 @@ app.post('/remove-user', async (c, next) => { }) ``` +## Advance Usage + +### Customize `redirect_uri` + +All the provider middlewares also accept a `redirect_uri` parameter that overrides the default `redirect_uri = c.req.url` behavior. + +This parameters can be useful if + +1. `hono` process cannot infer correct redirect_uri from the request. For example, when the server runs behind a reverse proxy and have no access to its internet hostname. +2. Or, in need to start oauth flow from a different route. +3. Or, in need to encode more info into `redirect_uri`. + +```ts +const app = new Hono(); + +const SITE_ORIGIN = `https://my-site.com`; +const OAUTH_CALLBACK_PATH = `/oauth/google`; + +app.get('/*', + async (c, next) => { + const session = readSession(c); + if (!session) { + // start oauth flow + const redirectUri = `${SITE_ORIGIN}${OAUTH_CALLBACK_PATH}?redirect=${encodeURIComponent(c.req.path)}`; + const oauth = googleAuth({ redirect_uri: redirectUri, ...more }); + return await oauth(c, next) + } + }, + async (c, next) => { + // if we are here, the req should contain either a valid session or a valid auth code + const session = readSession(c); + const authedGoogleUser = c.get('user-google') + if (authedGoogleUser) { + await saveSession(c, authedGoogleUser); + } else if (!session) { + throw new HttpException(401) + } + return next(); + }, + async (c, next) => { + // serve protected content + } +); +``` + ## Author monoald https://github.com/monoald diff --git a/packages/oauth-providers/src/providers/discord/discordAuth.ts b/packages/oauth-providers/src/providers/discord/discordAuth.ts index 8caac164..729bff38 100644 --- a/packages/oauth-providers/src/providers/discord/discordAuth.ts +++ b/packages/oauth-providers/src/providers/discord/discordAuth.ts @@ -1,5 +1,5 @@ import type { MiddlewareHandler } from 'hono' -import { setCookie, getCookie } from 'hono/cookie' +import { getCookie, setCookie } from 'hono/cookie' import { env } from 'hono/adapter' import { HTTPException } from 'hono/http-exception' @@ -11,6 +11,7 @@ export function discordAuth(options: { scope: Scopes[] client_id?: string client_secret?: string + redirect_uri?: string }): MiddlewareHandler { return async (c, next) => { // Generate encoded "keys" @@ -20,7 +21,7 @@ export function discordAuth(options: { const auth = new AuthFlow({ client_id: options.client_id || (env(c).DISCORD_ID as string), client_secret: options.client_secret || (env(c).DISCORD_SECRET as string), - redirect_uri: c.req.url.split('?')[0], + redirect_uri: options.redirect_uri || c.req.url.split('?')[0], scope: options.scope, state: newState, code: c.req.query('code'), @@ -30,14 +31,6 @@ export function discordAuth(options: { }, }) - // Avoid CSRF attack by checking state - if (c.req.url.includes('?')) { - const storedState = getCookie(c, 'state') - if (c.req.query('state') !== storedState) { - throw new HTTPException(401) - } - } - // Redirect to login dialog if (!auth.code) { setCookie(c, 'state', newState, { @@ -49,6 +42,14 @@ export function discordAuth(options: { return c.redirect(auth.redirect()) } + // Avoid CSRF attack by checking state + if (c.req.url.includes('?')) { + const storedState = getCookie(c, 'state') + if (c.req.query('state') !== storedState) { + throw new HTTPException(401) + } + } + // Retrieve user data from discord await auth.getUserData() diff --git a/packages/oauth-providers/src/providers/facebook/facebookAuth.ts b/packages/oauth-providers/src/providers/facebook/facebookAuth.ts index 07adde06..cdf75403 100644 --- a/packages/oauth-providers/src/providers/facebook/facebookAuth.ts +++ b/packages/oauth-providers/src/providers/facebook/facebookAuth.ts @@ -1,5 +1,5 @@ import type { MiddlewareHandler } from 'hono' -import { setCookie, getCookie } from 'hono/cookie' +import { getCookie, setCookie } from 'hono/cookie' import { env } from 'hono/adapter' import { HTTPException } from 'hono/http-exception' @@ -12,6 +12,7 @@ export function facebookAuth(options: { fields: Fields[] client_id?: string client_secret?: string + redirect_uri?: string }): MiddlewareHandler { return async (c, next) => { const newState = getRandomState() @@ -19,7 +20,7 @@ export function facebookAuth(options: { const auth = new AuthFlow({ client_id: options.client_id || (env(c).FACEBOOK_ID as string), client_secret: options.client_secret || (env(c).FACEBOOK_SECRET as string), - redirect_uri: c.req.url.split('?')[0], + redirect_uri: options.redirect_uri || c.req.url.split('?')[0], scope: options.scope, fields: options.fields, state: newState, @@ -30,14 +31,6 @@ export function facebookAuth(options: { }, }) - // Avoid CSRF attack by checking state - if (c.req.url.includes('?')) { - const storedState = getCookie(c, 'state') - if (c.req.query('state') !== storedState) { - throw new HTTPException(401) - } - } - // Redirect to login dialog if (!auth.code) { setCookie(c, 'state', newState, { @@ -49,6 +42,14 @@ export function facebookAuth(options: { return c.redirect(auth.redirect()) } + // Avoid CSRF attack by checking state + if (c.req.url.includes('?')) { + const storedState = getCookie(c, 'state') + if (c.req.query('state') !== storedState) { + throw new HTTPException(401) + } + } + // Retrieve user data from facebook await auth.getUserData() diff --git a/packages/oauth-providers/src/providers/github/authFlow.ts b/packages/oauth-providers/src/providers/github/authFlow.ts index 4344e115..97ff687b 100644 --- a/packages/oauth-providers/src/providers/github/authFlow.ts +++ b/packages/oauth-providers/src/providers/github/authFlow.ts @@ -2,11 +2,11 @@ import { HTTPException } from 'hono/http-exception' import { toQueryParams } from '../../utils/objectToQuery' import type { + GitHubEmailResponse, GitHubErrorResponse, + GitHubScope, GitHubTokenResponse, GitHubUser, - GitHubScope, - GitHubEmailResponse, } from './types' type GithubAuthFlow = { diff --git a/packages/oauth-providers/src/providers/github/githubAuth.ts b/packages/oauth-providers/src/providers/github/githubAuth.ts index eb71be1b..9729337d 100644 --- a/packages/oauth-providers/src/providers/github/githubAuth.ts +++ b/packages/oauth-providers/src/providers/github/githubAuth.ts @@ -12,6 +12,7 @@ export function githubAuth(options: { client_secret?: string scope?: GitHubScope[] oauthApp?: boolean + redirect_uri?: string }): MiddlewareHandler { return async (c, next) => { const newState = getRandomState() @@ -46,7 +47,9 @@ export function githubAuth(options: { // As such, we want to make sure we call back to the same location // for GitHub apps and not the first configured callbackURL in the app config. return c.redirect( - auth.redirect().concat(options.oauthApp ? '' : `&redirect_uri=${c.req.url}`) + auth + .redirect() + .concat(options.oauthApp ? '' : `&redirect_uri=${options.redirect_uri || c.req.url}`) ) } diff --git a/packages/oauth-providers/src/providers/google/googleAuth.ts b/packages/oauth-providers/src/providers/google/googleAuth.ts index bad111a8..eac4edd8 100644 --- a/packages/oauth-providers/src/providers/google/googleAuth.ts +++ b/packages/oauth-providers/src/providers/google/googleAuth.ts @@ -13,6 +13,7 @@ export function googleAuth(options: { client_id?: string client_secret?: string state?: string + redirect_uri?: string }): MiddlewareHandler { return async (c, next) => { const newState = options.state || getRandomState() @@ -20,7 +21,7 @@ export function googleAuth(options: { const auth = new AuthFlow({ client_id: options.client_id || (env(c).GOOGLE_ID as string), client_secret: options.client_secret || (env(c).GOOGLE_SECRET as string), - redirect_uri: c.req.url.split('?')[0], + redirect_uri: options.redirect_uri || c.req.url.split('?')[0], login_hint: options.login_hint, prompt: options.prompt, scope: options.scope, @@ -32,14 +33,6 @@ export function googleAuth(options: { }, }) - // Avoid CSRF attack by checking state - if (c.req.url.includes('?')) { - const storedState = getCookie(c, 'state') - if (c.req.query('state') !== storedState) { - throw new HTTPException(401) - } - } - // Redirect to login dialog if (!auth.code) { setCookie(c, 'state', newState, { @@ -51,6 +44,14 @@ export function googleAuth(options: { return c.redirect(auth.redirect()) } + // Avoid CSRF attack by checking state + if (c.req.url.includes('?')) { + const storedState = getCookie(c, 'state') + if (c.req.query('state') !== storedState) { + throw new HTTPException(401) + } + } + // Retrieve user data from google await auth.getUserData() diff --git a/packages/oauth-providers/src/providers/linkedin/authFlow.ts b/packages/oauth-providers/src/providers/linkedin/authFlow.ts index 597dbcb9..cdf5665b 100644 --- a/packages/oauth-providers/src/providers/linkedin/authFlow.ts +++ b/packages/oauth-providers/src/providers/linkedin/authFlow.ts @@ -4,9 +4,9 @@ import type { Token } from '../../types' import { toQueryParams } from '../../utils/objectToQuery' import type { LinkedInErrorResponse, + LinkedInScope, LinkedInTokenResponse, LinkedInUser, - LinkedInScope, } from './types' export type LinkedInAuthFlow = { diff --git a/packages/oauth-providers/src/providers/linkedin/linkedinAuth.ts b/packages/oauth-providers/src/providers/linkedin/linkedinAuth.ts index 71eacb8e..297460dc 100644 --- a/packages/oauth-providers/src/providers/linkedin/linkedinAuth.ts +++ b/packages/oauth-providers/src/providers/linkedin/linkedinAuth.ts @@ -12,6 +12,7 @@ export function linkedinAuth(options: { client_secret?: string scope?: LinkedInScope[] appAuth?: boolean + redirect_uri?: string }): MiddlewareHandler { return async (c, next) => { const newState = getRandomState() @@ -19,21 +20,13 @@ export function linkedinAuth(options: { const auth = new AuthFlow({ client_id: options.client_id || (env(c).LINKEDIN_ID as string), client_secret: options.client_secret || (env(c).LINKEDIN_SECRET as string), - redirect_uri: c.req.url.split('?')[0], + redirect_uri: options.redirect_uri || c.req.url.split('?')[0], scope: options.scope, state: newState, appAuth: options.appAuth || false, code: c.req.query('code'), }) - // Avoid CSRF attack by checking state - if (c.req.url.includes('?')) { - const storedState = getCookie(c, 'state') - if (c.req.query('state') !== storedState) { - throw new HTTPException(401) - } - } - // Redirect to login dialog if (!auth.code && !options.appAuth) { setCookie(c, 'state', newState, { @@ -45,6 +38,14 @@ export function linkedinAuth(options: { return c.redirect(auth.redirect()) } + // Avoid CSRF attack by checking state + if (c.req.url.includes('?')) { + const storedState = getCookie(c, 'state') + if (c.req.query('state') !== storedState) { + throw new HTTPException(401) + } + } + if (options.appAuth) { await auth.getAppToken() } else { diff --git a/packages/oauth-providers/src/providers/x/xAuth.ts b/packages/oauth-providers/src/providers/x/xAuth.ts index 886ddc0b..00f9ff95 100644 --- a/packages/oauth-providers/src/providers/x/xAuth.ts +++ b/packages/oauth-providers/src/providers/x/xAuth.ts @@ -13,6 +13,7 @@ export function xAuth(options: { fields?: XFields[] client_id?: string client_secret?: string + redirect_uri?: string }): MiddlewareHandler { return async (c, next) => { // Generate encoded "keys" @@ -22,7 +23,7 @@ export function xAuth(options: { const auth = new AuthFlow({ client_id: options.client_id || (env(c).X_ID as string), client_secret: options.client_secret || (env(c).X_SECRET as string), - redirect_uri: c.req.url.split('?')[0], + redirect_uri: options.redirect_uri || c.req.url.split('?')[0], scope: options.scope, fields: options.fields, state: newState, diff --git a/packages/oauth-providers/test/handlers.ts b/packages/oauth-providers/test/handlers.ts index 40e69bb1..ae17f113 100644 --- a/packages/oauth-providers/test/handlers.ts +++ b/packages/oauth-providers/test/handlers.ts @@ -8,11 +8,7 @@ import type { FacebookUser, } from '../src/providers/facebook' import type { GitHubErrorResponse, GitHubTokenResponse } from '../src/providers/github' -import type { - GoogleErrorResponse, - GoogleTokenResponse, - GoogleUser, -} from '../src/providers/google/types' +import type { GoogleErrorResponse, GoogleTokenResponse, GoogleUser } from '../src/providers/google' import type { LinkedInErrorResponse, LinkedInTokenResponse } from '../src/providers/linkedin' import type { XErrorResponse, XRevokeResponse, XTokenResponse } from '../src/providers/x' diff --git a/packages/oauth-providers/test/index.test.ts b/packages/oauth-providers/test/index.test.ts index 4fdc7b01..ac151843 100644 --- a/packages/oauth-providers/test/index.test.ts +++ b/packages/oauth-providers/test/index.test.ts @@ -2,9 +2,9 @@ import { Hono } from 'hono' import { setupServer } from 'msw/node' import type { DiscordUser } from '../src/providers/discord' import { + discordAuth, refreshToken as discordRefresh, revokeToken as discordRevoke, - discordAuth, } from '../src/providers/discord' import { facebookAuth } from '../src/providers/facebook' import type { FacebookUser } from '../src/providers/facebook' @@ -18,30 +18,30 @@ import type { XUser } from '../src/providers/x' import { refreshToken, revokeToken, xAuth } from '../src/providers/x' import type { Token } from '../src/types' import { + discordCodeError, + discordRefreshToken, + discordRefreshTokenError, + discordToken, + discordUser, + dummyCode, dummyToken, + facebookCodeError, + facebookUser, + githubCodeError, + githubToken, + githubUser, + googleCodeError, googleUser, handlers, - facebookUser, - githubUser, - dummyCode, - googleCodeError, - facebookCodeError, - githubToken, - githubCodeError, linkedInCodeError, - linkedInUser, linkedInToken, + linkedInUser, xCodeError, - xUser, - xToken, xRefreshToken, xRefreshTokenError, xRevokeTokenError, - discordCodeError, - discordUser, - discordToken, - discordRefreshToken, - discordRefreshTokenError, + xToken, + xUser, } from './handlers' const server = setupServer(...handlers) @@ -62,6 +62,14 @@ describe('OAuth Middleware', () => { scope: ['openid', 'email', 'profile'], }) ) + app.use('/google-custom-redirect', (c, next) => { + return googleAuth({ + client_id, + client_secret, + scope: ['openid', 'email', 'profile'], + redirect_uri: 'http://localhost:3000/google', + })(c, next) + }) app.get('/google', (c) => { const user = c.get('user-google') const token = c.get('token') @@ -93,6 +101,15 @@ describe('OAuth Middleware', () => { ], }) ) + app.use('/facebook-custom-redirect', (c, next) => + facebookAuth({ + client_id, + client_secret, + scope: [], + fields: [], + redirect_uri: 'http://localhost:3000/facebook', + })(c, next) + ) app.get('/facebook', (c) => { const user = c.get('user-facebook') const token = c.get('token') @@ -113,6 +130,13 @@ describe('OAuth Middleware', () => { client_secret, }) ) + app.use('/github/app-custom-redirect', (c, next) => + githubAuth({ + client_id, + client_secret, + redirect_uri: 'http://localhost:3000/github/app', + })(c, next) + ) app.get('/github/app', (c) => { const token = c.get('token') const refreshToken = c.get('refresh-token') @@ -156,6 +180,14 @@ describe('OAuth Middleware', () => { scope: ['email', 'openid', 'profile'], }) ) + app.use('/linkedin-custom-redirect', (c, next) => + linkedinAuth({ + client_id, + client_secret, + scope: ['email', 'openid', 'profile'], + redirect_uri: 'http://localhost:3000/linkedin', + })(c, next) + ) app.get('/linkedin', (c) => { const token = c.get('token') const refreshToken = c.get('refresh-token') @@ -194,6 +226,15 @@ describe('OAuth Middleware', () => { ], }) ) + app.use('/x-custom-redirect', (c, next) => + xAuth({ + client_id, + client_secret, + scope: [], + fields: [], + redirect_uri: 'http://localhost:3000/x', + })(c, next) + ) app.get('/x', (c) => { const token = c.get('token') const refreshToken = c.get('refresh-token') @@ -241,6 +282,14 @@ describe('OAuth Middleware', () => { scope: ['identify', 'email'], }) ) + app.use('/discord-custom-redirect', (c, next) => + discordAuth({ + client_id, + client_secret, + scope: ['identify', 'email'], + redirect_uri: 'http://localhost:3000/discord', + })(c, next) + ) app.get('/discord', (c) => { const token = c.get('token') const refreshToken = c.get('refresh-token') @@ -295,6 +344,16 @@ describe('OAuth Middleware', () => { expect(res).not.toBeNull() expect(res.status).toBe(302) + expect(res.headers) + }) + + it('Should redirect to custom redirect_uri', async () => { + const res = await app.request('/google-custom-redirect') + expect(res).not.toBeNull() + expect(res.status).toBe(302) + const redirectLocation = res.headers.get('location')! + const redirectUrl = new URL(redirectLocation) + expect(redirectUrl.searchParams.get('redirect_uri')).toBe('http://localhost:3000/google') }) it('Prevent CSRF attack', async () => { @@ -339,6 +398,14 @@ describe('OAuth Middleware', () => { expect(res.status).toBe(302) }) + it('Should redirect to custom redirect_uri', async () => { + const res = await app.request('/facebook-custom-redirect') + expect(res?.status).toBe(302) + const redirectLocation = res.headers.get('location')! + const redirectUrl = new URL(redirectLocation) + expect(redirectUrl.searchParams.get('redirect_uri')).toBe('http://localhost:3000/facebook') + }) + it('Prevent CSRF attack', async () => { const res = await app.request(`/facebook?code=${dummyCode}&state=malware-state`) expect(res).not.toBeNull() @@ -383,6 +450,16 @@ describe('OAuth Middleware', () => { expect(res.status).toBe(302) }) + it('Should redirect to custom redirect_uri', async () => { + const res = await app.request('/github/app-custom-redirect') + expect(res?.status).toBe(302) + const redirectLocation = res.headers.get('location')! + const redirectUrl = new URL(redirectLocation) + expect(redirectUrl.searchParams.get('redirect_uri')).toBe( + 'http://localhost:3000/github/app' + ) + }) + it('Should throw error for invalide code', async () => { const res = await app.request('/github/app?code=9348ffdsd-sdsdbad-code') @@ -459,6 +536,14 @@ describe('OAuth Middleware', () => { expect(res.status).toBe(302) }) + it('Should redirect to custom redirect_uri', async () => { + const res = await app.request('/linkedin-custom-redirect') + expect(res?.status).toBe(302) + const redirectLocation = res.headers.get('location')! + const redirectUrl = new URL(redirectLocation) + expect(redirectUrl.searchParams.get('redirect_uri')).toBe('http://localhost:3000/linkedin') + }) + it('Should throw error for invalide code', async () => { const res = await app.request('/linkedin?code=9348ffdsd-sdsdbad-code') @@ -500,6 +585,14 @@ describe('OAuth Middleware', () => { expect(res.status).toBe(302) }) + it('Should redirect to custom redirect_uri', async () => { + const res = await app.request('/x-custom-redirect') + expect(res?.status).toBe(302) + const redirectLocation = res.headers.get('location')! + const redirectUrl = new URL(redirectLocation) + expect(redirectUrl.searchParams.get('redirect_uri')).toBe('http://localhost:3000/x') + }) + it('Prevent CSRF attack', async () => { const res = await app.request(`/x?code=${dummyCode}&state=malware-state`) expect(res).not.toBeNull() @@ -588,6 +681,14 @@ describe('OAuth Middleware', () => { expect(res.status).toBe(302) }) + it('Should redirect to custom redirect_uri', async () => { + const res = await app.request('/discord-custom-redirect') + expect(res?.status).toBe(302) + const redirectLocation = res.headers.get('location')! + const redirectUrl = new URL(redirectLocation) + expect(redirectUrl.searchParams.get('redirect_uri')).toBe('http://localhost:3000/discord') + }) + it('Prevent CSRF attack', async () => { const res = await app.request(`/discord?code=${dummyCode}&state=malware-state`) expect(res).not.toBeNull()