diff --git a/README.md b/README.md index 5fff2b7d..c8008d8e 100644 --- a/README.md +++ b/README.md @@ -560,7 +560,8 @@ This feature is fit for a large number of short connections. ### Load balancing Load balancing is supported by `group`. -This feature is available only for type `tcp` now. + +This feature is available only for type `tcp` and `http` now. ```ini # frpc.ini @@ -583,6 +584,10 @@ group_key = 123 Proxies in same group will accept connections from port 80 randomly. +For `tcp` type, `remote_port` in one group shoud be same. + +For `http` type, `custom_domains, subdomain, locations` shoud be same. + ### Health Check Health check feature can help you achieve high availability with load balancing. diff --git a/README_zh.md b/README_zh.md index 7fdaa0f7..a918930d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -16,7 +16,7 @@ frp 是一个可用于内网穿透的高性能的反向代理应用,支持 tcp * [通过 ssh 访问公司内网机器](#通过-ssh-访问公司内网机器) * [通过自定义域名访问部署于内网的 web 服务](#通过自定义域名访问部署于内网的-web-服务) * [转发 DNS 查询请求](#转发-dns-查询请求) - * [转发 Unix域套接字](#转发-unix域套接字) + * [转发 Unix 域套接字](#转发-unix-域套接字) * [对外提供简单的文件访问服务](#对外提供简单的文件访问服务) * [为本地 HTTP 服务启用 HTTPS](#为本地-http-服务启用-https) * [安全地暴露内网服务](#安全地暴露内网服务) @@ -194,7 +194,7 @@ DNS 查询请求通常使用 UDP 协议,frp 支持对内网 UDP 服务的穿 `dig @x.x.x.x -p 6000 www.google.com` -### 转发 Unix域套接字 +### 转发 Unix 域套接字 通过 tcp 端口访问内网的 unix域套接字(例如和 docker daemon 通信)。 @@ -597,7 +597,7 @@ tcp_mux = false 可以将多个相同类型的 proxy 加入到同一个 group 中,从而实现负载均衡的功能。 -目前只支持 tcp 类型的 proxy。 +目前只支持 TCP 和 HTTP 类型的 proxy。 ```ini # frpc.ini @@ -618,7 +618,9 @@ group_key = 123 用户连接 frps 服务器的 80 端口,frps 会将接收到的用户连接随机分发给其中一个存活的 proxy。这样可以在一台 frpc 机器挂掉后仍然有其他节点能够提供服务。 -要求 `group_key` 相同,做权限验证,且 `remote_port` 相同。 +TCP 类型代理要求 `group_key` 相同,做权限验证,且 `remote_port` 相同。 + +HTTP 类型代理要求 `group_key, custom_domains 或 subdomain 和 locations` 相同。 ### 健康检查 diff --git a/client/service.go b/client/service.go index 32106cad..007e21d1 100644 --- a/client/service.go +++ b/client/service.go @@ -86,6 +86,8 @@ func (svr *Service) Run() error { if g.GlbClientCfg.LoginFailExit { return err } else { + conn.Close() + session.Close() time.Sleep(10 * time.Second) } } else { diff --git a/server/controller/resource.go b/server/controller/resource.go index 8428a1f3..91332b57 100644 --- a/server/controller/resource.go +++ b/server/controller/resource.go @@ -29,6 +29,9 @@ type ResourceController struct { // Tcp Group Controller TcpGroupCtl *group.TcpGroupCtl + // HTTP Group Controller + HTTPGroupCtl *group.HTTPGroupController + // Manage all tcp ports TcpPortManager *ports.PortManager diff --git a/server/group/group.go b/server/group/group.go index a0dae7cd..ab38cf45 100644 --- a/server/group/group.go +++ b/server/group/group.go @@ -23,4 +23,5 @@ var ( ErrGroupParamsInvalid = errors.New("group params invalid") ErrListenerClosed = errors.New("group listener closed") ErrGroupDifferentPort = errors.New("group should have same remote port") + ErrProxyRepeated = errors.New("group proxy repeated") ) diff --git a/server/group/http.go b/server/group/http.go new file mode 100644 index 00000000..538dccf3 --- /dev/null +++ b/server/group/http.go @@ -0,0 +1,157 @@ +package group + +import ( + "fmt" + "sync" + "sync/atomic" + + frpNet "github.com/fatedier/frp/utils/net" + + "github.com/fatedier/frp/utils/vhost" +) + +type HTTPGroupController struct { + groups map[string]*HTTPGroup + + vhostRouter *vhost.VhostRouters + + mu sync.Mutex +} + +func NewHTTPGroupController(vhostRouter *vhost.VhostRouters) *HTTPGroupController { + return &HTTPGroupController{ + groups: make(map[string]*HTTPGroup), + vhostRouter: vhostRouter, + } +} + +func (ctl *HTTPGroupController) Register(proxyName, group, groupKey string, + routeConfig vhost.VhostRouteConfig) (err error) { + + indexKey := httpGroupIndex(group, routeConfig.Domain, routeConfig.Location) + ctl.mu.Lock() + g, ok := ctl.groups[indexKey] + if !ok { + g = NewHTTPGroup(ctl) + ctl.groups[indexKey] = g + } + ctl.mu.Unlock() + + return g.Register(proxyName, group, groupKey, routeConfig) +} + +func (ctl *HTTPGroupController) UnRegister(proxyName, group, domain, location string) { + indexKey := httpGroupIndex(group, domain, location) + ctl.mu.Lock() + defer ctl.mu.Unlock() + g, ok := ctl.groups[indexKey] + if !ok { + return + } + + isEmpty := g.UnRegister(proxyName) + if isEmpty { + delete(ctl.groups, indexKey) + } +} + +type HTTPGroup struct { + group string + groupKey string + domain string + location string + + createFuncs map[string]vhost.CreateConnFunc + pxyNames []string + index uint64 + ctl *HTTPGroupController + mu sync.RWMutex +} + +func NewHTTPGroup(ctl *HTTPGroupController) *HTTPGroup { + return &HTTPGroup{ + createFuncs: make(map[string]vhost.CreateConnFunc), + pxyNames: make([]string, 0), + ctl: ctl, + } +} + +func (g *HTTPGroup) Register(proxyName, group, groupKey string, + routeConfig vhost.VhostRouteConfig) (err error) { + + g.mu.Lock() + defer g.mu.Unlock() + if len(g.createFuncs) == 0 { + // the first proxy in this group + tmp := routeConfig // copy object + tmp.CreateConnFn = g.createConn + err = g.ctl.vhostRouter.Add(routeConfig.Domain, routeConfig.Location, &tmp) + if err != nil { + return + } + + g.group = group + g.groupKey = groupKey + g.domain = routeConfig.Domain + g.location = routeConfig.Location + } else { + if g.group != group || g.domain != routeConfig.Domain || g.location != routeConfig.Location { + err = ErrGroupParamsInvalid + return + } + if g.groupKey != groupKey { + err = ErrGroupAuthFailed + return + } + } + if _, ok := g.createFuncs[proxyName]; ok { + err = ErrProxyRepeated + return + } + g.createFuncs[proxyName] = routeConfig.CreateConnFn + g.pxyNames = append(g.pxyNames, proxyName) + return nil +} + +func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.createFuncs, proxyName) + for i, name := range g.pxyNames { + if name == proxyName { + g.pxyNames = append(g.pxyNames[:i], g.pxyNames[i+1:]...) + break + } + } + + if len(g.createFuncs) == 0 { + isEmpty = true + g.ctl.vhostRouter.Del(g.domain, g.location) + } + return +} + +func (g *HTTPGroup) createConn(remoteAddr string) (frpNet.Conn, error) { + var f vhost.CreateConnFunc + newIndex := atomic.AddUint64(&g.index, 1) + + g.mu.RLock() + group := g.group + domain := g.domain + location := g.location + if len(g.pxyNames) > 0 { + name := g.pxyNames[int(newIndex)%len(g.pxyNames)] + f, _ = g.createFuncs[name] + } + g.mu.RUnlock() + + if f == nil { + return nil, fmt.Errorf("no CreateConnFunc for http group [%s], domain [%s], location [%s]", group, domain, location) + } + + return f(remoteAddr) +} + +func httpGroupIndex(group, domain, location string) string { + return fmt.Sprintf("%s_%s_%s", group, domain, location) +} diff --git a/server/group/tcp.go b/server/group/tcp.go index 8c46be65..9c027b94 100644 --- a/server/group/tcp.go +++ b/server/group/tcp.go @@ -24,46 +24,47 @@ import ( gerr "github.com/fatedier/golib/errors" ) -type TcpGroupListener struct { - groupName string - group *TcpGroup +// TcpGroupCtl manage all TcpGroups +type TcpGroupCtl struct { + groups map[string]*TcpGroup - addr net.Addr - closeCh chan struct{} + // portManager is used to manage port + portManager *ports.PortManager + mu sync.Mutex } -func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener { - return &TcpGroupListener{ - groupName: name, - group: group, - addr: addr, - closeCh: make(chan struct{}), +// NewTcpGroupCtl return a new TcpGroupCtl +func NewTcpGroupCtl(portManager *ports.PortManager) *TcpGroupCtl { + return &TcpGroupCtl{ + groups: make(map[string]*TcpGroup), + portManager: portManager, } } -func (ln *TcpGroupListener) Accept() (c net.Conn, err error) { - var ok bool - select { - case <-ln.closeCh: - return nil, ErrListenerClosed - case c, ok = <-ln.group.Accept(): - if !ok { - return nil, ErrListenerClosed - } - return c, nil +// Listen is the wrapper for TcpGroup's Listen +// If there are no group, we will create one here +func (tgc *TcpGroupCtl) Listen(proxyName string, group string, groupKey string, + addr string, port int) (l net.Listener, realPort int, err error) { + + tgc.mu.Lock() + tcpGroup, ok := tgc.groups[group] + if !ok { + tcpGroup = NewTcpGroup(tgc) + tgc.groups[group] = tcpGroup } + tgc.mu.Unlock() + + return tcpGroup.Listen(proxyName, group, groupKey, addr, port) } -func (ln *TcpGroupListener) Addr() net.Addr { - return ln.addr -} - -func (ln *TcpGroupListener) Close() (err error) { - close(ln.closeCh) - ln.group.CloseListener(ln) - return +// RemoveGroup remove TcpGroup from controller +func (tgc *TcpGroupCtl) RemoveGroup(group string) { + tgc.mu.Lock() + defer tgc.mu.Unlock() + delete(tgc.groups, group) } +// TcpGroup route connections to different proxies type TcpGroup struct { group string groupKey string @@ -79,6 +80,7 @@ type TcpGroup struct { mu sync.Mutex } +// NewTcpGroup return a new TcpGroup func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup { return &TcpGroup{ lns: make([]*TcpGroupListener, 0), @@ -87,10 +89,14 @@ func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup { } } +// Listen will return a new TcpGroupListener +// if TcpGroup already has a listener, just add a new TcpGroupListener to the queues +// otherwise, listen on the real address func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TcpGroupListener, realPort int, err error) { tg.mu.Lock() defer tg.mu.Unlock() if len(tg.lns) == 0 { + // the first listener, listen on the real address realPort, err = tg.ctl.portManager.Acquire(proxyName, port) if err != nil { return @@ -114,6 +120,7 @@ func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr } go tg.worker() } else { + // address and port in the same group must be equal if tg.group != group || tg.addr != addr { err = ErrGroupParamsInvalid return @@ -133,6 +140,7 @@ func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr return } +// worker is called when the real tcp listener has been created func (tg *TcpGroup) worker() { for { c, err := tg.tcpLn.Accept() @@ -152,6 +160,7 @@ func (tg *TcpGroup) Accept() <-chan net.Conn { return tg.acceptCh } +// CloseListener remove the TcpGroupListener from the TcpGroup func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) { tg.mu.Lock() defer tg.mu.Unlock() @@ -169,36 +178,47 @@ func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) { } } -type TcpGroupCtl struct { - groups map[string]*TcpGroup +// TcpGroupListener +type TcpGroupListener struct { + groupName string + group *TcpGroup - portManager *ports.PortManager - mu sync.Mutex + addr net.Addr + closeCh chan struct{} } -func NewTcpGroupCtl(portManager *ports.PortManager) *TcpGroupCtl { - return &TcpGroupCtl{ - groups: make(map[string]*TcpGroup), - portManager: portManager, +func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener { + return &TcpGroupListener{ + groupName: name, + group: group, + addr: addr, + closeCh: make(chan struct{}), } } -func (tgc *TcpGroupCtl) Listen(proxyNanme string, group string, groupKey string, - addr string, port int) (l net.Listener, realPort int, err error) { - - tgc.mu.Lock() - defer tgc.mu.Unlock() - if tcpGroup, ok := tgc.groups[group]; ok { - return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port) - } else { - tcpGroup = NewTcpGroup(tgc) - tgc.groups[group] = tcpGroup - return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port) +// Accept will accept connections from TcpGroup +func (ln *TcpGroupListener) Accept() (c net.Conn, err error) { + var ok bool + select { + case <-ln.closeCh: + return nil, ErrListenerClosed + case c, ok = <-ln.group.Accept(): + if !ok { + return nil, ErrListenerClosed + } + return c, nil } } -func (tgc *TcpGroupCtl) RemoveGroup(group string) { - tgc.mu.Lock() - defer tgc.mu.Unlock() - delete(tgc.groups, group) +func (ln *TcpGroupListener) Addr() net.Addr { + return ln.addr +} + +// Close close the listener +func (ln *TcpGroupListener) Close() (err error) { + close(ln.closeCh) + + // remove self from TcpGroup + ln.group.CloseListener(ln) + return } diff --git a/server/proxy/http.go b/server/proxy/http.go index c5bc1ac4..1fa8765e 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -50,6 +50,12 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) { locations = []string{""} } + defer func() { + if err != nil { + pxy.Close() + } + }() + addrs := make([]string, 0) for _, domain := range pxy.cfg.CustomDomains { if domain == "" { @@ -59,17 +65,31 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) { routeConfig.Domain = domain for _, location := range locations { routeConfig.Location = location - err = pxy.rc.HttpReverseProxy.Register(routeConfig) - if err != nil { - return - } tmpDomain := routeConfig.Domain tmpLocation := routeConfig.Location - addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(g.GlbServerCfg.VhostHttpPort))) - pxy.closeFuncs = append(pxy.closeFuncs, func() { - pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation) - }) - pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) + + // handle group + if pxy.cfg.Group != "" { + err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, routeConfig) + if err != nil { + return + } + + pxy.closeFuncs = append(pxy.closeFuncs, func() { + pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation) + }) + } else { + // no group + err = pxy.rc.HttpReverseProxy.Register(routeConfig) + if err != nil { + return + } + pxy.closeFuncs = append(pxy.closeFuncs, func() { + pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation) + }) + } + addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(g.GlbServerCfg.VhostHttpPort))) + pxy.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group) } } @@ -77,17 +97,31 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) { routeConfig.Domain = pxy.cfg.SubDomain + "." + g.GlbServerCfg.SubDomainHost for _, location := range locations { routeConfig.Location = location - err = pxy.rc.HttpReverseProxy.Register(routeConfig) - if err != nil { - return - } tmpDomain := routeConfig.Domain tmpLocation := routeConfig.Location + + // handle group + if pxy.cfg.Group != "" { + err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, routeConfig) + if err != nil { + return + } + + pxy.closeFuncs = append(pxy.closeFuncs, func() { + pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation) + }) + } else { + err = pxy.rc.HttpReverseProxy.Register(routeConfig) + if err != nil { + return + } + pxy.closeFuncs = append(pxy.closeFuncs, func() { + pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation) + }) + } addrs = append(addrs, util.CanonicalAddr(tmpDomain, g.GlbServerCfg.VhostHttpPort)) - pxy.closeFuncs = append(pxy.closeFuncs, func() { - pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation) - }) - pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) + + pxy.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group) } } remoteAddr = strings.Join(addrs, ",") diff --git a/server/proxy/https.go b/server/proxy/https.go index 888fcbe5..cb5ce928 100644 --- a/server/proxy/https.go +++ b/server/proxy/https.go @@ -31,6 +31,11 @@ type HttpsProxy struct { func (pxy *HttpsProxy) Run() (remoteAddr string, err error) { routeConfig := &vhost.VhostRouteConfig{} + defer func() { + if err != nil { + pxy.Close() + } + }() addrs := make([]string, 0) for _, domain := range pxy.cfg.CustomDomains { if domain == "" { diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index dd7bb79d..1a9a28e1 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -72,6 +72,8 @@ func (pxy *BaseProxy) Close() { } } +// GetWorkConnFromPool try to get a new work connections from pool +// for quickly response, we immediately send the StartWorkConn message to frpc after take out one from pool func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn frpNet.Conn, err error) { // try all connections from the pool for i := 0; i < pxy.poolCount+1; i++ { diff --git a/server/service.go b/server/service.go index d1207f44..e3a1d117 100644 --- a/server/service.go +++ b/server/service.go @@ -76,6 +76,9 @@ type Service struct { // Manage all proxies pxyManager *proxy.ProxyManager + // HTTP vhost router + httpVhostRouter *vhost.VhostRouters + // All resource managers and controllers rc *controller.ResourceController @@ -95,12 +98,16 @@ func NewService() (svr *Service, err error) { TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), }, - tlsConfig: generateTLSConfig(), + httpVhostRouter: vhost.NewVhostRouters(), + tlsConfig: generateTLSConfig(), } // Init group controller svr.rc.TcpGroupCtl = group.NewTcpGroupCtl(svr.rc.TcpPortManager) + // Init HTTP group controller + svr.rc.HTTPGroupCtl = group.NewHTTPGroupController(svr.httpVhostRouter) + // Init assets err = assets.Load(cfg.AssetsDir) if err != nil { @@ -159,7 +166,7 @@ func NewService() (svr *Service, err error) { if cfg.VhostHttpPort > 0 { rp := vhost.NewHttpReverseProxy(vhost.HttpReverseProxyOptions{ ResponseHeaderTimeoutS: cfg.VhostHttpTimeout, - }) + }, svr.httpVhostRouter) svr.rc.HttpReverseProxy = rp address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) diff --git a/utils/version/version.go b/utils/version/version.go index da23fbea..40fd07a7 100644 --- a/utils/version/version.go +++ b/utils/version/version.go @@ -19,7 +19,7 @@ import ( "strings" ) -var version string = "0.27.1" +var version string = "0.28.0" func Full() string { return version diff --git a/utils/vhost/http.go b/utils/vhost/http.go index 7bbc3615..1b3a5bd1 100644 --- a/utils/vhost/http.go +++ b/utils/vhost/http.go @@ -23,7 +23,6 @@ import ( "net" "net/http" "strings" - "sync" "time" frpLog "github.com/fatedier/frp/utils/log" @@ -32,8 +31,7 @@ import ( ) var ( - ErrRouterConfigConflict = errors.New("router config conflict") - ErrNoDomain = errors.New("no such domain") + ErrNoDomain = errors.New("no such domain") ) func getHostFromAddr(addr string) (host string) { @@ -51,21 +49,19 @@ type HttpReverseProxyOptions struct { } type HttpReverseProxy struct { - proxy *ReverseProxy - + proxy *ReverseProxy vhostRouter *VhostRouters responseHeaderTimeout time.Duration - cfgMu sync.RWMutex } -func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy { +func NewHttpReverseProxy(option HttpReverseProxyOptions, vhostRouter *VhostRouters) *HttpReverseProxy { if option.ResponseHeaderTimeoutS <= 0 { option.ResponseHeaderTimeoutS = 60 } rp := &HttpReverseProxy{ responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second, - vhostRouter: NewVhostRouters(), + vhostRouter: vhostRouter, } proxy := &ReverseProxy{ Director: func(req *http.Request) { @@ -106,21 +102,18 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy { return rp } +// Register register the route config to reverse proxy +// reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error { - rp.cfgMu.Lock() - defer rp.cfgMu.Unlock() - _, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location) - if ok { - return ErrRouterConfigConflict - } else { - rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg) + err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg) + if err != nil { + return err } return nil } +// UnRegister unregister route config by domain and location func (rp *HttpReverseProxy) UnRegister(domain string, location string) { - rp.cfgMu.Lock() - defer rp.cfgMu.Unlock() rp.vhostRouter.Del(domain, location) } @@ -140,6 +133,7 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers return } +// CreateConnection create a new connection by route config func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) { vr, ok := rp.getVhost(domain, location) if ok { @@ -163,10 +157,8 @@ func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) boo return true } +// getVhost get vhost router by domain and location func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) { - rp.cfgMu.RLock() - defer rp.cfgMu.RUnlock() - // first we check the full hostname // if not exist, then check the wildcard_domain such as *.example.com vr, ok = rp.vhostRouter.Get(domain, location) diff --git a/utils/vhost/router.go b/utils/vhost/router.go index ea5c347c..bfdcb50b 100644 --- a/utils/vhost/router.go +++ b/utils/vhost/router.go @@ -1,11 +1,16 @@ package vhost import ( + "errors" "sort" "strings" "sync" ) +var ( + ErrRouterConfigConflict = errors.New("router config conflict") +) + type VhostRouters struct { RouterByDomain map[string][]*VhostRouter mutex sync.RWMutex @@ -24,10 +29,14 @@ func NewVhostRouters() *VhostRouters { } } -func (r *VhostRouters) Add(domain, location string, payload interface{}) { +func (r *VhostRouters) Add(domain, location string, payload interface{}) error { r.mutex.Lock() defer r.mutex.Unlock() + if _, exist := r.exist(domain, location); exist { + return ErrRouterConfigConflict + } + vrs, found := r.RouterByDomain[domain] if !found { vrs = make([]*VhostRouter, 0, 1) @@ -42,6 +51,7 @@ func (r *VhostRouters) Add(domain, location string, payload interface{}) { sort.Sort(sort.Reverse(ByLocation(vrs))) r.RouterByDomain[domain] = vrs + return nil } func (r *VhostRouters) Del(domain, location string) { @@ -80,10 +90,7 @@ func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) { return } -func (r *VhostRouters) Exist(host, path string) (vr *VhostRouter, exist bool) { - r.mutex.RLock() - defer r.mutex.RUnlock() - +func (r *VhostRouters) exist(host, path string) (vr *VhostRouter, exist bool) { vrs, found := r.RouterByDomain[host] if !found { return diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index f366e1e1..d3e54fa1 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -15,7 +15,6 @@ package vhost import ( "fmt" "strings" - "sync" "time" "github.com/fatedier/frp/utils/log" @@ -35,7 +34,6 @@ type VhostMuxer struct { authFunc httpAuthFunc rewriteFunc hostRewriteFunc registryRouter *VhostRouters - mutex sync.RWMutex } func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { @@ -53,6 +51,7 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut type CreateConnFunc func(remoteAddr string) (frpNet.Conn, error) +// VhostRouteConfig is the params used to match HTTP requests type VhostRouteConfig struct { Domain string Location string @@ -67,14 +66,6 @@ type VhostRouteConfig struct { // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil // then rewrite the host header to rewriteHost func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) { - v.mutex.Lock() - defer v.mutex.Unlock() - - _, ok := v.registryRouter.Exist(cfg.Domain, cfg.Location) - if ok { - return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", cfg.Domain, cfg.Location) - } - l = &Listener{ name: cfg.Domain, location: cfg.Location, @@ -85,14 +76,14 @@ func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) { accept: make(chan frpNet.Conn), Logger: log.NewPrefixLogger(""), } - v.registryRouter.Add(cfg.Domain, cfg.Location, l) + err = v.registryRouter.Add(cfg.Domain, cfg.Location, l) + if err != nil { + return + } return l, nil } func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { - v.mutex.RLock() - defer v.mutex.RUnlock() - // first we check the full hostname // if not exist, then check the wildcard_domain such as *.example.com vr, found := v.registryRouter.Get(name, path)