diff --git a/lib/requestHandler.js b/lib/requestHandler.js index 57f2881..7f664c8 100644 --- a/lib/requestHandler.js +++ b/lib/requestHandler.js @@ -212,18 +212,39 @@ function fetchRemoteResponse(protocol, options, reqData, config) { @param @required wsClient the ws client of WebSocket * */ -function getWsReqInfo(wsClient) { - const upgradeReq = wsClient.upgradeReq || {}; - const header = upgradeReq.headers || {}; - const host = header.host; +function getWsReqInfo(wsReq) { + const headers = wsReq.headers || {}; + const host = headers.host; const hostName = host.split(':')[0]; const port = host.split(':')[1]; // TODO 如果是windows机器,url是不是全路径?需要对其过滤,取出 - const path = upgradeReq.url || '/'; + const path = wsReq.url || '/'; + + const isEncript = true && wsReq.connection && wsReq.connection.encrypted; + /** + * construct the request headers based on original connection, + * but delete the `sec-websocket-*` headers as they are already consumed by AnyProxy + */ + const getNoWsHeaders = () => { + const originHeaders = Object.assign({}, headers); + const originHeaderKeys = Object.keys(originHeaders); + originHeaderKeys.forEach((key) => { + // if the key matchs 'sec-websocket', delete it + if (/sec-websocket/ig.test(key)) { + delete originHeaders[key]; + } + }); + + delete originHeaders.connection; + delete originHeaders.upgrade; + return originHeaders; + } + - const isEncript = true && upgradeReq.connection && upgradeReq.connection.encrypted; return { + headers: headers, // the full headers of origin ws connection + noWsHeaders: getNoWsHeaders(), hostName: hostName, port: port, path: path, @@ -664,7 +685,7 @@ function getConnectReqHandler(userRule, recorder, httpsServerMgr) { * get a websocket event handler @param @required {object} wsClient */ -function getWsHandler(userRule, recorder, wsClient) { +function getWsHandler(userRule, recorder, wsClient, wsReq) { const self = this; try { let resourceInfoId = -1; @@ -672,10 +693,11 @@ function getWsHandler(userRule, recorder, wsClient) { wsMessages: [] // all ws messages go through AnyProxy }; const clientMsgQueue = []; - const serverInfo = getWsReqInfo(wsClient); + const serverInfo = getWsReqInfo(wsReq); const wsUrl = `${serverInfo.protocol}://${serverInfo.hostName}:${serverInfo.port}${serverInfo.path}`; const proxyWs = new WebSocket(wsUrl, '', { - rejectUnauthorized: !self.dangerouslyIgnoreUnauthorized + rejectUnauthorized: !self.dangerouslyIgnoreUnauthorized, + headers: serverInfo.noWsHeaders }); if (recorder) { @@ -684,7 +706,7 @@ function getWsHandler(userRule, recorder, wsClient) { method: 'WebSocket', path: serverInfo.path, url: wsUrl, - req: wsClient.upgradeReq || {}, + req: wsReq, startTime: new Date().getTime() }); resourceInfoId = recorder.appendRecord(resourceInfo); @@ -763,8 +785,9 @@ function getWsHandler(userRule, recorder, wsClient) { } // this event is fired when the connection is build and headers is returned - proxyWs.on('headers', (headers, response) => { + proxyWs.on('upgrade', (response) => { resourceInfo.endTime = new Date().getTime(); + const headers = response.headers; resourceInfo.res = { //construct a self-defined res object statusCode: response.statusCode, headers: headers, diff --git a/package.json b/package.json index 9268f5c..2274816 100644 --- a/package.json +++ b/package.json @@ -39,7 +39,7 @@ "stream-throttle": "^0.1.3", "svg-inline-react": "^1.0.2", "whatwg-fetch": "^1.0.0", - "ws": "^2.2.0" + "ws": "^5.1.0" }, "devDependencies": { "antd": "^2.5.0", @@ -69,7 +69,6 @@ "koa-body": "^1.4.0", "koa-router": "^5.4.0", "koa-send": "^3.2.0", - "koa-websocket": "^2.0.0", "less": "^2.7.1", "less-loader": "^2.2.3", "memwatch-next": "^0.3.0", diff --git a/test/server/server.js b/test/server/server.js index cdb33e7..fcab17f 100644 --- a/test/server/server.js +++ b/test/server/server.js @@ -7,7 +7,6 @@ const https = require('https'); const certMgr = require('../../lib/certMgr'); const fs = require('fs'); const nurl = require('url'); -const websocket = require('koa-websocket'); const color = require('colorful'); const WebSocketServer = require('ws').Server; const tls = require('tls'); @@ -65,6 +64,23 @@ function KoaServer() { yield next; }; + this.logWsRequest = function (wsReq) { + const headers = wsReq.headers; + const host = headers.host; + const isEncript = true && wsReq.connection && wsReq.connection.encrypted; + const protocol = isEncript ? 'wss' : 'ws'; + let key = `${protocol}://${host}${wsReq.url}`; + // take proxy data with 'proxy-' + url + if (headers['via-proxy'] === 'true') { + key = PROXY_KEY_PREFIX + key; + } + + self.requestRecordMap[key] = { + headers: wsReq.headers, + body: '' + } + }; + this.start(); } @@ -236,11 +252,14 @@ KoaServer.prototype.constructRouter = function () { return router; }; -KoaServer.prototype.constructWsRouter = function () { - const wsRouter = KoaRouter(); - const self = this; - wsRouter.get('/test/socket', function *(next) { - const ws = this.websocket; +KoaServer.prototype.createWsServer = function (httpServer) { + const wsServer = new WebSocketServer({ + server: httpServer, + path: '/test/socket' + }); + wsServer.on('connection', (ws, wsReq) => { + const self = this; + self.logWsRequest(wsReq); const messageObj = { type: 'initial', content: 'default message' @@ -251,10 +270,7 @@ KoaServer.prototype.constructWsRouter = function () { printLog('message from request socket: ' + message); self.handleRecievedMessage(ws, message); }); - yield next; - }); - - return wsRouter; + }) }; KoaServer.prototype.getRequestRecord = function (key) { @@ -277,14 +293,13 @@ KoaServer.prototype.handleRecievedMessage = function (ws, message) { KoaServer.prototype.start = function () { printLog('Starting the server...'); const router = this.constructRouter(); - const wsRouter = this.constructWsRouter(); const self = this; const app = Koa(); - websocket(app); app.use(router.routes()); - app.ws.use(wsRouter.routes()); this.httpServer = app.listen(DEFAULT_PORT); + this.createWsServer(this.httpServer); + printLog('HTTP is now listening on port :' + DEFAULT_PORT); @@ -303,7 +318,8 @@ KoaServer.prototype.start = function () { server: self.httpsServer }); - wss.on('connection', (ws) => { + wss.on('connection', (ws, wsReq) => { + self.logWsRequest(wsReq); ws.on('message', (message) => { printLog('received in wss: ' + message); self.handleRecievedMessage(ws, message); diff --git a/test/spec_rule/no_rule_websocket_spec.js b/test/spec_rule/no_rule_websocket_spec.js index 7feae15..9f197fb 100644 --- a/test/spec_rule/no_rule_websocket_spec.js +++ b/test/spec_rule/no_rule_websocket_spec.js @@ -6,7 +6,7 @@ const ProxyServerUtil = require('../util/ProxyServerUtil.js'); const { generateWsUrl, directWs, proxyWs } = require('../util/HttpUtil.js'); const Server = require('../server/server.js'); -const { printLog, isArrayEqual } = require('../util/CommonUtil.js'); +const { printLog, isArrayEqual, isCommonReqEqual } = require('../util/CommonUtil.js'); testWebsocket('ws'); testWebsocket('wss'); @@ -26,6 +26,11 @@ function testWebsocket(protocol, masked = false) { 'Send the message with default option4' ]; + const websocketHeaders = { + referer: 'https://www.anyproxy.io/websocket/test', + origin: 'www.anyproxy.io' + } + beforeAll((done) => { jasmine.DEFAULT_TIMEOUT_INTERVAL = 200000; printLog('Start server for no_rule_websocket_spec'); @@ -47,11 +52,11 @@ function testWebsocket(protocol, masked = false) { it('Default websocket option', done => { const directMessages = []; // set the flag for direct message, compare when both direct and proxy got message const proxyMessages = []; - let directHeaders; - let proxyHeaders; + let directResHeaders; + let proxyResHeaders; - const ws = directWs(url); - const proxyWsRef = proxyWs(url); + const ws = directWs(url, websocketHeaders); + const proxyWsRef = proxyWs(url, websocketHeaders); ws.on('open', () => { ws.send(testMessageArray[0], masked); for (let i = 1; i < testMessageArray.length; i++) { @@ -74,13 +79,13 @@ function testWebsocket(protocol, masked = false) { } }); - ws.on('headers', (headers) => { - directHeaders = headers; + ws.on('upgrade', (res) => { + directResHeaders = res.headers; compareMessageIfReady(); }); - proxyWsRef.on('headers', (headers) => { - proxyHeaders = headers; + proxyWsRef.on('upgrade', (res) => { + proxyResHeaders = res.headers; compareMessageIfReady(); }); @@ -114,12 +119,13 @@ function testWebsocket(protocol, masked = false) { const targetLen = testMessageArray.length; if (directMessages.length === targetLen && proxyMessages.length === targetLen - && directHeaders && proxyHeaders + && directResHeaders && proxyResHeaders ) { expect(isArrayEqual(directMessages, testMessageArray)).toBe(true); expect(isArrayEqual(directMessages, proxyMessages)).toBe(true); - expect(directHeaders['x-anyproxy-websocket']).toBeUndefined(); - expect(proxyHeaders['x-anyproxy-websocket']).toBe('true'); + expect(directResHeaders['x-anyproxy-websocket']).toBeUndefined(); + expect(proxyResHeaders['x-anyproxy-websocket']).toBe('true'); + expect(isCommonReqEqual(url, serverInstance)).toBe(true); done(); } } diff --git a/test/util/CommonUtil.js b/test/util/CommonUtil.js index 8ba1f94..1c4b7cf 100644 --- a/test/util/CommonUtil.js +++ b/test/util/CommonUtil.js @@ -120,6 +120,7 @@ function isCommonResHeaderEqual(directHeaders, proxyHeaders, requestUrl) { * */ function isCommonReqEqual(url, serverInstance) { + console.info('==> trying to get the url ', url); try { let isEqual = true; @@ -139,6 +140,27 @@ function isCommonReqEqual(url, serverInstance) { delete directReqObj.headers['transfer-encoding']; delete proxyReqObj.headers['transfer-encoding']; + // delete the headers that should not be passed by AnyProxy + delete directReqObj.headers.connection; + delete proxyReqObj.headers.connection; + + // delete the headers related to websocket establishment + const directHeaderKeys = Object.keys(directReqObj.headers); + directHeaderKeys.forEach((key) => { + // if the key matchs 'sec-websocket', delete it + if (/sec-websocket/ig.test(key)) { + delete directReqObj.headers[key]; + } + }); + + const proxyHeaderKeys = Object.keys(proxyReqObj.headers); + proxyHeaderKeys.forEach((key) => { + // if the key matchs 'sec-websocaket', delete it + if (/sec-websocket/ig.test(key)) { + delete proxyReqObj.headers[key]; + } + }); + isEqual = isEqual && directReqObj.url === proxyReqObj.url; isEqual = isEqual && isObjectEqual(directReqObj.headers, proxyReqObj.headers, url); isEqual = isEqual && directReqObj.body === proxyReqObj.body; diff --git a/test/util/HttpUtil.js b/test/util/HttpUtil.js index fb3d570..2c03002 100644 --- a/test/util/HttpUtil.js +++ b/test/util/HttpUtil.js @@ -187,17 +187,20 @@ function doUpload(url, method, filepath, formParams, headers = {}, isProxy) { return requestTask; } -function doWebSocket(url, isProxy) { +function doWebSocket(url, headers = {}, isProxy) { let ws; if (isProxy) { + headers['via-proxy'] = 'true'; const agent = new HttpsProxyAgent(SOCKET_PROXY_HOST); ws = new WebSocket(url, { agent, - rejectUnauthorized: false + rejectUnauthorized: false, + headers }); } else { ws = new WebSocket(url, { - rejectUnauthorized: false + rejectUnauthorized: false, + headers }); } @@ -252,12 +255,12 @@ function directOptions(url, headers = {}) { return directRequest('OPTIONS', url, {}, headers); } -function proxyWs(url) { - return doWebSocket(url, true); +function proxyWs(url, headers) { + return doWebSocket(url, headers, true); } -function directWs(url) { - return doWebSocket(url); +function directWs(url, headers) { + return doWebSocket(url, headers); } /**