use net.JoinHostPort instead of fmt.Sprintf (#2791)

This commit is contained in:
fatedier
2022-02-09 15:19:35 +08:00
committed by GitHub
parent b2311e55e7
commit 6194273615
17 changed files with 61 additions and 49 deletions

View File

@@ -18,6 +18,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
@@ -163,7 +164,7 @@ type UDPListener struct {
}
func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
if err != nil {
return l, err
}

View File

@@ -2,9 +2,9 @@ package net
import (
"errors"
"fmt"
"net"
"net/http"
"strconv"
"golang.org/x/net/websocket"
)
@@ -52,7 +52,7 @@ func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
}
func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
if err != nil {
return nil, err
}

View File

@@ -48,7 +48,7 @@ func readHTTPConnectRequest(rd io.Reader) (host string, err error) {
return
}
host = util.GetHostFromAddr(req.Host)
host, _ = util.CanonicalHost(req.Host)
return
}

View File

@@ -34,17 +34,6 @@ func OkResponse() *http.Response {
return res
}
// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func.
func GetHostFromAddr(addr string) (host string) {
strs := strings.Split(addr, ":")
if len(strs) > 1 {
host = strs[0]
} else {
host = addr
}
return
}
// canonicalHost strips port from host if present and returns the canonicalized
// host name.
func CanonicalHost(host string) (string, error) {

View File

@@ -19,6 +19,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
)
@@ -52,7 +53,7 @@ func CanonicalAddr(host string, port int) (addr string) {
if port == 80 || port == 443 {
addr = host
} else {
addr = fmt.Sprintf("%s:%d", host, port)
addr = net.JoinHostPort(host, strconv.Itoa(port))
}
return
}

View File

@@ -59,7 +59,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
Director: func(req *http.Request) {
req.URL.Scheme = "http"
url := req.Context().Value(RouteInfoURL).(string)
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string))
rc := rp.GetRouteConfig(oldHost, url)
if rc != nil {
if rc.RewriteHost != "" {
@@ -81,7 +81,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
IdleConnTimeout: 60 * time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value(RouteInfoURL).(string)
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string))
remote := ctx.Value(RouteInfoRemote).(string)
return rp.CreateConnection(host, url, remote)
},
@@ -191,7 +191,7 @@ func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router
}
func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
domain := util.GetHostFromAddr(req.Host)
domain, _ := util.CanonicalHost(req.Host)
location := req.URL.Path
user, passwd, _ := req.BasicAuth()
if !rp.CheckAuth(domain, location, user, passwd) {