diff --git a/.changeset/olive-ducks-end.md b/.changeset/olive-ducks-end.md new file mode 100644 index 00000000..3f357703 --- /dev/null +++ b/.changeset/olive-ducks-end.md @@ -0,0 +1,5 @@ +--- +'@hono/node-ws': patch +--- + +Fixed bug with multiple connections in node-ws diff --git a/packages/node-ws/src/index.test.ts b/packages/node-ws/src/index.test.ts index 42d951a0..0033a3d9 100644 --- a/packages/node-ws/src/index.test.ts +++ b/packages/node-ws/src/index.test.ts @@ -5,35 +5,105 @@ import { WebSocket } from 'ws' import { createNodeWebSocket } from '.' describe('WebSocket helper', () => { - const app = new Hono() - const { injectWebSocket, upgradeWebSocket } = createNodeWebSocket({ app }) + let app: Hono + let server: ServerType + let injectWebSocket: ReturnType['injectWebSocket'] + let upgradeWebSocket: ReturnType['upgradeWebSocket'] - const mainPromise = new Promise((resolve) => - app.get( - '/', - upgradeWebSocket(() => ({ - onOpen() { - resolve(true) - }, - })) - ) - ) + beforeEach(async () => { + app = new Hono() + ;({ injectWebSocket, upgradeWebSocket } = createNodeWebSocket({ app })) - it('Should be able to connect', async () => { - const server = await new Promise((resolve) => { - const server = serve( - { - fetch: app.fetch, - port: 3030, - }, - () => { - resolve(server) - } - ) + server = await new Promise((resolve) => { + const server = serve({ fetch: app.fetch, port: 3030 }, () => resolve(server)) }) injectWebSocket(server) + }) + + afterEach(() => { + server.close() + }) + + it('Should be able to connect', async () => { + const mainPromise = new Promise((resolve) => + app.get( + '/', + upgradeWebSocket(() => ({ + onOpen() { + resolve(true) + }, + })) + ) + ) + new WebSocket('ws://localhost:3030/') expect(await mainPromise).toBe(true) }) + + it('Should be able to send and receive messages', async () => { + const mainPromise = new Promise((resolve) => + app.get( + '/', + upgradeWebSocket(() => ({ + onMessage(data) { + resolve(data.data) + }, + })) + ) + ) + + const ws = new WebSocket('ws://localhost:3030/') + await new Promise((resolve) => ws.on('open', resolve)) + ws.send('Hello') + + expect(await mainPromise).toBe('Hello') + }) + + it('Should handle multiple concurrent connections', async () => { + const connectionCount = 5 + let openConnections = 0 + const messages: string[] = [] + + app.get( + '/', + upgradeWebSocket(() => ({ + onOpen() { + openConnections++ + }, + onMessage(data, ws) { + messages.push(data.data as string) + ws.send(data.data as string) + }, + })) + ) + + const connections = await Promise.all( + Array(connectionCount) + .fill(null) + .map(async () => { + const ws = new WebSocket('ws://localhost:3030/') + await new Promise((resolve) => ws.on('open', resolve)) + return ws + }) + ) + + expect(openConnections).toBe(connectionCount) + + await Promise.all( + connections.map((ws, index) => { + return new Promise((resolve) => { + ws.send(`Hello from connection ${index + 1}`) + ws.on('message', () => resolve()) + }) + }) + ) + + expect(messages.length).toBe(connectionCount) + messages.forEach((msg, index) => { + expect(msg).toBe(`Hello from connection ${index + 1}`) + }) + + connections.forEach((ws) => ws.close()) + }) }) diff --git a/packages/node-ws/src/index.ts b/packages/node-ws/src/index.ts index 34c4b1f0..2fd7f165 100644 --- a/packages/node-ws/src/index.ts +++ b/packages/node-ws/src/index.ts @@ -3,7 +3,9 @@ import type { Server } from 'node:http' import type { Http2SecureServer, Http2Server } from 'node:http2' import type { Hono } from 'hono' import type { UpgradeWebSocket, WSContext } from 'hono/ws' +import type { WebSocket } from 'ws' import { WebSocketServer } from 'ws' +import type { IncomingMessage } from 'http' export interface NodeWebSocket { upgradeWebSocket: UpgradeWebSocket @@ -20,7 +22,22 @@ export interface NodeWebSocketInit { * @returns NodeWebSocket */ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { - const wss = new WebSocketServer({noServer: true}) + const wss = new WebSocketServer({ noServer: true }) + const waiter = new Map void>() + + wss.on('connection', (ws, request) => { + const waiterFn = waiter.get(request) + if (waiterFn) { + waiterFn(ws) + waiter.delete(request) + } + }) + + const nodeUpgradeWebSocket = (request: IncomingMessage) => { + return new Promise((resolve) => { + waiter.set(request, resolve) + }) + } return { injectWebSocket(server) { @@ -34,9 +51,11 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { } headers.append(key, Array.isArray(value) ? value[0] : value) } - await init.app.request(url, { - headers: headers, - }) + await init.app.request( + url, + { headers: headers }, + { incoming: request, outgoing: undefined } + ) wss.handleUpgrade(request, socket, head, (ws) => { wss.emit('connection', ws, request) }) @@ -49,8 +68,11 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { await next() return } - const events = await createEvents(c) - wss.on('connection', (ws) => { + + ;(async () => { + const events = await createEvents(c) + const ws = await nodeUpgradeWebSocket(c.env.incoming) + const ctx: WSContext = { binaryType: 'arraybuffer', close(code, reason) { @@ -92,7 +114,7 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { ctx ) }) - }) + })() return new Response() },