Merge pull request #351 from alibaba/ws-header-fix

pass the headers when proxy websocket request
This commit is contained in:
Otto Mao 2018-04-26 21:14:40 +08:00 committed by GitHub
commit 10f84d0f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 115 additions and 46 deletions

View File

@ -212,18 +212,39 @@ function fetchRemoteResponse(protocol, options, reqData, config) {
@param @required wsClient the ws client of WebSocket @param @required wsClient the ws client of WebSocket
* *
*/ */
function getWsReqInfo(wsClient) { function getWsReqInfo(wsReq) {
const upgradeReq = wsClient.upgradeReq || {}; const headers = wsReq.headers || {};
const header = upgradeReq.headers || {}; const host = headers.host;
const host = header.host;
const hostName = host.split(':')[0]; const hostName = host.split(':')[0];
const port = host.split(':')[1]; const port = host.split(':')[1];
// TODO 如果是windows机器url是不是全路径需要对其过滤取出 // TODO 如果是windows机器url是不是全路径需要对其过滤取出
const path = upgradeReq.url || '/'; const path = wsReq.url || '/';
const isEncript = true && wsReq.connection && wsReq.connection.encrypted;
/**
* construct the request headers based on original connection,
* but delete the `sec-websocket-*` headers as they are already consumed by AnyProxy
*/
const getNoWsHeaders = () => {
const originHeaders = Object.assign({}, headers);
const originHeaderKeys = Object.keys(originHeaders);
originHeaderKeys.forEach((key) => {
// if the key matchs 'sec-websocket', delete it
if (/sec-websocket/ig.test(key)) {
delete originHeaders[key];
}
});
delete originHeaders.connection;
delete originHeaders.upgrade;
return originHeaders;
}
const isEncript = true && upgradeReq.connection && upgradeReq.connection.encrypted;
return { return {
headers: headers, // the full headers of origin ws connection
noWsHeaders: getNoWsHeaders(),
hostName: hostName, hostName: hostName,
port: port, port: port,
path: path, path: path,
@ -664,7 +685,7 @@ function getConnectReqHandler(userRule, recorder, httpsServerMgr) {
* get a websocket event handler * get a websocket event handler
@param @required {object} wsClient @param @required {object} wsClient
*/ */
function getWsHandler(userRule, recorder, wsClient) { function getWsHandler(userRule, recorder, wsClient, wsReq) {
const self = this; const self = this;
try { try {
let resourceInfoId = -1; let resourceInfoId = -1;
@ -672,10 +693,11 @@ function getWsHandler(userRule, recorder, wsClient) {
wsMessages: [] // all ws messages go through AnyProxy wsMessages: [] // all ws messages go through AnyProxy
}; };
const clientMsgQueue = []; const clientMsgQueue = [];
const serverInfo = getWsReqInfo(wsClient); const serverInfo = getWsReqInfo(wsReq);
const wsUrl = `${serverInfo.protocol}://${serverInfo.hostName}:${serverInfo.port}${serverInfo.path}`; const wsUrl = `${serverInfo.protocol}://${serverInfo.hostName}:${serverInfo.port}${serverInfo.path}`;
const proxyWs = new WebSocket(wsUrl, '', { const proxyWs = new WebSocket(wsUrl, '', {
rejectUnauthorized: !self.dangerouslyIgnoreUnauthorized rejectUnauthorized: !self.dangerouslyIgnoreUnauthorized,
headers: serverInfo.noWsHeaders
}); });
if (recorder) { if (recorder) {
@ -684,7 +706,7 @@ function getWsHandler(userRule, recorder, wsClient) {
method: 'WebSocket', method: 'WebSocket',
path: serverInfo.path, path: serverInfo.path,
url: wsUrl, url: wsUrl,
req: wsClient.upgradeReq || {}, req: wsReq,
startTime: new Date().getTime() startTime: new Date().getTime()
}); });
resourceInfoId = recorder.appendRecord(resourceInfo); resourceInfoId = recorder.appendRecord(resourceInfo);
@ -763,8 +785,9 @@ function getWsHandler(userRule, recorder, wsClient) {
} }
// this event is fired when the connection is build and headers is returned // this event is fired when the connection is build and headers is returned
proxyWs.on('headers', (headers, response) => { proxyWs.on('upgrade', (response) => {
resourceInfo.endTime = new Date().getTime(); resourceInfo.endTime = new Date().getTime();
const headers = response.headers;
resourceInfo.res = { //construct a self-defined res object resourceInfo.res = { //construct a self-defined res object
statusCode: response.statusCode, statusCode: response.statusCode,
headers: headers, headers: headers,

View File

@ -39,7 +39,7 @@
"stream-throttle": "^0.1.3", "stream-throttle": "^0.1.3",
"svg-inline-react": "^1.0.2", "svg-inline-react": "^1.0.2",
"whatwg-fetch": "^1.0.0", "whatwg-fetch": "^1.0.0",
"ws": "^2.2.0" "ws": "^5.1.0"
}, },
"devDependencies": { "devDependencies": {
"antd": "^2.5.0", "antd": "^2.5.0",
@ -69,7 +69,6 @@
"koa-body": "^1.4.0", "koa-body": "^1.4.0",
"koa-router": "^5.4.0", "koa-router": "^5.4.0",
"koa-send": "^3.2.0", "koa-send": "^3.2.0",
"koa-websocket": "^2.0.0",
"less": "^2.7.1", "less": "^2.7.1",
"less-loader": "^2.2.3", "less-loader": "^2.2.3",
"memwatch-next": "^0.3.0", "memwatch-next": "^0.3.0",

View File

@ -7,7 +7,6 @@ const https = require('https');
const certMgr = require('../../lib/certMgr'); const certMgr = require('../../lib/certMgr');
const fs = require('fs'); const fs = require('fs');
const nurl = require('url'); const nurl = require('url');
const websocket = require('koa-websocket');
const color = require('colorful'); const color = require('colorful');
const WebSocketServer = require('ws').Server; const WebSocketServer = require('ws').Server;
const tls = require('tls'); const tls = require('tls');
@ -65,6 +64,23 @@ function KoaServer() {
yield next; yield next;
}; };
this.logWsRequest = function (wsReq) {
const headers = wsReq.headers;
const host = headers.host;
const isEncript = true && wsReq.connection && wsReq.connection.encrypted;
const protocol = isEncript ? 'wss' : 'ws';
let key = `${protocol}://${host}${wsReq.url}`;
// take proxy data with 'proxy-' + url
if (headers['via-proxy'] === 'true') {
key = PROXY_KEY_PREFIX + key;
}
self.requestRecordMap[key] = {
headers: wsReq.headers,
body: ''
}
};
this.start(); this.start();
} }
@ -236,11 +252,14 @@ KoaServer.prototype.constructRouter = function () {
return router; return router;
}; };
KoaServer.prototype.constructWsRouter = function () { KoaServer.prototype.createWsServer = function (httpServer) {
const wsRouter = KoaRouter(); const wsServer = new WebSocketServer({
const self = this; server: httpServer,
wsRouter.get('/test/socket', function *(next) { path: '/test/socket'
const ws = this.websocket; });
wsServer.on('connection', (ws, wsReq) => {
const self = this;
self.logWsRequest(wsReq);
const messageObj = { const messageObj = {
type: 'initial', type: 'initial',
content: 'default message' content: 'default message'
@ -251,10 +270,7 @@ KoaServer.prototype.constructWsRouter = function () {
printLog('message from request socket: ' + message); printLog('message from request socket: ' + message);
self.handleRecievedMessage(ws, message); self.handleRecievedMessage(ws, message);
}); });
yield next; })
});
return wsRouter;
}; };
KoaServer.prototype.getRequestRecord = function (key) { KoaServer.prototype.getRequestRecord = function (key) {
@ -277,14 +293,13 @@ KoaServer.prototype.handleRecievedMessage = function (ws, message) {
KoaServer.prototype.start = function () { KoaServer.prototype.start = function () {
printLog('Starting the server...'); printLog('Starting the server...');
const router = this.constructRouter(); const router = this.constructRouter();
const wsRouter = this.constructWsRouter();
const self = this; const self = this;
const app = Koa(); const app = Koa();
websocket(app);
app.use(router.routes()); app.use(router.routes());
app.ws.use(wsRouter.routes());
this.httpServer = app.listen(DEFAULT_PORT); this.httpServer = app.listen(DEFAULT_PORT);
this.createWsServer(this.httpServer);
printLog('HTTP is now listening on port :' + DEFAULT_PORT); printLog('HTTP is now listening on port :' + DEFAULT_PORT);
@ -303,7 +318,8 @@ KoaServer.prototype.start = function () {
server: self.httpsServer server: self.httpsServer
}); });
wss.on('connection', (ws) => { wss.on('connection', (ws, wsReq) => {
self.logWsRequest(wsReq);
ws.on('message', (message) => { ws.on('message', (message) => {
printLog('received in wss: ' + message); printLog('received in wss: ' + message);
self.handleRecievedMessage(ws, message); self.handleRecievedMessage(ws, message);

View File

@ -6,7 +6,7 @@
const ProxyServerUtil = require('../util/ProxyServerUtil.js'); const ProxyServerUtil = require('../util/ProxyServerUtil.js');
const { generateWsUrl, directWs, proxyWs } = require('../util/HttpUtil.js'); const { generateWsUrl, directWs, proxyWs } = require('../util/HttpUtil.js');
const Server = require('../server/server.js'); const Server = require('../server/server.js');
const { printLog, isArrayEqual } = require('../util/CommonUtil.js'); const { printLog, isArrayEqual, isCommonReqEqual } = require('../util/CommonUtil.js');
testWebsocket('ws'); testWebsocket('ws');
testWebsocket('wss'); testWebsocket('wss');
@ -26,6 +26,11 @@ function testWebsocket(protocol, masked = false) {
'Send the message with default option4' 'Send the message with default option4'
]; ];
const websocketHeaders = {
referer: 'https://www.anyproxy.io/websocket/test',
origin: 'www.anyproxy.io'
}
beforeAll((done) => { beforeAll((done) => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 200000; jasmine.DEFAULT_TIMEOUT_INTERVAL = 200000;
printLog('Start server for no_rule_websocket_spec'); printLog('Start server for no_rule_websocket_spec');
@ -47,11 +52,11 @@ function testWebsocket(protocol, masked = false) {
it('Default websocket option', done => { it('Default websocket option', done => {
const directMessages = []; // set the flag for direct message, compare when both direct and proxy got message const directMessages = []; // set the flag for direct message, compare when both direct and proxy got message
const proxyMessages = []; const proxyMessages = [];
let directHeaders; let directResHeaders;
let proxyHeaders; let proxyResHeaders;
const ws = directWs(url); const ws = directWs(url, websocketHeaders);
const proxyWsRef = proxyWs(url); const proxyWsRef = proxyWs(url, websocketHeaders);
ws.on('open', () => { ws.on('open', () => {
ws.send(testMessageArray[0], masked); ws.send(testMessageArray[0], masked);
for (let i = 1; i < testMessageArray.length; i++) { for (let i = 1; i < testMessageArray.length; i++) {
@ -74,13 +79,13 @@ function testWebsocket(protocol, masked = false) {
} }
}); });
ws.on('headers', (headers) => { ws.on('upgrade', (res) => {
directHeaders = headers; directResHeaders = res.headers;
compareMessageIfReady(); compareMessageIfReady();
}); });
proxyWsRef.on('headers', (headers) => { proxyWsRef.on('upgrade', (res) => {
proxyHeaders = headers; proxyResHeaders = res.headers;
compareMessageIfReady(); compareMessageIfReady();
}); });
@ -114,12 +119,13 @@ function testWebsocket(protocol, masked = false) {
const targetLen = testMessageArray.length; const targetLen = testMessageArray.length;
if (directMessages.length === targetLen if (directMessages.length === targetLen
&& proxyMessages.length === targetLen && proxyMessages.length === targetLen
&& directHeaders && proxyHeaders && directResHeaders && proxyResHeaders
) { ) {
expect(isArrayEqual(directMessages, testMessageArray)).toBe(true); expect(isArrayEqual(directMessages, testMessageArray)).toBe(true);
expect(isArrayEqual(directMessages, proxyMessages)).toBe(true); expect(isArrayEqual(directMessages, proxyMessages)).toBe(true);
expect(directHeaders['x-anyproxy-websocket']).toBeUndefined(); expect(directResHeaders['x-anyproxy-websocket']).toBeUndefined();
expect(proxyHeaders['x-anyproxy-websocket']).toBe('true'); expect(proxyResHeaders['x-anyproxy-websocket']).toBe('true');
expect(isCommonReqEqual(url, serverInstance)).toBe(true);
done(); done();
} }
} }

View File

@ -120,6 +120,7 @@ function isCommonResHeaderEqual(directHeaders, proxyHeaders, requestUrl) {
* *
*/ */
function isCommonReqEqual(url, serverInstance) { function isCommonReqEqual(url, serverInstance) {
console.info('==> trying to get the url ', url);
try { try {
let isEqual = true; let isEqual = true;
@ -139,6 +140,27 @@ function isCommonReqEqual(url, serverInstance) {
delete directReqObj.headers['transfer-encoding']; delete directReqObj.headers['transfer-encoding'];
delete proxyReqObj.headers['transfer-encoding']; delete proxyReqObj.headers['transfer-encoding'];
// delete the headers that should not be passed by AnyProxy
delete directReqObj.headers.connection;
delete proxyReqObj.headers.connection;
// delete the headers related to websocket establishment
const directHeaderKeys = Object.keys(directReqObj.headers);
directHeaderKeys.forEach((key) => {
// if the key matchs 'sec-websocket', delete it
if (/sec-websocket/ig.test(key)) {
delete directReqObj.headers[key];
}
});
const proxyHeaderKeys = Object.keys(proxyReqObj.headers);
proxyHeaderKeys.forEach((key) => {
// if the key matchs 'sec-websocaket', delete it
if (/sec-websocket/ig.test(key)) {
delete proxyReqObj.headers[key];
}
});
isEqual = isEqual && directReqObj.url === proxyReqObj.url; isEqual = isEqual && directReqObj.url === proxyReqObj.url;
isEqual = isEqual && isObjectEqual(directReqObj.headers, proxyReqObj.headers, url); isEqual = isEqual && isObjectEqual(directReqObj.headers, proxyReqObj.headers, url);
isEqual = isEqual && directReqObj.body === proxyReqObj.body; isEqual = isEqual && directReqObj.body === proxyReqObj.body;

View File

@ -187,17 +187,20 @@ function doUpload(url, method, filepath, formParams, headers = {}, isProxy) {
return requestTask; return requestTask;
} }
function doWebSocket(url, isProxy) { function doWebSocket(url, headers = {}, isProxy) {
let ws; let ws;
if (isProxy) { if (isProxy) {
headers['via-proxy'] = 'true';
const agent = new HttpsProxyAgent(SOCKET_PROXY_HOST); const agent = new HttpsProxyAgent(SOCKET_PROXY_HOST);
ws = new WebSocket(url, { ws = new WebSocket(url, {
agent, agent,
rejectUnauthorized: false rejectUnauthorized: false,
headers
}); });
} else { } else {
ws = new WebSocket(url, { ws = new WebSocket(url, {
rejectUnauthorized: false rejectUnauthorized: false,
headers
}); });
} }
@ -252,12 +255,12 @@ function directOptions(url, headers = {}) {
return directRequest('OPTIONS', url, {}, headers); return directRequest('OPTIONS', url, {}, headers);
} }
function proxyWs(url) { function proxyWs(url, headers) {
return doWebSocket(url, true); return doWebSocket(url, headers, true);
} }
function directWs(url) { function directWs(url, headers) {
return doWebSocket(url); return doWebSocket(url, headers);
} }
/** /**