tcpmux: support authentication (#3345)

This commit is contained in:
fatedier
2023-03-07 19:53:32 +08:00
committed by GitHub
parent 54eb704650
commit 862b1642ba
11 changed files with 324 additions and 71 deletions

View File

@@ -31,18 +31,21 @@ import (
type HTTPConnectTCPMuxer struct {
*vhost.Muxer
passthrough bool
authRequired bool // Not supported until we really need this.
// If passthrough is set to true, the CONNECT request will be forwarded to the backend service.
// Otherwise, it will return an OK response to the client and forward the remaining content to the backend service.
passthrough bool
}
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
ret := &HTTPConnectTCPMuxer{passthrough: passthrough, authRequired: false}
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, nil, ret.sendConnectResponse, nil, timeout)
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
mux.SetCheckAuthFunc(ret.auth).
SetSuccessHookFunc(ret.sendConnectResponse)
ret.Muxer = mux
return ret, err
}
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host string, httpUser string, err error) {
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {
bufioReader := bufio.NewReader(rd)
req, err := http.ReadRequest(bufioReader)
@@ -58,7 +61,7 @@ func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host str
host, _ = util.CanonicalHost(req.Host)
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
httpUser, _, _ = util.ParseBasicAuth(proxyAuth)
httpUser, httpPwd, _ = util.ParseBasicAuth(proxyAuth)
}
return
}
@@ -74,11 +77,26 @@ func (muxer *HTTPConnectTCPMuxer) sendConnectResponse(c net.Conn, reqInfo map[st
return res.Write(c)
}
func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, reqInfo map[string]string) (bool, error) {
reqUsername := reqInfo["HTTPUser"]
reqPassword := reqInfo["HTTPPwd"]
if username == reqUsername && password == reqPassword {
return true, nil
}
resp := util.ProxyUnauthorizedResponse()
if resp.Body != nil {
defer resp.Body.Close()
}
_ = resp.Write(c)
return false, nil
}
func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) {
reqInfoMap := make(map[string]string, 0)
sc, rd := gnet.NewSharedConn(c)
host, httpUser, err := muxer.readHTTPConnectRequest(rd)
host, httpUser, httpPwd, err := muxer.readHTTPConnectRequest(rd)
if err != nil {
return nil, reqInfoMap, err
}
@@ -86,18 +104,11 @@ func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn,
reqInfoMap["Host"] = host
reqInfoMap["Scheme"] = "tcp"
reqInfoMap["HTTPUser"] = httpUser
reqInfoMap["HTTPPwd"] = httpPwd
outConn := c
if muxer.passthrough {
outConn = sc
if muxer.authRequired && httpUser == "" {
resp := util.ProxyUnauthorizedResponse()
if resp.Body != nil {
defer resp.Body.Close()
}
_ = resp.Write(c)
outConn = c
}
}
return outConn, reqInfoMap, nil
}

View File

@@ -28,7 +28,10 @@ type HTTPSMuxer struct {
}
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
mux, err := NewMuxer(listener, GetHTTPSHostname, nil, nil, nil, timeout)
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
if err != nil {
return nil, err
}
return &HTTPSMuxer{mux}, err
}

View File

