diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index 39c26d26..23e62e53 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -347,22 +347,18 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) // Send detect message - array := strings.Split(natHoleRespMsg.VisitorAddr, ":") - if len(array) <= 1 { - xl.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) + host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr) + if err != nil { + xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err) } laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) - /* - for i := 1000; i < 65000; i++ { - pxy.sendDetectMsg(array[0], int64(i), laddr, "a") - } - */ - port, err := strconv.ParseInt(array[1], 10, 64) + + port, err := strconv.ParseInt(portStr, 10, 64) if err != nil { xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) return } - pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) + pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid)) xl.Trace("send all detect msg done") msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}) diff --git a/client/visitor.go b/client/visitor.go index 7526481d..52f4ccd9 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net" + "strconv" "sync" "time" @@ -85,7 +86,7 @@ type STCPVisitor struct { } func (sv *STCPVisitor) Run() (err error) { - sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort)) + sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) if err != nil { return } @@ -174,7 +175,7 @@ type XTCPVisitor struct { } func (sv *XTCPVisitor) Run() (err error) { - sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort)) + sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) if err != nil { return } @@ -352,7 +353,7 @@ type SUDPVisitor struct { func (sv *SUDPVisitor) Run() (err error) { xl := xlog.FromContextSafe(sv.ctx) - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort)) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) if err != nil { return fmt.Errorf("sudp ResolveUDPAddr error: %v", err) } diff --git a/pkg/util/net/udp.go b/pkg/util/net/udp.go index 67d66665..6689732e 100644 --- a/pkg/util/net/udp.go +++ b/pkg/util/net/udp.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "net" + "strconv" "sync" "time" @@ -163,7 +164,7 @@ type UDPListener struct { } func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) { - udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort))) if err != nil { return l, err } diff --git a/pkg/util/net/websocket.go b/pkg/util/net/websocket.go index 7030787e..4ec5c9fe 100644 --- a/pkg/util/net/websocket.go +++ b/pkg/util/net/websocket.go @@ -2,9 +2,9 @@ package net import ( "errors" - "fmt" "net" "net/http" + "strconv" "golang.org/x/net/websocket" ) @@ -52,7 +52,7 @@ func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { } func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) { - tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort))) if err != nil { return nil, err } diff --git a/pkg/util/tcpmux/httpconnect.go b/pkg/util/tcpmux/httpconnect.go index 014f6881..fcc0a88f 100644 --- a/pkg/util/tcpmux/httpconnect.go +++ b/pkg/util/tcpmux/httpconnect.go @@ -48,7 +48,7 @@ func readHTTPConnectRequest(rd io.Reader) (host string, err error) { return } - host = util.GetHostFromAddr(req.Host) + host, _ = util.CanonicalHost(req.Host) return } diff --git a/pkg/util/util/http.go b/pkg/util/util/http.go index 2d6089b1..988ec179 100644 --- a/pkg/util/util/http.go +++ b/pkg/util/util/http.go @@ -34,17 +34,6 @@ func OkResponse() *http.Response { return res } -// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func. -func GetHostFromAddr(addr string) (host string) { - strs := strings.Split(addr, ":") - if len(strs) > 1 { - host = strs[0] - } else { - host = addr - } - return -} - // canonicalHost strips port from host if present and returns the canonicalized // host name. func CanonicalHost(host string) (string, error) { diff --git a/pkg/util/util/util.go b/pkg/util/util/util.go index 50069ea4..eb2ae0b2 100644 --- a/pkg/util/util/util.go +++ b/pkg/util/util/util.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net" "strconv" "strings" ) @@ -52,7 +53,7 @@ func CanonicalAddr(host string, port int) (addr string) { if port == 80 || port == 443 { addr = host } else { - addr = fmt.Sprintf("%s:%d", host, port) + addr = net.JoinHostPort(host, strconv.Itoa(port)) } return } diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go index ee2ab1a1..b9dc32db 100644 --- a/pkg/util/vhost/http.go +++ b/pkg/util/vhost/http.go @@ -59,7 +59,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) * Director: func(req *http.Request) { req.URL.Scheme = "http" url := req.Context().Value(RouteInfoURL).(string) - oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string)) + oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string)) rc := rp.GetRouteConfig(oldHost, url) if rc != nil { if rc.RewriteHost != "" { @@ -81,7 +81,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) * IdleConnTimeout: 60 * time.Second, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { url := ctx.Value(RouteInfoURL).(string) - host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string)) + host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string)) remote := ctx.Value(RouteInfoRemote).(string) return rp.CreateConnection(host, url, remote) }, @@ -191,7 +191,7 @@ func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router } func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - domain := util.GetHostFromAddr(req.Host) + domain, _ := util.CanonicalHost(req.Host) location := req.URL.Path user, passwd, _ := req.BasicAuth() if !rp.CheckAuth(domain, location, user, passwd) { diff --git a/server/group/tcp.go b/server/group/tcp.go index 0128482f..c7fd2b27 100644 --- a/server/group/tcp.go +++ b/server/group/tcp.go @@ -15,8 +15,8 @@ package group import ( - "fmt" "net" + "strconv" "sync" "github.com/fatedier/frp/server/ports" @@ -101,7 +101,7 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr if err != nil { return } - tcpLn, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port)) + tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port))) if errRet != nil { err = errRet return diff --git a/server/ports/ports.go b/server/ports/ports.go index 1dabd450..f852f843 100644 --- a/server/ports/ports.go +++ b/server/ports/ports.go @@ -2,8 +2,8 @@ package ports import ( "errors" - "fmt" "net" + "strconv" "sync" "time" ) @@ -134,7 +134,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) { func (pm *Manager) isPortAvailable(port int) bool { if pm.netType == "udp" { - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port)) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port))) if err != nil { return false } @@ -146,7 +146,7 @@ func (pm *Manager) isPortAvailable(port int) bool { return true } - l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port)) + l, err := net.Listen(pm.netType, net.JoinHostPort(pm.bindAddr, strconv.Itoa(port))) if err != nil { return false } diff --git a/server/proxy/tcp.go b/server/proxy/tcp.go index 420f43fe..0cf9c5f9 100644 --- a/server/proxy/tcp.go +++ b/server/proxy/tcp.go @@ -17,6 +17,7 @@ package proxy import ( "fmt" "net" + "strconv" "github.com/fatedier/frp/pkg/config" ) @@ -54,7 +55,7 @@ func (pxy *TCPProxy) Run() (remoteAddr string, err error) { pxy.rc.TCPPortManager.Release(pxy.realPort) } }() - listener, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort)) + listener, errRet := net.Listen("tcp", net.JoinHostPort(pxy.serverCfg.ProxyBindAddr, strconv.Itoa(pxy.realPort))) if errRet != nil { err = errRet return diff --git a/server/proxy/udp.go b/server/proxy/udp.go index 4540a434..9e3c0675 100644 --- a/server/proxy/udp.go +++ b/server/proxy/udp.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net" + "strconv" "time" "github.com/fatedier/frp/pkg/config" @@ -70,7 +71,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) { remoteAddr = fmt.Sprintf(":%d", pxy.realPort) pxy.cfg.RemotePort = pxy.realPort - addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort)) + addr, errRet := net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.serverCfg.ProxyBindAddr, strconv.Itoa(pxy.realPort))) if errRet != nil { err = errRet return diff --git a/server/service.go b/server/service.go index bc0d48a6..c3c45488 100644 --- a/server/service.go +++ b/server/service.go @@ -124,7 +124,8 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { // Create tcpmux httpconnect multiplexer. if cfg.TCPMuxHTTPConnectPort > 0 { var l net.Listener - l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.TCPMuxHTTPConnectPort)) + address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.TCPMuxHTTPConnectPort)) + l, err = net.Listen("tcp", address) if err != nil { err = fmt.Errorf("Create server listener error, %v", err) return @@ -135,7 +136,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { err = fmt.Errorf("Create vhost tcpMuxer error, %v", err) return } - log.Info("tcpmux httpconnect multiplexer listen on %s:%d", cfg.ProxyBindAddr, cfg.TCPMuxHTTPConnectPort) + log.Info("tcpmux httpconnect multiplexer listen on %s", address) } // Init all plugins @@ -199,7 +200,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { err = fmt.Errorf("Listen on kcp address udp %s error: %v", address, err) return } - log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KCPBindPort) + log.Info("frps kcp listen on udp %s", address) } // Listen for accepting connections from client using websocket protocol. @@ -232,7 +233,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { } } go server.Serve(l) - log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHTTPPort) + log.Info("http service listen on %s", address) } // Create https vhost muxer. @@ -288,7 +289,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { err = fmt.Errorf("Create dashboard web server error, %v", err) return } - log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort) + log.Info("Dashboard listen on %s", address) statsEnable = true } if statsEnable { diff --git a/test/e2e/basic/client_server.go b/test/e2e/basic/client_server.go index 67f1efd3..c7faa421 100644 --- a/test/e2e/basic/client_server.go +++ b/test/e2e/basic/client_server.go @@ -249,4 +249,24 @@ var _ = Describe("[Feature: Client-Server]", func() { }) } }) + + Describe("IPv6 bind address", func() { + supportProtocols := []string{"tcp", "kcp", "websocket"} + for _, protocol := range supportProtocols { + tmp := protocol + defineClientServerTest("IPv6 bind address: "+strings.ToUpper(tmp), f, &generalTestConfigures{ + server: fmt.Sprintf(` + bind_addr = :: + kcp_bind_port = {{ .%s }} + protocol = %s + `, consts.PortServerName, protocol), + client: fmt.Sprintf(` + tls_enable = true + protocol = %s + disable_custom_tls_first_byte = true + `, protocol), + }) + } + }) + }) diff --git a/test/e2e/mock/server/httpserver/server.go b/test/e2e/mock/server/httpserver/server.go index f35c1193..a811ac27 100644 --- a/test/e2e/mock/server/httpserver/server.go +++ b/test/e2e/mock/server/httpserver/server.go @@ -2,7 +2,6 @@ package httpserver import ( "crypto/tls" - "fmt" "net" "net/http" "strconv" @@ -97,7 +96,7 @@ func (s *Server) Close() error { } func (s *Server) initListener() (err error) { - s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)) + s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort))) return } diff --git a/test/e2e/mock/server/streamserver/server.go b/test/e2e/mock/server/streamserver/server.go index bb5b790f..1dde353a 100644 --- a/test/e2e/mock/server/streamserver/server.go +++ b/test/e2e/mock/server/streamserver/server.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "strconv" libnet "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/test/e2e/pkg/rpc" @@ -99,7 +100,7 @@ func (s *Server) Close() error { func (s *Server) initListener() (err error) { switch s.netType { case TCP: - s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)) + s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort))) case UDP: s.l, err = libnet.ListenUDP(s.bindAddr, s.bindPort) case Unix: diff --git a/test/e2e/pkg/port/port.go b/test/e2e/pkg/port/port.go index dc9e1012..298892e7 100644 --- a/test/e2e/pkg/port/port.go +++ b/test/e2e/pkg/port/port.go @@ -3,6 +3,7 @@ package port import ( "fmt" "net" + "strconv" "sync" "k8s.io/apimachinery/pkg/util/sets" @@ -57,7 +58,7 @@ func (pa *Allocator) GetByName(portName string) int { return 0 } - l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + l, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) if err != nil { // Maybe not controlled by us, mark it used. pa.used.Insert(port) @@ -65,7 +66,7 @@ func (pa *Allocator) GetByName(portName string) int { } l.Close() - udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", port)) + udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) if err != nil { continue }