feat(node-ws): Reject unexpected WebSocket connections (#973)

* Added rejection of WebSocket connections when the app does not expect them

* added changeset

* Updated waiter names; removed strict option
pull/975/head
Dmytro Kulyk 2025-02-23 08:31:44 +02:00 committed by GitHub
parent 57b9f5dbdc
commit 6f90a574c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 66 additions and 11 deletions

View File

@ -0,0 +1,5 @@
---
'@hono/node-ws': minor
---
Added rejection of WebSocket connections when the app does not expect them

View File

@ -51,6 +51,40 @@ describe('WebSocket helper', () => {
expect(await mainPromise).toBe(true) expect(await mainPromise).toBe(true)
}) })
it('Should be rejected if upgradeWebSocket is not used', async () => {
app.get(
'/', (c)=>c.body('')
)
{
const ws = new WebSocket('ws://localhost:3030/')
const mainPromise = new Promise<boolean>((resolve) => {
ws.onerror = () => {
resolve(true)
}
ws.onopen = () => {
resolve(false)
}
})
expect(await mainPromise).toBe(true)
}
{ //also should rejected on fallback
const ws = new WebSocket('ws://localhost:3030/notFound')
const mainPromise = new Promise<boolean>((resolve) => {
ws.onerror = () => {
resolve(true)
}
ws.onopen = () => {
resolve(false)
}
})
expect(await mainPromise).toBe(true)
}
})
it('Should be able to connect', async () => { it('Should be able to connect', async () => {
const mainPromise = new Promise<boolean>((resolve) => const mainPromise = new Promise<boolean>((resolve) =>
app.get( app.get(

View File

@ -5,6 +5,7 @@ import { WebSocketServer } from 'ws'
import type { IncomingMessage } from 'http' import type { IncomingMessage } from 'http'
import type { Server } from 'node:http' import type { Server } from 'node:http'
import type { Http2SecureServer, Http2Server } from 'node:http2' import type { Http2SecureServer, Http2Server } from 'node:http2'
import type { Duplex } from 'node:stream'
import { CloseEvent } from './events' import { CloseEvent } from './events'
export interface NodeWebSocket { export interface NodeWebSocket {
@ -24,25 +25,25 @@ export interface NodeWebSocketInit {
*/ */
export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
const wss = new WebSocketServer({ noServer: true }) const wss = new WebSocketServer({ noServer: true })
const waiter = new Map<IncomingMessage, (ws: WebSocket) => void>() const waiterMap = new Map<IncomingMessage, { resolve: (ws: WebSocket) => void, response: Response }>()
wss.on('connection', (ws, request) => { wss.on('connection', (ws, request) => {
const waiterFn = waiter.get(request) const waiter = waiterMap.get(request)
if (waiterFn) { if (waiter) {
waiterFn(ws) waiter.resolve(ws)
waiter.delete(request) waiterMap.delete(request)
} }
}) })
const nodeUpgradeWebSocket = (request: IncomingMessage) => { const nodeUpgradeWebSocket = (request: IncomingMessage, response: Response) => {
return new Promise<WebSocket>((resolve) => { return new Promise<WebSocket>((resolve) => {
waiter.set(request, resolve) waiterMap.set(request, { resolve, response })
}) })
} }
return { return {
injectWebSocket(server) { injectWebSocket(server) {
server.on('upgrade', async (request, socket, head) => { server.on('upgrade', async (request, socket: Duplex, head) => {
const url = new URL(request.url ?? '/', init.baseUrl ?? 'http://localhost') const url = new URL(request.url ?? '/', init.baseUrl ?? 'http://localhost')
const headers = new Headers() const headers = new Headers()
for (const key in request.headers) { for (const key in request.headers) {
@ -52,11 +53,25 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
} }
headers.append(key, Array.isArray(value) ? value[0] : value) headers.append(key, Array.isArray(value) ? value[0] : value)
} }
await init.app.request(
const response = await init.app.request(
url, url,
{ headers: headers }, { headers: headers },
{ incoming: request, outgoing: undefined } { incoming: request, outgoing: undefined }
) )
const waiter = waiterMap.get(request)
if (!waiter || waiter.response !== response) {
socket.end(
'HTTP/1.1 400 Bad Request\r\n' +
'Connection: close\r\n' +
'Content-Length: 0\r\n' +
'\r\n'
)
waiterMap.delete(request)
return
}
wss.handleUpgrade(request, socket, head, (ws) => { wss.handleUpgrade(request, socket, head, (ws) => {
wss.emit('connection', ws, request) wss.emit('connection', ws, request)
}) })
@ -70,8 +85,9 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
return return
} }
const response = new Response()
;(async () => { ;(async () => {
const ws = await nodeUpgradeWebSocket(c.env.incoming) const ws = await nodeUpgradeWebSocket(c.env.incoming, response)
const events = await createEvents(c) const events = await createEvents(c)
const ctx: WSContext<WebSocket> = { const ctx: WSContext<WebSocket> = {
@ -116,7 +132,7 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => {
}) })
})() })()
return new Response() return response
}, },
} }
} }