optimize some code (#3801)

This commit is contained in:
fatedier
2023-11-27 15:47:49 +08:00
committed by GitHub
parent d5b41f1e14
commit 69ae2b0b69
52 changed files with 880 additions and 600 deletions

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package util
package http
import (
"encoding/base64"

128
pkg/util/http/server.go Normal file
View File

@@ -0,0 +1,128 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http
import (
"crypto/tls"
"net"
"net/http"
"net/http/pprof"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/fatedier/frp/assets"
v1 "github.com/fatedier/frp/pkg/config/v1"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
var (
defaultReadTimeout = 60 * time.Second
defaultWriteTimeout = 60 * time.Second
)
type Server struct {
addr string
ln net.Listener
tlsCfg *tls.Config
router *mux.Router
hs *http.Server
authMiddleware mux.MiddlewareFunc
}
func NewServer(cfg v1.WebServerConfig) (*Server, error) {
if cfg.AssetsDir != "" {
assets.Load(cfg.AssetsDir)
}
addr := net.JoinHostPort(cfg.Addr, strconv.Itoa(cfg.Port))
if addr == ":" {
addr = ":http"
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
router := mux.NewRouter()
hs := &http.Server{
Addr: addr,
Handler: router,
ReadTimeout: defaultReadTimeout,
WriteTimeout: defaultWriteTimeout,
}
s := &Server{
addr: addr,
ln: ln,
hs: hs,
router: router,
}
if cfg.PprofEnable {
s.registerPprofHandlers()
}
if cfg.TLS != nil {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertFile, cfg.TLS.KeyFile)
if err != nil {
return nil, err
}
s.tlsCfg = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
s.authMiddleware = netpkg.NewHTTPAuthMiddleware(cfg.User, cfg.Password).SetAuthFailDelay(200 * time.Millisecond).Middleware
return s, nil
}
func (s *Server) Address() string {
return s.addr
}
func (s *Server) Run() error {
ln := s.ln
if s.tlsCfg != nil {
ln = tls.NewListener(ln, s.tlsCfg)
}
return s.hs.Serve(ln)
}
func (s *Server) Close() error {
return s.hs.Close()
}
type RouterRegisterHelper struct {
Router *mux.Router
AssetsFS http.FileSystem
AuthMiddleware mux.MiddlewareFunc
}
func (s *Server) RouteRegister(register func(helper *RouterRegisterHelper)) {
register(&RouterRegisterHelper{
Router: s.router,
AssetsFS: assets.FileSystem,
AuthMiddleware: s.authMiddleware,
})
}
func (s *Server) registerPprofHandlers() {
s.router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
s.router.HandleFunc("/debug/pprof/profile", pprof.Profile)
s.router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
s.router.HandleFunc("/debug/pprof/trace", pprof.Trace)
s.router.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
}

33
pkg/util/net/dns.go Normal file
View File

@@ -0,0 +1,33 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"context"
"net"
)
func SetDefaultDNSAddress(dnsAddress string) {
if _, _, err := net.SplitHostPort(dnsAddress); err != nil {
dnsAddress = net.JoinHostPort(dnsAddress, "53")
}
// Change default dns server
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial("udp", dnsAddress)
},
}
}

View File

@@ -52,7 +52,10 @@ func (l *InternalListener) PutConn(conn net.Conn) error {
conn.Close()
}
})
return err
if err != nil {
return fmt.Errorf("put conn error: listener is closed")
}
return nil
}
func (l *InternalListener) Close() error {

View File

@@ -24,7 +24,7 @@ import (
libnet "github.com/fatedier/golib/net"
"github.com/fatedier/frp/pkg/util/util"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/vhost"
)
@@ -59,10 +59,10 @@ func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, ht
return
}
host, _ = util.CanonicalHost(req.Host)
host, _ = httppkg.CanonicalHost(req.Host)
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
httpUser, httpPwd, _ = util.ParseBasicAuth(proxyAuth)
httpUser, httpPwd, _ = httppkg.ParseBasicAuth(proxyAuth)
}
return
}
@@ -71,7 +71,7 @@ func (muxer *HTTPConnectTCPMuxer) sendConnectResponse(c net.Conn, _ map[string]s
if muxer.passthrough {
return nil
}
res := util.OkResponse()
res := httppkg.OkResponse()
if res.Body != nil {
defer res.Body.Close()
}
@@ -85,7 +85,7 @@ func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, re
return true, nil
}
resp := util.ProxyUnauthorizedResponse()
resp := httppkg.ProxyUnauthorizedResponse()
if resp.Body != nil {
defer resp.Body.Close()
}