@@ -85,17 +85,3 @@ func notFoundResponse() *http.Response {
}
return res
}
func noAuthResponse() *http.Response {
header := make(map[string][]string)
header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`}
res := &http.Response{
Status: "401 Not authorized",
StatusCode: 401,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: header,
}
return res
}

View File

@@ -43,43 +43,55 @@ type RequestRouteInfo struct {
type (
muxFunc func(net.Conn) (net.Conn, map[string]string, error)
httpAuthFunc func(net.Conn, string, string, string) (bool, error)
authFunc func(conn net.Conn, username, password string, reqInfoMap map[string]string) (bool, error)
hostRewriteFunc func(net.Conn, string) (net.Conn, error)
successFunc func(net.Conn, map[string]string) error
successHookFunc func(net.Conn, map[string]string) error
)
// Muxer is only used for https and tcpmux proxy.
// Muxer is a functional component used for https and tcpmux proxies.
// It accepts connections and extracts vhost information from the beginning of the connection data.
// It then routes the connection to its appropriate listener.
type Muxer struct {
listener net.Listener
timeout time.Duration
listener net.Listener
timeout time.Duration
vhostFunc muxFunc
authFunc httpAuthFunc
successFunc successFunc
rewriteFunc hostRewriteFunc
checkAuth authFunc
successHook successHookFunc
rewriteHost hostRewriteFunc
registryRouter *Routers
}
func NewMuxer(
listener net.Listener,
vhostFunc muxFunc,
authFunc httpAuthFunc,
successFunc successFunc,
rewriteFunc hostRewriteFunc,
timeout time.Duration,
) (mux *Muxer, err error) {
mux = &Muxer{
listener: listener,
timeout: timeout,
vhostFunc: vhostFunc,
authFunc: authFunc,
successFunc: successFunc,
rewriteFunc: rewriteFunc,
registryRouter: NewRouters(),
}
go mux.run()
return mux, nil
}
func (v *Muxer) SetCheckAuthFunc(f authFunc) *Muxer {
v.checkAuth = f
return v
}
func (v *Muxer) SetSuccessHookFunc(f successHookFunc) *Muxer {
v.successHook = f
return v
}
func (v *Muxer) SetRewriteHostFunc(f hostRewriteFunc) *Muxer {
v.rewriteHost = f
return v
}
type ChooseEndpointFunc func() (string, error)
type CreateConnFunc func(remoteAddr string) (net.Conn, error)
@@ -101,7 +113,7 @@ type RouteConfig struct {
CreateConnByEndpointFn CreateConnByEndpointFunc
}
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil
// listen for a new domain name, if rewriteHost is not empty and rewriteHost func is not nil,
// then rewrite the host header to rewriteHost
func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err error) {
l = &Listener{
@@ -109,8 +121,8 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err
location: cfg.Location,
routeByHTTPUser: cfg.RouteByHTTPUser,
rewriteHost: cfg.RewriteHost,
userName: cfg.Username,
passWord: cfg.Password,
username: cfg.Username,
password: cfg.Password,
mux: v,
accept: make(chan net.Conn),
ctx: ctx,
@@ -205,25 +217,20 @@ func (v *Muxer) handle(c net.Conn) {
}
xl := xlog.FromContextSafe(l.ctx)
if v.successFunc != nil {
if err := v.successFunc(c, reqInfoMap); err != nil {
if v.successHook != nil {
if err := v.successHook(c, reqInfoMap); err != nil {
xl.Info("success func failure on vhost connection: %v", err)
_ = c.Close()
return
}
}
// if authFunc is exist and username/password is set
// if checkAuth func is exist and username/password is set
// then verify user access
if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" {
bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"])
if !bAccess || err != nil {
xl.Debug("check http Authorization failed")
res := noAuthResponse()
if res.Body != nil {
defer res.Body.Close()
}
_ = res.Write(c)
if l.mux.checkAuth != nil && l.username != "" {
ok, err := l.mux.checkAuth(c, l.username, l.password, reqInfoMap)
if !ok || err != nil {
xl.Debug("auth failed for user: %s", l.username)
_ = c.Close()
return
}
@@ -249,8 +256,8 @@ type Listener struct {
location string
routeByHTTPUser string
rewriteHost string
userName string
passWord string
username string
password string
mux *Muxer // for closing Muxer
accept chan net.Conn
ctx context.Context
@@ -263,11 +270,11 @@ func (l *Listener) Accept() (net.Conn, error) {
return nil, fmt.Errorf("Listener closed")
}
// if rewriteFunc is exist
// if rewriteHost func is exist
// rewrite http requests with a modified host header
// if l.rewriteHost is empty, nothing to do
if l.mux.rewriteFunc != nil {
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
if l.mux.rewriteHost != nil {
sConn, err := l.mux.rewriteHost(conn, l.rewriteHost)
if err != nil {
xl.Warn("host header rewrite failed: %v", err)
return nil, fmt.Errorf("host header rewrite failed")