diff --git a/.changeset/open-years-itch.md b/.changeset/open-years-itch.md new file mode 100644 index 00000000..d995f648 --- /dev/null +++ b/.changeset/open-years-itch.md @@ -0,0 +1,5 @@ +--- +'@hono/node-ws': patch +--- + +enhance WebSocket connection handling with CORS support and connection symbols diff --git a/packages/node-ws/src/index.test.ts b/packages/node-ws/src/index.test.ts index 9815e4fe..a64a879b 100644 --- a/packages/node-ws/src/index.test.ts +++ b/packages/node-ws/src/index.test.ts @@ -3,6 +3,7 @@ import { serve } from '@hono/node-server' // @ts-ignore import type { ServerType } from '@hono/node-server/dist/types' import { Hono } from 'hono' +import { cors } from 'hono/cors' import type { WSMessageReceive } from 'hono/ws' import { WebSocket } from 'ws' import { createNodeWebSocket } from '.' @@ -244,4 +245,42 @@ describe('WebSocket helper', () => { createNodeWebSocket({ app }) }) }) + + it('Should client can connect when use cors()', async () => { + app.use(cors()) + const mainPromise = new Promise((resolve) => + app.get( + '/', + upgradeWebSocket(() => ({ + onOpen() { + resolve(true) + }, + })) + ) + ) + + new WebSocket('ws://localhost:3030/') + + expect(await mainPromise).toBe(true) + }) + it('Should client can connect even if a response has difference', async () => { + app.use(async (c, next) => { + c.res = new Response(null, c.res) + await next() + }) + const mainPromise = new Promise((resolve) => + app.get( + '/', + upgradeWebSocket(() => ({ + onOpen() { + resolve(true) + }, + })) + ) + ) + + new WebSocket('ws://localhost:3030/') + + expect(await mainPromise).toBe(true) + }) }) diff --git a/packages/node-ws/src/index.ts b/packages/node-ws/src/index.ts index ed0c5c79..8a87deb2 100644 --- a/packages/node-ws/src/index.ts +++ b/packages/node-ws/src/index.ts @@ -23,6 +23,11 @@ export interface NodeWebSocketInit { baseUrl?: string | URL } +const generateConnectionSymbol = () => Symbol('connection') + +/** @example `c.env[CONNECTION_SYMBOL_KEY]` */ +const CONNECTION_SYMBOL_KEY: unique symbol = Symbol('CONNECTION_SYMBOL_KEY') + /** * Create WebSockets for Node.js * @param init Options @@ -32,7 +37,7 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { const wss = new WebSocketServer({ noServer: true }) const waiterMap = new Map< IncomingMessage, - { resolve: (ws: WebSocket) => void; response: Response } + { resolve: (ws: WebSocket) => void; connectionSymbol: symbol } >() wss.on('connection', (ws, request) => { @@ -43,9 +48,9 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { } }) - const nodeUpgradeWebSocket = (request: IncomingMessage, response: Response) => { + const nodeUpgradeWebSocket = (request: IncomingMessage, connectionSymbol: symbol) => { return new Promise((resolve) => { - waiterMap.set(request, { resolve, response }) + waiterMap.set(request, { resolve, connectionSymbol }) }) } @@ -62,14 +67,18 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { headers.append(key, Array.isArray(value) ? value[0] : value) } - const response = await init.app.request( - url, - { headers: headers }, - { incoming: request, outgoing: undefined } - ) - + const env: { + incoming: IncomingMessage + outgoing: undefined + [CONNECTION_SYMBOL_KEY]?: symbol + } = { + incoming: request, + outgoing: undefined, + } + await init.app.request(url, { headers: headers }, env) const waiter = waiterMap.get(request) - if (!waiter || waiter.response !== response) { + + if (!waiter || waiter.connectionSymbol !== env[CONNECTION_SYMBOL_KEY]) { socket.end( 'HTTP/1.1 400 Bad Request\r\n' + 'Connection: close\r\n' + @@ -93,9 +102,10 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { return } - const response = new Response() + const connectionSymbol = generateConnectionSymbol() + c.env[CONNECTION_SYMBOL_KEY] = connectionSymbol ;(async () => { - const ws = await nodeUpgradeWebSocket(c.env.incoming, response) + const ws = await nodeUpgradeWebSocket(c.env.incoming, connectionSymbol) let events: WSEvents try { events = await createEvents(c) @@ -167,7 +177,7 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { }) })() - return response + return new Response() }, } }