View File

@@ -31,8 +31,8 @@ import (
libio "github.com/fatedier/golib/io"
"github.com/fatedier/golib/pool"
frpLog "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
httppkg "github.com/fatedier/frp/pkg/util/http"
logpkg "github.com/fatedier/frp/pkg/util/log"
)
var ErrNoRouteFound = errors.New("no route found")
@@ -61,7 +61,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
Director: func(req *http.Request) {
req.URL.Scheme = "http"
reqRouteInfo := req.Context().Value(RouteInfoKey).(*RequestRouteInfo)
oldHost, _ := util.CanonicalHost(reqRouteInfo.Host)
oldHost, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
rc := rp.GetRouteConfig(oldHost, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
if rc != nil {
@@ -74,7 +74,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
// ignore error here, it will use CreateConnFn instead later
endpoint, _ = rc.ChooseEndpointFn()
reqRouteInfo.Endpoint = endpoint
frpLog.Trace("choose endpoint name [%s] for http request host [%s] path [%s] httpuser [%s]",
logpkg.Trace("choose endpoint name [%s] for http request host [%s] path [%s] httpuser [%s]",
endpoint, oldHost, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
}
// Set {domain}.{location}.{routeByHTTPUser}.{endpoint} as URL host here to let http transport reuse connections.
@@ -116,7 +116,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
BufferPool: newWrapPool(),
ErrorLog: log.New(newWrapLogger(), "", 0),
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
frpLog.Warn("do http proxy request [host: %s] error: %v", req.Host, err)
logpkg.Warn("do http proxy request [host: %s] error: %v", req.Host, err)
rw.WriteHeader(http.StatusNotFound)
_, _ = rw.Write(getNotFoundPageContent())
},
@@ -143,7 +143,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
if ok {
frpLog.Debug("get new HTTP request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
logpkg.Debug("get new HTTP request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
return vr.payload.(*RouteConfig)
}
return nil
@@ -159,7 +159,7 @@ func (rp *HTTPReverseProxy) GetHeaders(domain, location, routeByHTTPUser string)
// CreateConnection create a new connection by route config
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
host, _ := util.CanonicalHost(reqRouteInfo.Host)
host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
if ok {
if byEndpoint {
@@ -303,7 +303,7 @@ func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Requ
}
func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
domain, _ := util.CanonicalHost(req.Host)
domain, _ := httppkg.CanonicalHost(req.Host)
location := req.URL.Path
user, passwd, _ := req.BasicAuth()
if !rp.CheckAuth(domain, location, user, user, passwd) {
@@ -333,6 +333,6 @@ type wrapLogger struct{}
func newWrapLogger() *wrapLogger { return &wrapLogger{} }
func (l *wrapLogger) Write(p []byte) (n int, err error) {
frpLog.Warn("%s", string(bytes.TrimRight(p, "\n")))
logpkg.Warn("%s", string(bytes.TrimRight(p, "\n")))
return len(p), nil
}

View File

@@ -20,7 +20,7 @@ import (
"net/http"
"os"
frpLog "github.com/fatedier/frp/pkg/util/log"
logpkg "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/version"
)
@@ -58,7 +58,7 @@ func getNotFoundPageContent() []byte {
if NotFoundPagePath != "" {
buf, err = os.ReadFile(NotFoundPagePath)
if err != nil {
frpLog.Warn("read custom 404 page error: %v", err)
logpkg.Warn("read custom 404 page error: %v", err)
buf = []byte(NotFound)
}
} else {

View File

@@ -22,7 +22,7 @@ import (
"github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -284,7 +284,7 @@ func (l *Listener) Accept() (net.Conn, error) {
xl.Debug("rewrite host to [%s] success", l.rewriteHost)
conn = sConn
}
return utilnet.NewContextConn(l.ctx, conn), nil
return netpkg.NewContextConn(l.ctx, conn), nil
}
func (l *Listener) Close() error {