From ca88b07ecf3c034df505bb6520473bc88cdac18d Mon Sep 17 00:00:00 2001
From: fatedier <fatedier@gmail.com>
Date: Wed, 8 Aug 2018 11:18:38 +0800
Subject: [PATCH] optimize

---
 cmd/frps/root.go               |  6 +++---
 conf/frps_full.ini             |  6 +++---
 models/config/server_common.go | 29 ++++++++++++++++-------------
 server/service.go              |  4 +++-
 utils/vhost/newhttp.go         | 20 +++++++++++++-------
 5 files changed, 38 insertions(+), 27 deletions(-)

diff --git a/cmd/frps/root.go b/cmd/frps/root.go
index 19997326..76a1acd9 100644
--- a/cmd/frps/root.go
+++ b/cmd/frps/root.go
@@ -45,6 +45,7 @@ var (
 	proxyBindAddr     string
 	vhostHttpPort     int
 	vhostHttpsPort    int
+	vhostHttpTimeout  int64
 	dashboardAddr     string
 	dashboardPort     int
 	dashboardUser     string
@@ -61,7 +62,6 @@ var (
 	allowPorts        string
 	maxPoolCount      int64
 	maxPortsPerClient int64
-	vhostHttpTimeout  int64
 )
 
 func init() {
@@ -75,6 +75,7 @@ func init() {
 	rootCmd.PersistentFlags().StringVarP(&proxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address")
 	rootCmd.PersistentFlags().IntVarP(&vhostHttpPort, "vhost_http_port", "", 0, "vhost http port")
 	rootCmd.PersistentFlags().IntVarP(&vhostHttpsPort, "vhost_https_port", "", 0, "vhost https port")
+	rootCmd.PersistentFlags().Int64VarP(&vhostHttpTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout")
 	rootCmd.PersistentFlags().StringVarP(&dashboardAddr, "dashboard_addr", "", "0.0.0.0", "dasboard address")
 	rootCmd.PersistentFlags().IntVarP(&dashboardPort, "dashboard_port", "", 0, "dashboard port")
 	rootCmd.PersistentFlags().StringVarP(&dashboardUser, "dashboard_user", "", "admin", "dashboard user")
@@ -88,7 +89,6 @@ func init() {
 	rootCmd.PersistentFlags().StringVarP(&subDomainHost, "subdomain_host", "", "", "subdomain host")
 	rootCmd.PersistentFlags().StringVarP(&allowPorts, "allow_ports", "", "", "allow ports")
 	rootCmd.PersistentFlags().Int64VarP(&maxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client")
-	rootCmd.PersistentFlags().Int64VarP(&vhostHttpTimeout, "vhost_http_timeout", "", 30, "vhost http timeout")
 }
 
 var rootCmd = &cobra.Command{
@@ -169,6 +169,7 @@ func parseServerCommonCfgFromCmd() (err error) {
 	g.GlbServerCfg.ProxyBindAddr = proxyBindAddr
 	g.GlbServerCfg.VhostHttpPort = vhostHttpPort
 	g.GlbServerCfg.VhostHttpsPort = vhostHttpsPort
+	g.GlbServerCfg.VhostHttpTimeout = vhostHttpTimeout
 	g.GlbServerCfg.DashboardAddr = dashboardAddr
 	g.GlbServerCfg.DashboardPort = dashboardPort
 	g.GlbServerCfg.DashboardUser = dashboardUser
@@ -193,7 +194,6 @@ func parseServerCommonCfgFromCmd() (err error) {
 		}
 	}
 	g.GlbServerCfg.MaxPortsPerClient = maxPortsPerClient
-	g.GlbServerCfg.VhostHttpTimeout = vhostHttpTimeout
 	return
 }
 
diff --git a/conf/frps_full.ini b/conf/frps_full.ini
index 4b61facb..a1fc50c9 100644
--- a/conf/frps_full.ini
+++ b/conf/frps_full.ini
@@ -20,6 +20,9 @@ kcp_bind_port = 7000
 vhost_http_port = 80
 vhost_https_port = 443
 
+# response header timeout(seconds) for vhost http server, default is 60s
+# vhost_http_timeout = 60
+
 # set dashboard_addr and dashboard_port to view dashboard of frps
 # dashboard_addr's default value is same with bind_addr
 # dashboard is available only if dashboard_port is set
@@ -66,6 +69,3 @@ subdomain_host = frps.com
 
 # if tcp stream multiplexing is used, default is true
 tcp_mux = true
-
-# if long connection for more than 30 seconds and disconnection of the server ,fix the pars .
-vhost_http_timeout = 30
\ No newline at end of file
diff --git a/models/config/server_common.go b/models/config/server_common.go
index a92b9d2f..9df432ed 100644
--- a/models/config/server_common.go
+++ b/models/config/server_common.go
@@ -51,8 +51,11 @@ type ServerCommonConf struct {
 	VhostHttpPort int `json:"vhost_http_port"`
 
 	// if VhostHttpsPort equals 0, don't listen a public port for https protocol
-	VhostHttpsPort int    `json:"vhost_http_port"`
-	DashboardAddr  string `json:"dashboard_addr"`
+	VhostHttpsPort int `json:"vhost_http_port"`
+
+	VhostHttpTimeout int64 `json:"vhost_http_timeout"`
+
+	DashboardAddr string `json:"dashboard_addr"`
 
 	// if DashboardPort equals 0, dashboard is not available
 	DashboardPort int    `json:"dashboard_port"`
@@ -73,7 +76,6 @@ type ServerCommonConf struct {
 	MaxPortsPerClient int64 `json:"max_ports_per_client"`
 	HeartBeatTimeout  int64 `json:"heart_beat_timeout"`
 	UserConnTimeout   int64 `json:"user_conn_timeout"`
-	VhostHttpTimeout  int64 `json:"vhost_http_timeout "`
 }
 
 func GetDefaultServerConf() *ServerCommonConf {
@@ -85,6 +87,7 @@ func GetDefaultServerConf() *ServerCommonConf {
 		ProxyBindAddr:     "0.0.0.0",
 		VhostHttpPort:     0,
 		VhostHttpsPort:    0,
+		VhostHttpTimeout:  60,
 		DashboardAddr:     "0.0.0.0",
 		DashboardPort:     0,
 		DashboardUser:     "admin",
@@ -103,7 +106,6 @@ func GetDefaultServerConf() *ServerCommonConf {
 		MaxPortsPerClient: 0,
 		HeartBeatTimeout:  90,
 		UserConnTimeout:   10,
-		VhostHttpTimeout:  30,
 	}
 }
 
@@ -183,6 +185,16 @@ func UnmarshalServerConfFromIni(defaultCfg *ServerCommonConf, content string) (c
 		cfg.VhostHttpsPort = 0
 	}
 
+	if tmpStr, ok = conf.Get("common", "vhost_http_timeout"); ok {
+		v, errRet := strconv.ParseInt(tmpStr, 10, 64)
+		if errRet != nil || v < 0 {
+			err = fmt.Errorf("Parse conf error: invalid vhost_http_timeout")
+			return
+		} else {
+			cfg.VhostHttpTimeout = v
+		}
+	}
+
 	if tmpStr, ok = conf.Get("common", "dashboard_addr"); ok {
 		cfg.DashboardAddr = tmpStr
 	} else {
@@ -302,15 +314,6 @@ func UnmarshalServerConfFromIni(defaultCfg *ServerCommonConf, content string) (c
 			cfg.HeartBeatTimeout = v
 		}
 	}
-	if tmpStr, ok = conf.Get("common", "vhost_http_timeout"); ok {
-		v, errRet := strconv.ParseInt(tmpStr, 10, 64)
-		if errRet != nil {
-			err = fmt.Errorf("Parse conf error: vhost_http_timeout is incorrect")
-			return
-		} else {
-			cfg.VhostHttpTimeout = v
-		}
-	}
 	return
 }
 
diff --git a/server/service.go b/server/service.go
index a9b14a62..65a5d5af 100644
--- a/server/service.go
+++ b/server/service.go
@@ -139,7 +139,9 @@ func NewService() (svr *Service, err error) {
 
 	// Create http vhost muxer.
 	if cfg.VhostHttpPort > 0 {
-		rp := vhost.NewHttpReverseProxy()
+		rp := vhost.NewHttpReverseProxy(vhost.HttpReverseProxyOptions{
+			ResponseHeaderTimeoutS: cfg.VhostHttpTimeout,
+		})
 		svr.httpReverseProxy = rp
 
 		address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
diff --git a/utils/vhost/newhttp.go b/utils/vhost/newhttp.go
index 6cfadf9c..fef991fa 100644
--- a/utils/vhost/newhttp.go
+++ b/utils/vhost/newhttp.go
@@ -25,15 +25,12 @@ import (
 	"sync"
 	"time"
 
-	"github.com/fatedier/frp/g"
 	frpLog "github.com/fatedier/frp/utils/log"
 
 	"github.com/fatedier/golib/pool"
 )
 
 var (
-	responseHeaderTimeout = time.Duration(g.GlbServerCfg.VhostHttpTimeout) * time.Second
-
 	ErrRouterConfigConflict = errors.New("router config conflict")
 	ErrNoDomain             = errors.New("no such domain")
 )
@@ -48,17 +45,26 @@ func getHostFromAddr(addr string) (host string) {
 	return
 }
 
+type HttpReverseProxyOptions struct {
+	ResponseHeaderTimeoutS int64
+}
+
 type HttpReverseProxy struct {
 	proxy *ReverseProxy
 
 	vhostRouter *VhostRouters
 
-	cfgMu sync.RWMutex
+	responseHeaderTimeout time.Duration
+	cfgMu                 sync.RWMutex
 }
 
-func NewHttpReverseProxy() *HttpReverseProxy {
+func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy {
+	if option.ResponseHeaderTimeoutS <= 0 {
+		option.ResponseHeaderTimeoutS = 60
+	}
 	rp := &HttpReverseProxy{
-		vhostRouter: NewVhostRouters(),
+		responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second,
+		vhostRouter:           NewVhostRouters(),
 	}
 	proxy := &ReverseProxy{
 		Director: func(req *http.Request) {
@@ -77,7 +83,7 @@ func NewHttpReverseProxy() *HttpReverseProxy {
 			}
 		},
 		Transport: &http.Transport{
-			ResponseHeaderTimeout: responseHeaderTimeout,
+			ResponseHeaderTimeout: rp.responseHeaderTimeout,
 			DisableKeepAlives:     true,
 			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 				url := ctx.Value("url").(string)