diff --git a/client/control.go b/client/control.go index 03618125..4d16ce99 100644 --- a/client/control.go +++ b/client/control.go @@ -17,10 +17,10 @@ package client import ( "context" "crypto/tls" - "fmt" "io" "net" "runtime/debug" + "strconv" "sync" "time" @@ -222,8 +222,10 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { return } } - conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, - fmt.Sprintf("%s:%d", ctl.clientCfg.ServerAddr, ctl.clientCfg.ServerPort), tlsConfig) + + address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)) + conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig) + if err != nil { xl.Warn("start new connection to server error: %v", err) return diff --git a/client/service.go b/client/service.go index f5039189..2b1b38aa 100644 --- a/client/service.go +++ b/client/service.go @@ -21,6 +21,7 @@ import ( "io/ioutil" "net" "runtime" + "strconv" "sync" "sync/atomic" "time" @@ -215,8 +216,9 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { return } } - conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, - fmt.Sprintf("%s:%d", svr.cfg.ServerAddr, svr.cfg.ServerPort), tlsConfig) + + address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)) + conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig) if err != nil { return } diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index d84ce827..a085dbc4 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -157,17 +157,16 @@ func parseClientCommonCfgFromIni(content string) (config.ClientCommonConf, error func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) { cfg = config.GetDefaultClientConf() - strs := strings.Split(serverAddr, ":") - if len(strs) < 2 { - err = fmt.Errorf("invalid server_addr") + ipStr, portStr, err := net.SplitHostPort(serverAddr) + if err != nil { + err = fmt.Errorf("invalid server_addr: %v", err) return } - if strs[0] != "" { - cfg.ServerAddr = strs[0] - } - cfg.ServerPort, err = strconv.Atoi(strs[1]) + + cfg.ServerAddr = ipStr + cfg.ServerPort, err = strconv.Atoi(portStr) if err != nil { - err = fmt.Errorf("invalid server_addr") + err = fmt.Errorf("invalid server_addr: %v", err) return } diff --git a/pkg/util/util/http.go b/pkg/util/util/http.go index bbd3f879..e48ef4ab 100644 --- a/pkg/util/util/http.go +++ b/pkg/util/util/http.go @@ -15,6 +15,7 @@ package util import ( + "net" "net/http" "strings" ) @@ -33,6 +34,7 @@ func OkResponse() *http.Response { return res } +// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func. func GetHostFromAddr(addr string) (host string) { strs := strings.Split(addr, ":") if len(strs) > 1 { @@ -42,3 +44,34 @@ func GetHostFromAddr(addr string) (host string) { } return } + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func CanonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return host, nil +} + +// hasPort reports whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} \ No newline at end of file