252 lines
7.1 KiB
TypeScript
252 lines
7.1 KiB
TypeScript
|
import { Hono } from 'hono'
|
||
|
import { cloudflareAccess } from '../src'
|
||
|
import { describe, expect, it, vi } from 'vitest'
|
||
|
|
||
|
import crypto from 'crypto';
|
||
|
import { promisify } from 'util';
|
||
|
|
||
|
const generateKeyPair = promisify(crypto.generateKeyPair);
|
||
|
|
||
|
interface KeyPairResult {
|
||
|
publicKey: string;
|
||
|
privateKey: string;
|
||
|
}
|
||
|
|
||
|
interface JWK {
|
||
|
kid: string;
|
||
|
kty: string;
|
||
|
alg: string;
|
||
|
use: string;
|
||
|
e: string;
|
||
|
n: string;
|
||
|
}
|
||
|
|
||
|
async function generateJWTKeyPair(): Promise<KeyPairResult> {
|
||
|
try {
|
||
|
const { publicKey, privateKey } = await generateKeyPair('rsa', {
|
||
|
modulusLength: 2048,
|
||
|
publicKeyEncoding: {
|
||
|
type: 'spki',
|
||
|
format: 'pem'
|
||
|
},
|
||
|
privateKeyEncoding: {
|
||
|
type: 'pkcs8',
|
||
|
format: 'pem'
|
||
|
}
|
||
|
});
|
||
|
|
||
|
return {
|
||
|
publicKey,
|
||
|
privateKey
|
||
|
};
|
||
|
} catch (error) {
|
||
|
throw new Error(`Failed to generate key pair: ${(error as Error).message}`);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
function generateKeyThumbprint(modulusBase64: string): string {
|
||
|
const hash = crypto.createHash('sha256');
|
||
|
hash.update(Buffer.from(modulusBase64, 'base64'));
|
||
|
return hash.digest('hex');
|
||
|
}
|
||
|
|
||
|
function publicKeyToJWK(publicKey: string): JWK {
|
||
|
// Convert PEM to key object
|
||
|
const keyObject = crypto.createPublicKey(publicKey);
|
||
|
|
||
|
// Export the key in JWK format
|
||
|
const jwk = keyObject.export({ format: 'jwk' });
|
||
|
|
||
|
// Generate key ID using the modulus
|
||
|
const kid = generateKeyThumbprint(jwk.n as string);
|
||
|
|
||
|
return {
|
||
|
kid,
|
||
|
kty: 'RSA',
|
||
|
alg: 'RS256',
|
||
|
use: 'sig',
|
||
|
e: jwk.e as string,
|
||
|
n: jwk.n as string,
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
function base64URLEncode(str: string): string {
|
||
|
return Buffer.from(str)
|
||
|
.toString('base64')
|
||
|
.replace(/\+/g, '-')
|
||
|
.replace(/\//g, '_')
|
||
|
.replace(/=/g, '');
|
||
|
}
|
||
|
|
||
|
function generateJWT(privateKey: string, payload: Record<string, any>, expiresIn: number = 3600): string {
|
||
|
// Create header
|
||
|
const header = {
|
||
|
alg: 'RS256',
|
||
|
typ: 'JWT'
|
||
|
};
|
||
|
|
||
|
// Add expiration to payload
|
||
|
const now = Math.floor(Date.now() / 1000);
|
||
|
const fullPayload = {
|
||
|
...payload,
|
||
|
iat: now,
|
||
|
exp: now + expiresIn
|
||
|
};
|
||
|
|
||
|
// Encode header and payload
|
||
|
const encodedHeader = base64URLEncode(JSON.stringify(header));
|
||
|
const encodedPayload = base64URLEncode(JSON.stringify(fullPayload));
|
||
|
|
||
|
// Create signature
|
||
|
const signatureInput = `${encodedHeader}.${encodedPayload}`;
|
||
|
const signer = crypto.createSign('RSA-SHA256');
|
||
|
signer.update(signatureInput);
|
||
|
const signature = signer.sign(privateKey);
|
||
|
// @ts-ignore
|
||
|
const encodedSignature = base64URLEncode(signature);
|
||
|
|
||
|
// Combine all parts
|
||
|
return `${encodedHeader}.${encodedPayload}.${encodedSignature}`;
|
||
|
}
|
||
|
|
||
|
|
||
|
describe('Cloudflare Access middleware', async () => {
|
||
|
const keyPair1 = await generateJWTKeyPair();
|
||
|
const keyPair2 = await generateJWTKeyPair();
|
||
|
const keyPair3 = await generateJWTKeyPair();
|
||
|
|
||
|
vi.stubGlobal('fetch', async () => {
|
||
|
return Response.json({
|
||
|
keys: [
|
||
|
publicKeyToJWK(keyPair1.publicKey),
|
||
|
publicKeyToJWK(keyPair2.publicKey),
|
||
|
],
|
||
|
})
|
||
|
})
|
||
|
|
||
|
const app = new Hono()
|
||
|
|
||
|
app.use('/*', cloudflareAccess('my-cool-team-name'))
|
||
|
app.get('/hello-behind-access', (c) => c.text('foo'))
|
||
|
app.get('/access-payload', (c) => c.json(c.get('accessPayload')))
|
||
|
|
||
|
it('Should be throw Missing bearer token when nothing is sent', async () => {
|
||
|
const res = await app.request('http://localhost/hello-behind-access')
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(401)
|
||
|
expect(await res.text()).toBe('Authentication error: Missing bearer token')
|
||
|
})
|
||
|
|
||
|
it('Should be throw Unable to decode Bearer token when sending garbage', async () => {
|
||
|
const res = await app.request('http://localhost/hello-behind-access', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': 'asdasdasda'
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(401)
|
||
|
expect(await res.text()).toBe('Authentication error: Unable to decode Bearer token')
|
||
|
})
|
||
|
|
||
|
it('Should be throw Token is expired when sending expired token', async () => {
|
||
|
const token = generateJWT(keyPair1.privateKey, {
|
||
|
sub: '1234567890',
|
||
|
}, -3600);
|
||
|
|
||
|
const res = await app.request('http://localhost/hello-behind-access', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': token
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(401)
|
||
|
expect(await res.text()).toBe('Authentication error: Token is expired')
|
||
|
})
|
||
|
|
||
|
it('Should be throw Expected team name x, but received y when sending invalid iss', async () => {
|
||
|
const token = generateJWT(keyPair1.privateKey, {
|
||
|
sub: '1234567890',
|
||
|
iss: 'https://different-team.cloudflareaccess.com',
|
||
|
});
|
||
|
|
||
|
const res = await app.request('http://localhost/hello-behind-access', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': token
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(401)
|
||
|
expect(await res.text()).toBe('Authentication error: Expected team name https://my-cool-team-name.cloudflareaccess.com, but received https://different-team.cloudflareaccess.com')
|
||
|
})
|
||
|
|
||
|
it('Should be throw Invalid token when sending token signed with private key not in the allowed list', async () => {
|
||
|
const token = generateJWT(keyPair3.privateKey, {
|
||
|
sub: '1234567890',
|
||
|
iss: 'https://my-cool-team-name.cloudflareaccess.com',
|
||
|
});
|
||
|
|
||
|
const res = await app.request('http://localhost/hello-behind-access', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': token
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(401)
|
||
|
expect(await res.text()).toBe('Authentication error: Invalid Token')
|
||
|
})
|
||
|
|
||
|
it('Should work when sending everything correctly', async () => {
|
||
|
const token = generateJWT(keyPair1.privateKey, {
|
||
|
sub: '1234567890',
|
||
|
iss: 'https://my-cool-team-name.cloudflareaccess.com',
|
||
|
});
|
||
|
|
||
|
const res = await app.request('http://localhost/hello-behind-access', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': token
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(200)
|
||
|
expect(await res.text()).toBe('foo')
|
||
|
})
|
||
|
|
||
|
it('Should work with tokens signed by the 2º key in the public keys list', async () => {
|
||
|
const token = generateJWT(keyPair2.privateKey, {
|
||
|
sub: '1234567890',
|
||
|
iss: 'https://my-cool-team-name.cloudflareaccess.com',
|
||
|
});
|
||
|
|
||
|
const res = await app.request('http://localhost/hello-behind-access', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': token
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(200)
|
||
|
expect(await res.text()).toBe('foo')
|
||
|
})
|
||
|
|
||
|
it('Should be able to retrieve the JWT payload from Hono context', async () => {
|
||
|
const token = generateJWT(keyPair1.privateKey, {
|
||
|
sub: '1234567890',
|
||
|
iss: 'https://my-cool-team-name.cloudflareaccess.com',
|
||
|
});
|
||
|
|
||
|
const res = await app.request('http://localhost/access-payload', {
|
||
|
headers: {
|
||
|
'cf-access-jwt-assertion': token
|
||
|
}
|
||
|
})
|
||
|
expect(res).not.toBeNull()
|
||
|
expect(res.status).toBe(200)
|
||
|
expect(await res.json()).toEqual({
|
||
|
"sub":"1234567890",
|
||
|
"iss":"https://my-cool-team-name.cloudflareaccess.com",
|
||
|
"iat":expect.any(Number),
|
||
|
"exp":expect.any(Number)
|
||
|
})
|
||
|
})
|
||
|
})
|