From aae5c9b0394c3f24b9f0c7b73faa1b574bb2408c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=A0=9A=E7=84=B6?= <yanran.wwj@alipay.com>
Date: Fri, 23 Mar 2018 15:42:38 +0800
Subject: [PATCH] pass through all the valid headers when proxing the
 WebSocket, and adds related test cases

---
 lib/requestHandler.js                    | 45 ++++++++++++++++++------
 package.json                             |  3 +-
 test/server/server.js                    | 44 +++++++++++++++--------
 test/spec_rule/no_rule_websocket_spec.js | 30 +++++++++-------
 test/util/CommonUtil.js                  | 22 ++++++++++++
 test/util/HttpUtil.js                    | 17 +++++----
 6 files changed, 115 insertions(+), 46 deletions(-)

diff --git a/lib/requestHandler.js b/lib/requestHandler.js
index 57f2881..7f664c8 100644
--- a/lib/requestHandler.js
+++ b/lib/requestHandler.js
@@ -212,18 +212,39 @@ function fetchRemoteResponse(protocol, options, reqData, config) {
  @param @required wsClient the ws client of WebSocket
 *
 */
-function getWsReqInfo(wsClient) {
-  const upgradeReq = wsClient.upgradeReq || {};
-  const header = upgradeReq.headers || {};
-  const host = header.host;
+function getWsReqInfo(wsReq) {
+  const headers = wsReq.headers || {};
+  const host = headers.host;
   const hostName = host.split(':')[0];
   const port = host.split(':')[1];
 
   // TODO 如果是windows机器,url是不是全路径?需要对其过滤,取出
-  const path = upgradeReq.url || '/';
+  const path = wsReq.url || '/';
+
+  const isEncript = true && wsReq.connection && wsReq.connection.encrypted;
+  /**
+   * construct the request headers based on original connection,
+   * but delete the `sec-websocket-*` headers as they are already consumed by AnyProxy
+   */
+  const getNoWsHeaders = () => {
+    const originHeaders = Object.assign({}, headers);
+    const originHeaderKeys = Object.keys(originHeaders);
+    originHeaderKeys.forEach((key) => {
+      // if the key matchs 'sec-websocket', delete it
+      if (/sec-websocket/ig.test(key)) {
+        delete originHeaders[key];
+      }
+    });
+
+    delete originHeaders.connection;
+    delete originHeaders.upgrade;
+    return originHeaders;
+  }
+
 
-  const isEncript = true && upgradeReq.connection && upgradeReq.connection.encrypted;
   return {
+    headers: headers, // the full headers of origin ws connection
+    noWsHeaders: getNoWsHeaders(),
     hostName: hostName,
     port: port,
     path: path,
@@ -664,7 +685,7 @@ function getConnectReqHandler(userRule, recorder, httpsServerMgr) {
 * get a websocket event handler
   @param @required {object} wsClient
 */
-function getWsHandler(userRule, recorder, wsClient) {
+function getWsHandler(userRule, recorder, wsClient, wsReq) {
   const self = this;
   try {
     let resourceInfoId = -1;
@@ -672,10 +693,11 @@ function getWsHandler(userRule, recorder, wsClient) {
       wsMessages: [] // all ws messages go through AnyProxy
     };
     const clientMsgQueue = [];
-    const serverInfo = getWsReqInfo(wsClient);
+    const serverInfo = getWsReqInfo(wsReq);
     const wsUrl = `${serverInfo.protocol}://${serverInfo.hostName}:${serverInfo.port}${serverInfo.path}`;
     const proxyWs = new WebSocket(wsUrl, '', {
-      rejectUnauthorized: !self.dangerouslyIgnoreUnauthorized
+      rejectUnauthorized: !self.dangerouslyIgnoreUnauthorized,
+      headers: serverInfo.noWsHeaders
     });
 
     if (recorder) {
@@ -684,7 +706,7 @@ function getWsHandler(userRule, recorder, wsClient) {
         method: 'WebSocket',
         path: serverInfo.path,
         url: wsUrl,
-        req: wsClient.upgradeReq || {},
+        req: wsReq,
         startTime: new Date().getTime()
       });
       resourceInfoId = recorder.appendRecord(resourceInfo);
@@ -763,8 +785,9 @@ function getWsHandler(userRule, recorder, wsClient) {
     }
 
     // this event is fired when the connection is build and headers is returned
-    proxyWs.on('headers', (headers, response) => {
+    proxyWs.on('upgrade', (response) => {
       resourceInfo.endTime = new Date().getTime();
+      const headers = response.headers;
       resourceInfo.res = { //construct a self-defined res object
         statusCode: response.statusCode,
         headers: headers,
diff --git a/package.json b/package.json
index 9268f5c..2274816 100644
--- a/package.json
+++ b/package.json
@@ -39,7 +39,7 @@
     "stream-throttle": "^0.1.3",
     "svg-inline-react": "^1.0.2",
     "whatwg-fetch": "^1.0.0",
-    "ws": "^2.2.0"
+    "ws": "^5.1.0"
   },
   "devDependencies": {
     "antd": "^2.5.0",
@@ -69,7 +69,6 @@
     "koa-body": "^1.4.0",
     "koa-router": "^5.4.0",
     "koa-send": "^3.2.0",
-    "koa-websocket": "^2.0.0",
     "less": "^2.7.1",
     "less-loader": "^2.2.3",
     "memwatch-next": "^0.3.0",
diff --git a/test/server/server.js b/test/server/server.js
index cdb33e7..fcab17f 100644
--- a/test/server/server.js
+++ b/test/server/server.js
@@ -7,7 +7,6 @@ const https = require('https');
 const certMgr = require('../../lib/certMgr');
 const fs = require('fs');
 const nurl = require('url');
-const websocket = require('koa-websocket');
 const color = require('colorful');
 const WebSocketServer = require('ws').Server;
 const tls = require('tls');
@@ -65,6 +64,23 @@ function KoaServer() {
     yield next;
   };
 
+  this.logWsRequest = function (wsReq) {
+    const headers = wsReq.headers;
+    const host = headers.host;
+    const isEncript = true && wsReq.connection && wsReq.connection.encrypted;
+    const protocol = isEncript ? 'wss' : 'ws';
+    let key = `${protocol}://${host}${wsReq.url}`;
+    // take proxy data with 'proxy-' + url
+    if (headers['via-proxy'] === 'true') {
+      key = PROXY_KEY_PREFIX + key;
+    }
+
+    self.requestRecordMap[key] = {
+      headers: wsReq.headers,
+      body: ''
+    }
+  };
+
   this.start();
 }
 
@@ -236,11 +252,14 @@ KoaServer.prototype.constructRouter = function () {
   return router;
 };
 
-KoaServer.prototype.constructWsRouter = function () {
-  const wsRouter = KoaRouter();
-  const self = this;
-  wsRouter.get('/test/socket', function *(next) {
-    const ws = this.websocket;
+KoaServer.prototype.createWsServer = function (httpServer) {
+  const wsServer = new WebSocketServer({
+    server: httpServer,
+    path: '/test/socket'
+  });
+  wsServer.on('connection', (ws, wsReq) => {
+    const self = this;
+    self.logWsRequest(wsReq);
     const messageObj = {
       type: 'initial',
       content: 'default message'
@@ -251,10 +270,7 @@ KoaServer.prototype.constructWsRouter = function () {
       printLog('message from request socket: ' + message);
       self.handleRecievedMessage(ws, message);
     });
-    yield next;
-  });
-
-  return wsRouter;
+  })
 };
 
 KoaServer.prototype.getRequestRecord = function (key) {
@@ -277,14 +293,13 @@ KoaServer.prototype.handleRecievedMessage = function (ws, message) {
 KoaServer.prototype.start = function () {
   printLog('Starting the server...');
   const router = this.constructRouter();
-  const wsRouter = this.constructWsRouter();
   const self = this;
   const app = Koa();
-  websocket(app);
 
   app.use(router.routes());
-  app.ws.use(wsRouter.routes());
   this.httpServer = app.listen(DEFAULT_PORT);
+  this.createWsServer(this.httpServer);
+
 
   printLog('HTTP is now listening on port :' + DEFAULT_PORT);
 
@@ -303,7 +318,8 @@ KoaServer.prototype.start = function () {
         server: self.httpsServer
       });
 
-      wss.on('connection', (ws) => {
+      wss.on('connection', (ws, wsReq) => {
+        self.logWsRequest(wsReq);
         ws.on('message', (message) => {
           printLog('received in wss: ' + message);
           self.handleRecievedMessage(ws, message);
diff --git a/test/spec_rule/no_rule_websocket_spec.js b/test/spec_rule/no_rule_websocket_spec.js
index 7feae15..9f197fb 100644
--- a/test/spec_rule/no_rule_websocket_spec.js
+++ b/test/spec_rule/no_rule_websocket_spec.js
@@ -6,7 +6,7 @@
 const ProxyServerUtil = require('../util/ProxyServerUtil.js');
 const { generateWsUrl, directWs, proxyWs } = require('../util/HttpUtil.js');
 const Server = require('../server/server.js');
-const { printLog, isArrayEqual } = require('../util/CommonUtil.js');
+const { printLog, isArrayEqual, isCommonReqEqual } = require('../util/CommonUtil.js');
 
 testWebsocket('ws');
 testWebsocket('wss');
@@ -26,6 +26,11 @@ function testWebsocket(protocol, masked = false) {
       'Send the message with default option4'
     ];
 
+    const websocketHeaders = {
+      referer: 'https://www.anyproxy.io/websocket/test',
+      origin: 'www.anyproxy.io'
+    }
+
     beforeAll((done) => {
       jasmine.DEFAULT_TIMEOUT_INTERVAL = 200000;
       printLog('Start server for no_rule_websocket_spec');
@@ -47,11 +52,11 @@ function testWebsocket(protocol, masked = false) {
     it('Default websocket option', done => {
       const directMessages = []; // set the flag for direct message, compare when both direct and proxy got message
       const proxyMessages = [];
-      let directHeaders;
-      let proxyHeaders;
+      let directResHeaders;
+      let proxyResHeaders;
 
-      const ws = directWs(url);
-      const proxyWsRef = proxyWs(url);
+      const ws = directWs(url, websocketHeaders);
+      const proxyWsRef = proxyWs(url, websocketHeaders);
       ws.on('open', () => {
         ws.send(testMessageArray[0], masked);
         for (let i = 1; i < testMessageArray.length; i++) {
@@ -74,13 +79,13 @@ function testWebsocket(protocol, masked = false) {
         }
       });
 
-      ws.on('headers', (headers) => {
-        directHeaders = headers;
+      ws.on('upgrade', (res) => {
+        directResHeaders = res.headers;
         compareMessageIfReady();
       });
 
-      proxyWsRef.on('headers', (headers) => {
-        proxyHeaders = headers;
+      proxyWsRef.on('upgrade', (res) => {
+        proxyResHeaders = res.headers;
         compareMessageIfReady();
       });
 
@@ -114,12 +119,13 @@ function testWebsocket(protocol, masked = false) {
         const targetLen = testMessageArray.length;
         if (directMessages.length === targetLen
           && proxyMessages.length === targetLen
-          && directHeaders && proxyHeaders
+          && directResHeaders && proxyResHeaders
         ) {
           expect(isArrayEqual(directMessages, testMessageArray)).toBe(true);
           expect(isArrayEqual(directMessages, proxyMessages)).toBe(true);
-          expect(directHeaders['x-anyproxy-websocket']).toBeUndefined();
-          expect(proxyHeaders['x-anyproxy-websocket']).toBe('true');
+          expect(directResHeaders['x-anyproxy-websocket']).toBeUndefined();
+          expect(proxyResHeaders['x-anyproxy-websocket']).toBe('true');
+          expect(isCommonReqEqual(url, serverInstance)).toBe(true);
           done();
         }
       }
diff --git a/test/util/CommonUtil.js b/test/util/CommonUtil.js
index 8ba1f94..1c4b7cf 100644
--- a/test/util/CommonUtil.js
+++ b/test/util/CommonUtil.js
@@ -120,6 +120,7 @@ function isCommonResHeaderEqual(directHeaders, proxyHeaders, requestUrl) {
 *
 */
 function isCommonReqEqual(url, serverInstance) {
+  console.info('==> trying to get the url ', url);
   try {
     let isEqual = true;
 
@@ -139,6 +140,27 @@ function isCommonReqEqual(url, serverInstance) {
     delete directReqObj.headers['transfer-encoding'];
     delete proxyReqObj.headers['transfer-encoding'];
 
+    // delete the headers that should not be passed by AnyProxy
+    delete directReqObj.headers.connection;
+    delete proxyReqObj.headers.connection;
+
+    // delete the headers related to websocket establishment
+    const directHeaderKeys = Object.keys(directReqObj.headers);
+    directHeaderKeys.forEach((key) => {
+      // if the key matchs 'sec-websocket', delete it
+      if (/sec-websocket/ig.test(key)) {
+        delete directReqObj.headers[key];
+      }
+    });
+
+    const proxyHeaderKeys = Object.keys(proxyReqObj.headers);
+    proxyHeaderKeys.forEach((key) => {
+      // if the key matchs 'sec-websocaket', delete it
+      if (/sec-websocket/ig.test(key)) {
+        delete proxyReqObj.headers[key];
+      }
+    });
+
     isEqual = isEqual && directReqObj.url === proxyReqObj.url;
     isEqual = isEqual && isObjectEqual(directReqObj.headers, proxyReqObj.headers, url);
     isEqual = isEqual && directReqObj.body === proxyReqObj.body;
diff --git a/test/util/HttpUtil.js b/test/util/HttpUtil.js
index fb3d570..2c03002 100644
--- a/test/util/HttpUtil.js
+++ b/test/util/HttpUtil.js
@@ -187,17 +187,20 @@ function doUpload(url, method, filepath, formParams, headers = {}, isProxy) {
   return requestTask;
 }
 
-function doWebSocket(url, isProxy) {
+function doWebSocket(url, headers = {}, isProxy) {
   let ws;
   if (isProxy) {
+    headers['via-proxy'] = 'true';
     const agent = new HttpsProxyAgent(SOCKET_PROXY_HOST);
     ws = new WebSocket(url, {
       agent,
-      rejectUnauthorized: false
+      rejectUnauthorized: false,
+      headers
     });
   } else {
     ws = new WebSocket(url, {
-      rejectUnauthorized: false
+      rejectUnauthorized: false,
+      headers
     });
   }
 
@@ -252,12 +255,12 @@ function directOptions(url, headers = {}) {
   return directRequest('OPTIONS', url, {}, headers);
 }
 
-function proxyWs(url) {
-  return doWebSocket(url, true);
+function proxyWs(url, headers) {
+  return doWebSocket(url, headers, true);
 }
 
-function directWs(url) {
-  return doWebSocket(url);
+function directWs(url, headers) {
+  return doWebSocket(url, headers);
 }
 
 /**