fix(auth-js): bun req cloning (#598)

pull/602/head
divyam234 2024-07-01 06:27:02 +05:30 committed by GitHub
parent ea7ec3c6ed
commit eb7e597aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 13 deletions

View File

@ -0,0 +1,5 @@
---
'@hono/auth-js': patch
---
fix bun req cloning

View File

@ -4,10 +4,11 @@ import type { AdapterUser } from '@auth/core/adapters'
import type { JWT } from '@auth/core/jwt' import type { JWT } from '@auth/core/jwt'
import type { Session } from '@auth/core/types' import type { Session } from '@auth/core/types'
import type { Context, MiddlewareHandler } from 'hono' import type { Context, MiddlewareHandler } from 'hono'
import { env } from 'hono/adapter' import { env ,getRuntimeKey} from 'hono/adapter'
import { HTTPException } from 'hono/http-exception' import { HTTPException } from 'hono/http-exception'
import { setEnvDefaults as coreSetEnvDefaults } from '@auth/core' import { setEnvDefaults as coreSetEnvDefaults } from '@auth/core'
declare module 'hono' { declare module 'hono' {
interface ContextVariableMap { interface ContextVariableMap {
authUser: AuthUser authUser: AuthUser
@ -38,33 +39,64 @@ export function setEnvDefaults(env: AuthEnv, config: AuthConfig) {
coreSetEnvDefaults(env, config) coreSetEnvDefaults(env, config)
} }
export function reqWithEnvUrl(req: Request, authUrl?: string): Request { async function cloneRequest(input: URL | string, request: Request){
if ( getRuntimeKey() === "bun") {
return new Request(input, {
method: request.method,
headers:new Headers(request.headers),
body:
request.method === "GET" || request.method === "HEAD"
? undefined
: await request.blob(),
// @ts-ignore: TS2353
referrer: "referrer" in request ? (request.referrer as string) : undefined,
// deno-lint-ignore no-explicit-any
referrerPolicy: request.referrerPolicy as any,
mode: request.mode,
credentials: request.credentials,
// @ts-ignore: TS2353
cache: request.cache,
redirect: request.redirect,
integrity: request.integrity,
keepalive: request.keepalive,
signal: request.signal
})
}
return new Request(input, request)
}
export async function reqWithEnvUrl(req: Request, authUrl?: string){
if (authUrl) { if (authUrl) {
const reqUrlObj = new URL(req.url) const reqUrlObj = new URL(req.url)
const authUrlObj = new URL(authUrl) const authUrlObj = new URL(authUrl)
const props = ['hostname', 'protocol', 'port', 'password', 'username'] as const const props = ['hostname', 'protocol', 'port', 'password', 'username'] as const
props.forEach((prop) => (reqUrlObj[prop] = authUrlObj[prop])) props.forEach((prop) => (reqUrlObj[prop] = authUrlObj[prop]))
return new Request(reqUrlObj.href, req) return cloneRequest(reqUrlObj.href, req)
} else { } else {
const url = new URL(req.url) const url = new URL(req.url)
const proto = req.headers.get('x-forwarded-proto') const proto = req.headers.get('x-forwarded-proto')
const host = req.headers.get('x-forwarded-host') ?? req.headers.get('host') const host = req.headers.get('x-forwarded-host') ?? req.headers.get('host')
if (proto != null) url.protocol = proto.endsWith(':') ? proto : proto + ':' if (proto != null) url.protocol = proto.endsWith(':') ? proto : proto + ':'
if (host) { if (host!=null) {
url.host = host url.host = host
const portMatch = host.match(/:(\d+)$/) const portMatch = host.match(/:(\d+)$/)
if (portMatch) url.port = portMatch[1] if (portMatch) url.port = portMatch[1]
else url.port = '' else url.port = ''
req.headers.delete("x-forwarded-host")
req.headers.delete("Host")
req.headers.set("Host", host)
} }
return new Request(url.href, req) return cloneRequest(url.href, req)
} }
} }
export async function getAuthUser(c: Context): Promise<AuthUser | null> { export async function getAuthUser(c: Context): Promise<AuthUser | null> {
const config = c.get('authConfig') const config = c.get('authConfig')
let ctxEnv = env(c) as AuthEnv const ctxEnv = env(c) as AuthEnv
setEnvDefaults(ctxEnv, config) setEnvDefaults(ctxEnv, config)
const origin = new URL(reqWithEnvUrl(c.req.raw, ctxEnv.AUTH_URL).url).origin const authReq = await reqWithEnvUrl(c.req.raw, ctxEnv.AUTH_URL)
const origin = new URL(authReq.url).origin
const request = new Request(`${origin}${config.basePath}/session`, { const request = new Request(`${origin}${config.basePath}/session`, {
headers: { cookie: c.req.header('cookie') ?? '' }, headers: { cookie: c.req.header('cookie') ?? '' },
}) })
@ -117,7 +149,7 @@ export function initAuthConfig(cb: ConfigHandler): MiddlewareHandler {
export function authHandler(): MiddlewareHandler { export function authHandler(): MiddlewareHandler {
return async (c) => { return async (c) => {
const config = c.get('authConfig') const config = c.get('authConfig')
let ctxEnv = env(c) as AuthEnv const ctxEnv = env(c) as AuthEnv
setEnvDefaults(ctxEnv, config) setEnvDefaults(ctxEnv, config)
@ -125,7 +157,8 @@ export function authHandler(): MiddlewareHandler {
throw new HTTPException(500, { message: 'Missing AUTH_SECRET' }) throw new HTTPException(500, { message: 'Missing AUTH_SECRET' })
} }
const res = await Auth(reqWithEnvUrl(c.req.raw, ctxEnv.AUTH_URL), config) const authReq = await reqWithEnvUrl(c.req.raw, ctxEnv.AUTH_URL)
const res = await Auth(authReq, config)
return new Response(res.body, res) return new Response(res.body, res)
} }
} }

View File

@ -3,7 +3,7 @@ import { skipCSRFCheck } from '@auth/core'
import type { Adapter } from '@auth/core/adapters' import type { Adapter } from '@auth/core/adapters'
import Credentials from '@auth/core/providers/credentials' import Credentials from '@auth/core/providers/credentials'
import { Hono } from 'hono' import { Hono } from 'hono'
import { describe, expect, it, vi } from 'vitest' import { describe, expect, it, vi } from "vitest"
import type { AuthConfig } from '../src' import type { AuthConfig } from '../src'
import { authHandler, verifyAuth, initAuthConfig, reqWithEnvUrl } from '../src' import { authHandler, verifyAuth, initAuthConfig, reqWithEnvUrl } from '../src'
@ -78,9 +78,9 @@ describe('Config', () => {
}) })
}) })
describe('reqWithEnvUrl()', () => { describe('reqWithEnvUrl()', async() => {
const req = new Request('http://request-base/request-path') const req = new Request('http://request-base/request-path')
const newReq = reqWithEnvUrl(req, 'https://auth-url-base/auth-url-path') const newReq = await reqWithEnvUrl(req, 'https://auth-url-base/auth-url-path')
it('Should rewrite the base path', () => { it('Should rewrite the base path', () => {
expect(newReq.url.toString()).toBe('https://auth-url-base/request-path') expect(newReq.url.toString()).toBe('https://auth-url-base/request-path')
}) })
@ -126,6 +126,7 @@ describe('Credentials Provider', () => {
password: {}, password: {},
}, },
authorize: (credentials) => { authorize: (credentials) => {
if (credentials.password === 'password') { if (credentials.password === 'password') {
return user return user
} }
@ -200,4 +201,4 @@ describe('Credentials Provider', () => {
let html = await res.text() let html = await res.text()
expect(html).toContain('action="https://example.com/api/auth/callback/credentials"') expect(html).toContain('action="https://example.com/api/auth/callback/credentials"')
}) })
}) })