From c71efde303102cdb2a16a868391ac72c0828bee1 Mon Sep 17 00:00:00 2001 From: fatedier Date: Sun, 28 May 2023 16:50:43 +0800 Subject: [PATCH] refactor the code related to xtcp (#3449) --- Makefile | 3 + client/admin_api.go | 5 +- client/control.go | 48 +- client/proxy/proxy.go | 491 +------------------- client/proxy/proxy_manager.go | 48 +- client/proxy/proxy_wrapper.go | 28 +- client/proxy/sudp.go | 190 ++++++++ client/proxy/udp.go | 157 +++++++ client/proxy/xtcp.go | 200 ++++++++ client/service.go | 10 +- client/visitor.go | 575 ----------------------- client/visitor/stcp.go | 118 +++++ client/visitor/sudp.go | 262 +++++++++++ client/visitor/visitor.go | 77 ++++ client/{ => visitor}/visitor_manager.go | 48 +- client/visitor/xtcp.go | 410 +++++++++++++++++ cmd/frpc/sub/nathole.go | 47 +- cmd/frpc/sub/root.go | 45 +- go.mod | 9 +- go.sum | 18 +- pkg/config/client_test.go | 3 + pkg/config/proxy.go | 3 - pkg/config/server.go | 49 +- pkg/config/server_test.go | 78 ++-- pkg/config/visitor.go | 27 +- pkg/config/visitor_test.go | 3 + pkg/metrics/mem/server.go | 13 +- pkg/msg/msg.go | 142 +++--- pkg/nathole/analysis.go | 328 ++++++++++++++ pkg/nathole/classify.go | 75 ++- pkg/nathole/controller.go | 382 ++++++++++++++++ pkg/nathole/discovery.go | 73 +-- pkg/nathole/nathole.go | 579 ++++++++++++++++-------- pkg/nathole/utils.go | 47 ++ pkg/transport/message.go | 119 +++++ pkg/transport/tls.go | 14 + pkg/util/net/udp.go | 8 + pkg/util/util/slice.go | 25 - pkg/util/util/slice_test.go | 49 -- pkg/util/util/util.go | 19 +- pkg/util/util/util_test.go | 43 +- server/control.go | 55 ++- server/proxy/xtcp.go | 29 +- server/service.go | 52 ++- 44 files changed, 3305 insertions(+), 1699 deletions(-) create mode 100644 client/proxy/sudp.go create mode 100644 client/proxy/udp.go create mode 100644 client/proxy/xtcp.go delete mode 100644 client/visitor.go create mode 100644 client/visitor/stcp.go create mode 100644 client/visitor/sudp.go create mode 100644 client/visitor/visitor.go rename client/{ => visitor}/visitor_manager.go (70%) create mode 100644 client/visitor/xtcp.go create mode 100644 pkg/nathole/analysis.go create mode 100644 pkg/nathole/controller.go create mode 100644 pkg/transport/message.go delete mode 100644 pkg/util/util/slice.go delete mode 100644 pkg/util/util/slice_test.go diff --git a/Makefile b/Makefile index 4600abcb..e9b9ec54 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,9 @@ fmt: fmt-more: gofumpt -l -w . +gci: + gci write -s standard -s default -s "prefix(github.com/fatedier/frp/)" ./ + vet: go vet ./... diff --git a/client/admin_api.go b/client/admin_api.go index d0177076..2a3633ae 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -25,10 +25,11 @@ import ( "strconv" "strings" + "github.com/samber/lo" + "github.com/fatedier/frp/client/proxy" "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/util/log" - "github.com/fatedier/frp/pkg/util/util" ) type GeneralResponse struct { @@ -98,7 +99,7 @@ func NewProxyStatusResp(status *proxy.WorkingStatus, serverAddr string) ProxySta if status.Err == "" { psr.RemoteAddr = status.RemoteAddr - if util.InSlice(status.Type, []string{"tcp", "udp"}) { + if lo.Contains([]string{"tcp", "udp"}, status.Type) { psr.RemoteAddr = serverAddr + psr.RemoteAddr } } diff --git a/client/control.go b/client/control.go index 9fb15c01..7626bde8 100644 --- a/client/control.go +++ b/client/control.go @@ -25,14 +25,21 @@ import ( "github.com/fatedier/golib/crypto" "github.com/fatedier/frp/client/proxy" + "github.com/fatedier/frp/client/visitor" "github.com/fatedier/frp/pkg/auth" "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" ) type Control struct { - // uniq id got from frps, attach it in loginMsg + // service context + ctx context.Context + xl *xlog.Logger + + // Unique ID obtained from frps. + // It should be attached to the login message when reconnecting. runID string // manage all proxies @@ -40,7 +47,7 @@ type Control struct { pm *proxy.Manager // manage all visitors - vm *VisitorManager + vm *visitor.Manager // control connection conn net.Conn @@ -68,16 +75,10 @@ type Control struct { writerShutdown *shutdown.Shutdown msgHandlerShutdown *shutdown.Shutdown - // The UDP port that the server is listening on - serverUDPPort int - - xl *xlog.Logger - - // service context - ctx context.Context - // sets authentication based on selected method authSetter auth.Setter + + msgTransporter transport.MessageTransporter } func NewControl( @@ -85,11 +86,12 @@ func NewControl( clientCfg config.ClientCommonConf, pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.VisitorConf, - serverUDPPort int, authSetter auth.Setter, ) *Control { // new xlog instance ctl := &Control{ + ctx: ctx, + xl: xlog.FromContextSafe(ctx), runID: runID, conn: conn, cm: cm, @@ -102,14 +104,12 @@ func NewControl( readerShutdown: shutdown.New(), writerShutdown: shutdown.New(), msgHandlerShutdown: shutdown.New(), - serverUDPPort: serverUDPPort, - xl: xlog.FromContextSafe(ctx), - ctx: ctx, authSetter: authSetter, } - ctl.pm = proxy.NewManager(ctl.ctx, ctl.sendCh, clientCfg, serverUDPPort) + ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) + ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) - ctl.vm = NewVisitorManager(ctl.ctx, ctl) + ctl.vm = visitor.NewManager(ctl.ctx, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) ctl.vm.Reload(visitorCfgs) return ctl } @@ -173,6 +173,16 @@ func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { } } +func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) { + xl := ctl.xl + + // Dispatch the NatHoleResp message to the related proxy. + ok := ctl.msgTransporter.DispatchWithType(inMsg, msg.TypeNameNatHoleResp, inMsg.TransactionID) + if !ok { + xl.Trace("dispatch NatHoleResp message to related proxy error") + } +} + func (ctl *Control) Close() error { return ctl.GracefulClose(0) } @@ -188,7 +198,7 @@ func (ctl *Control) GracefulClose(d time.Duration) error { return nil } -// ClosedDoneCh returns a channel which will be closed after all resources are released +// ClosedDoneCh returns a channel that will be closed after all resources are released func (ctl *Control) ClosedDoneCh() <-chan struct{} { return ctl.closedDoneCh } @@ -250,7 +260,7 @@ func (ctl *Control) writer() { } } -// msgHandler handles all channel events and do corresponding operations. +// msgHandler handles all channel events and performs corresponding operations. func (ctl *Control) msgHandler() { xl := ctl.xl defer func() { @@ -307,6 +317,8 @@ func (ctl *Control) msgHandler() { go ctl.HandleReqWorkConn(m) case *msg.NewProxyResp: ctl.HandleNewProxyResp(m) + case *msg.NatHoleResp: + ctl.HandleNatHoleResp(m) case *msg.Pong: if m.Error != "" { xl.Error("Pong contains error: %s", m.Error) diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index c8ea4f2c..61e6763e 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -24,20 +24,16 @@ import ( "sync" "time" - "github.com/fatedier/golib/errors" frpIo "github.com/fatedier/golib/io" libdial "github.com/fatedier/golib/net/dial" - "github.com/fatedier/golib/pool" - fmux "github.com/hashicorp/yamux" pp "github.com/pires/go-proxyproto" "golang.org/x/time/rate" "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/msg" plugin "github.com/fatedier/frp/pkg/plugin/client" - "github.com/fatedier/frp/pkg/proto/udp" + "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/limit" - frpNet "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -51,7 +47,12 @@ type Proxy interface { Close() } -func NewProxy(ctx context.Context, pxyConf config.ProxyConf, clientCfg config.ClientCommonConf, serverUDPPort int) (pxy Proxy) { +func NewProxy( + ctx context.Context, + pxyConf config.ProxyConf, + clientCfg config.ClientCommonConf, + msgTransporter transport.MessageTransporter, +) (pxy Proxy) { var limiter *rate.Limiter limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes() if limitBytes > 0 && pxyConf.GetBaseInfo().BandwidthLimitMode == config.BandwidthLimitModeClient { @@ -59,11 +60,11 @@ func NewProxy(ctx context.Context, pxyConf config.ProxyConf, clientCfg config.Cl } baseProxy := BaseProxy{ - clientCfg: clientCfg, - serverUDPPort: serverUDPPort, - limiter: limiter, - xl: xlog.FromContextSafe(ctx), - ctx: ctx, + clientCfg: clientCfg, + limiter: limiter, + msgTransporter: msgTransporter, + xl: xlog.FromContextSafe(ctx), + ctx: ctx, } switch cfg := pxyConf.(type) { case *config.TCPProxyConf: @@ -112,10 +113,10 @@ func NewProxy(ctx context.Context, pxyConf config.ProxyConf, clientCfg config.Cl } type BaseProxy struct { - closed bool - clientCfg config.ClientCommonConf - serverUDPPort int - limiter *rate.Limiter + closed bool + clientCfg config.ClientCommonConf + msgTransporter transport.MessageTransporter + limiter *rate.Limiter mu sync.RWMutex xl *xlog.Logger @@ -267,466 +268,6 @@ func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { conn, []byte(pxy.clientCfg.Token), m) } -// XTCP -type XTCPProxy struct { - *BaseProxy - - cfg *config.XTCPProxyConf - proxyPlugin plugin.Plugin -} - -func (pxy *XTCPProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) - if err != nil { - return - } - } - return -} - -func (pxy *XTCPProxy) Close() { - if pxy.proxyPlugin != nil { - pxy.proxyPlugin.Close() - } -} - -func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - xl := pxy.xl - defer conn.Close() - var natHoleSidMsg msg.NatHoleSid - err := msg.ReadMsgInto(conn, &natHoleSidMsg) - if err != nil { - xl.Error("xtcp read from workConn error: %v", err) - return - } - - natHoleClientMsg := &msg.NatHoleClient{ - ProxyName: pxy.cfg.ProxyName, - Sid: natHoleSidMsg.Sid, - } - serverAddr := pxy.clientCfg.NatHoleServerAddr - if serverAddr == "" { - serverAddr = pxy.clientCfg.ServerAddr - } - raddr, _ := net.ResolveUDPAddr("udp", - net.JoinHostPort(serverAddr, strconv.Itoa(pxy.serverUDPPort))) - clientConn, err := net.DialUDP("udp", nil, raddr) - if err != nil { - xl.Error("dial server udp addr error: %v", err) - return - } - defer clientConn.Close() - - err = msg.WriteMsg(clientConn, natHoleClientMsg) - if err != nil { - xl.Error("send natHoleClientMsg to server error: %v", err) - return - } - - // Wait for client address at most 5 seconds. - var natHoleRespMsg msg.NatHoleResp - _ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) - - buf := pool.GetBuf(1024) - n, err := clientConn.Read(buf) - if err != nil { - xl.Error("get natHoleRespMsg error: %v", err) - return - } - err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) - if err != nil { - xl.Error("get natHoleRespMsg error: %v", err) - return - } - _ = clientConn.SetReadDeadline(time.Time{}) - _ = clientConn.Close() - - if natHoleRespMsg.Error != "" { - xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) - return - } - - xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) - - // Send detect message - host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr) - if err != nil { - xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err) - } - laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) - - port, err := strconv.ParseInt(portStr, 10, 64) - if err != nil { - xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) - return - } - _ = pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid)) - xl.Trace("send all detect msg done") - - if err := msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}); err != nil { - xl.Error("write message error: %v", err) - return - } - - // Listen for clientConn's address and wait for visitor connection - lConn, err := net.ListenUDP("udp", laddr) - if err != nil { - xl.Error("listen on visitorConn's local address error: %v", err) - return - } - defer lConn.Close() - - _ = lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) - sidBuf := pool.GetBuf(1024) - var uAddr *net.UDPAddr - n, uAddr, err = lConn.ReadFromUDP(sidBuf) - if err != nil { - xl.Warn("get sid from visitor error: %v", err) - return - } - _ = lConn.SetReadDeadline(time.Time{}) - if string(sidBuf[:n]) != natHoleRespMsg.Sid { - xl.Warn("incorrect sid from visitor") - return - } - pool.PutBuf(sidBuf) - xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) - - if _, err := lConn.WriteToUDP(sidBuf[:n], uAddr); err != nil { - xl.Error("write uaddr error: %v", err) - return - } - - kcpConn, err := frpNet.NewKCPConnFromUDP(lConn, false, uAddr.String()) - if err != nil { - xl.Error("create kcp connection from udp connection error: %v", err) - return - } - - fmuxCfg := fmux.DefaultConfig() - fmuxCfg.KeepAliveInterval = 5 * time.Second - fmuxCfg.LogOutput = io.Discard - sess, err := fmux.Server(kcpConn, fmuxCfg) - if err != nil { - xl.Error("create yamux server from kcp connection error: %v", err) - return - } - defer sess.Close() - muxConn, err := sess.Accept() - if err != nil { - xl.Error("accept for yamux connection error: %v", err) - return - } - - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, - muxConn, []byte(pxy.cfg.Sk), m) -} - -func (pxy *XTCPProxy) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) { - daddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(addr, strconv.Itoa(port))) - if err != nil { - return err - } - - tConn, err := net.DialUDP("udp", laddr, daddr) - if err != nil { - return err - } - - // uConn := ipv4.NewConn(tConn) - // uConn.SetTTL(3) - - if _, err := tConn.Write(content); err != nil { - return err - } - return tConn.Close() -} - -// UDP -type UDPProxy struct { - *BaseProxy - - cfg *config.UDPProxyConf - - localAddr *net.UDPAddr - readCh chan *msg.UDPPacket - - // include msg.UDPPacket and msg.Ping - sendCh chan msg.Message - workConn net.Conn -} - -func (pxy *UDPProxy) Run() (err error) { - pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort))) - if err != nil { - return - } - return -} - -func (pxy *UDPProxy) Close() { - pxy.mu.Lock() - defer pxy.mu.Unlock() - - if !pxy.closed { - pxy.closed = true - if pxy.workConn != nil { - pxy.workConn.Close() - } - if pxy.readCh != nil { - close(pxy.readCh) - } - if pxy.sendCh != nil { - close(pxy.sendCh) - } - } -} - -func (pxy *UDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - xl := pxy.xl - xl.Info("incoming a new work connection for udp proxy, %s", conn.RemoteAddr().String()) - // close resources releated with old workConn - pxy.Close() - - var rwc io.ReadWriteCloser = conn - var err error - if pxy.limiter != nil { - rwc = frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { - return conn.Close() - }) - } - if pxy.cfg.UseEncryption { - rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.clientCfg.Token)) - if err != nil { - conn.Close() - xl.Error("create encryption stream error: %v", err) - return - } - } - if pxy.cfg.UseCompression { - rwc = frpIo.WithCompression(rwc) - } - conn = frpNet.WrapReadWriteCloserToConn(rwc, conn) - - pxy.mu.Lock() - pxy.workConn = conn - pxy.readCh = make(chan *msg.UDPPacket, 1024) - pxy.sendCh = make(chan msg.Message, 1024) - pxy.closed = false - pxy.mu.Unlock() - - workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) { - for { - var udpMsg msg.UDPPacket - if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil { - xl.Warn("read from workConn for udp error: %v", errRet) - return - } - if errRet := errors.PanicToError(func() { - xl.Trace("get udp package from workConn: %s", udpMsg.Content) - readCh <- &udpMsg - }); errRet != nil { - xl.Info("reader goroutine for udp work connection closed: %v", errRet) - return - } - } - } - workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) { - defer func() { - xl.Info("writer goroutine for udp work connection closed") - }() - var errRet error - for rawMsg := range sendCh { - switch m := rawMsg.(type) { - case *msg.UDPPacket: - xl.Trace("send udp package to workConn: %s", m.Content) - case *msg.Ping: - xl.Trace("send ping message to udp workConn") - } - if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil { - xl.Error("udp work write error: %v", errRet) - return - } - } - } - heartbeatFn := func(sendCh chan msg.Message) { - var errRet error - for { - time.Sleep(time.Duration(30) * time.Second) - if errRet = errors.PanicToError(func() { - sendCh <- &msg.Ping{} - }); errRet != nil { - xl.Trace("heartbeat goroutine for udp work connection closed") - break - } - } - } - - go workConnSenderFn(pxy.workConn, pxy.sendCh) - go workConnReaderFn(pxy.workConn, pxy.readCh) - go heartbeatFn(pxy.sendCh) - udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize)) -} - -type SUDPProxy struct { - *BaseProxy - - cfg *config.SUDPProxyConf - - localAddr *net.UDPAddr - - closeCh chan struct{} -} - -func (pxy *SUDPProxy) Run() (err error) { - pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort))) - if err != nil { - return - } - return -} - -func (pxy *SUDPProxy) Close() { - pxy.mu.Lock() - defer pxy.mu.Unlock() - select { - case <-pxy.closeCh: - return - default: - close(pxy.closeCh) - } -} - -func (pxy *SUDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - xl := pxy.xl - xl.Info("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String()) - - var rwc io.ReadWriteCloser = conn - var err error - if pxy.limiter != nil { - rwc = frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { - return conn.Close() - }) - } - if pxy.cfg.UseEncryption { - rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.clientCfg.Token)) - if err != nil { - conn.Close() - xl.Error("create encryption stream error: %v", err) - return - } - } - if pxy.cfg.UseCompression { - rwc = frpIo.WithCompression(rwc) - } - conn = frpNet.WrapReadWriteCloserToConn(rwc, conn) - - workConn := conn - readCh := make(chan *msg.UDPPacket, 1024) - sendCh := make(chan msg.Message, 1024) - isClose := false - - mu := &sync.Mutex{} - - closeFn := func() { - mu.Lock() - defer mu.Unlock() - if isClose { - return - } - - isClose = true - if workConn != nil { - workConn.Close() - } - close(readCh) - close(sendCh) - } - - // udp service <- frpc <- frps <- frpc visitor <- user - workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) { - defer closeFn() - - for { - // first to check sudp proxy is closed or not - select { - case <-pxy.closeCh: - xl.Trace("frpc sudp proxy is closed") - return - default: - } - - var udpMsg msg.UDPPacket - if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil { - xl.Warn("read from workConn for sudp error: %v", errRet) - return - } - - if errRet := errors.PanicToError(func() { - readCh <- &udpMsg - }); errRet != nil { - xl.Warn("reader goroutine for sudp work connection closed: %v", errRet) - return - } - } - } - - // udp service -> frpc -> frps -> frpc visitor -> user - workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) { - defer func() { - closeFn() - xl.Info("writer goroutine for sudp work connection closed") - }() - - var errRet error - for rawMsg := range sendCh { - switch m := rawMsg.(type) { - case *msg.UDPPacket: - xl.Trace("frpc send udp package to frpc visitor, [udp local: %v, remote: %v], [tcp work conn local: %v, remote: %v]", - m.LocalAddr.String(), m.RemoteAddr.String(), conn.LocalAddr().String(), conn.RemoteAddr().String()) - case *msg.Ping: - xl.Trace("frpc send ping message to frpc visitor") - } - - if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil { - xl.Error("sudp work write error: %v", errRet) - return - } - } - } - - heartbeatFn := func(sendCh chan msg.Message) { - ticker := time.NewTicker(30 * time.Second) - defer func() { - ticker.Stop() - closeFn() - }() - - var errRet error - for { - select { - case <-ticker.C: - if errRet = errors.PanicToError(func() { - sendCh <- &msg.Ping{} - }); errRet != nil { - xl.Warn("heartbeat goroutine for sudp work connection closed") - return - } - case <-pxy.closeCh: - xl.Trace("frpc sudp proxy is closed") - return - } - } - } - - go workConnSenderFn(workConn, sendCh) - go workConnReaderFn(workConn, readCh) - go heartbeatFn(sendCh) - - udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize)) -} - // Common handler for tcp work connections. func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin, baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn, diff --git a/client/proxy/proxy_manager.go b/client/proxy/proxy_manager.go index 563531e8..f5d7502c 100644 --- a/client/proxy/proxy_manager.go +++ b/client/proxy/proxy_manager.go @@ -1,3 +1,17 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package proxy import ( @@ -6,37 +20,36 @@ import ( "net" "sync" - "github.com/fatedier/golib/errors" - "github.com/fatedier/frp/client/event" "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" ) type Manager struct { - sendCh chan (msg.Message) - proxies map[string]*Wrapper + proxies map[string]*Wrapper + msgTransporter transport.MessageTransporter closed bool mu sync.RWMutex clientCfg config.ClientCommonConf - // The UDP port that the server is listening on - serverUDPPort int - ctx context.Context } -func NewManager(ctx context.Context, msgSendCh chan (msg.Message), clientCfg config.ClientCommonConf, serverUDPPort int) *Manager { +func NewManager( + ctx context.Context, + clientCfg config.ClientCommonConf, + msgTransporter transport.MessageTransporter, +) *Manager { return &Manager{ - sendCh: msgSendCh, - proxies: make(map[string]*Wrapper), - closed: false, - clientCfg: clientCfg, - serverUDPPort: serverUDPPort, - ctx: ctx, + proxies: make(map[string]*Wrapper), + msgTransporter: msgTransporter, + closed: false, + clientCfg: clientCfg, + ctx: ctx, } } @@ -86,10 +99,7 @@ func (pm *Manager) HandleEvent(payload interface{}) error { return event.ErrPayloadType } - err := errors.PanicToError(func() { - pm.sendCh <- m - }) - return err + return pm.msgTransporter.Send(m) } func (pm *Manager) GetAllProxyStatus() []*WorkingStatus { @@ -131,7 +141,7 @@ func (pm *Manager) Reload(pxyCfgs map[string]config.ProxyConf) { addPxyNames := make([]string, 0) for name, cfg := range pxyCfgs { if _, ok := pm.proxies[name]; !ok { - pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.serverUDPPort) + pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter) pm.proxies[name] = pxy addPxyNames = append(addPxyNames, name) diff --git a/client/proxy/proxy_wrapper.go b/client/proxy/proxy_wrapper.go index af217f06..f2caa618 100644 --- a/client/proxy/proxy_wrapper.go +++ b/client/proxy/proxy_wrapper.go @@ -1,3 +1,17 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package proxy import ( @@ -14,6 +28,7 @@ import ( "github.com/fatedier/frp/client/health" "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -56,6 +71,8 @@ type Wrapper struct { // event handler handler event.Handler + msgTransporter transport.MessageTransporter + health uint32 lastSendStartMsg time.Time lastStartErr time.Time @@ -67,7 +84,13 @@ type Wrapper struct { ctx context.Context } -func NewWrapper(ctx context.Context, cfg config.ProxyConf, clientCfg config.ClientCommonConf, eventHandler event.Handler, serverUDPPort int) *Wrapper { +func NewWrapper( + ctx context.Context, + cfg config.ProxyConf, + clientCfg config.ClientCommonConf, + eventHandler event.Handler, + msgTransporter transport.MessageTransporter, +) *Wrapper { baseInfo := cfg.GetBaseInfo() xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(baseInfo.ProxyName) pw := &Wrapper{ @@ -80,6 +103,7 @@ func NewWrapper(ctx context.Context, cfg config.ProxyConf, clientCfg config.Clie closeCh: make(chan struct{}), healthNotifyCh: make(chan struct{}), handler: eventHandler, + msgTransporter: msgTransporter, xl: xl, ctx: xlog.NewContext(ctx, xl), } @@ -92,7 +116,7 @@ func NewWrapper(ctx context.Context, cfg config.ProxyConf, clientCfg config.Clie xl.Trace("enable health check monitor") } - pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, serverUDPPort) + pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, pw.msgTransporter) return pw } diff --git a/client/proxy/sudp.go b/client/proxy/sudp.go new file mode 100644 index 00000000..f5405903 --- /dev/null +++ b/client/proxy/sudp.go @@ -0,0 +1,190 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "io" + "net" + "strconv" + "sync" + "time" + + "github.com/fatedier/golib/errors" + frpIo "github.com/fatedier/golib/io" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/udp" + "github.com/fatedier/frp/pkg/util/limit" + frpNet "github.com/fatedier/frp/pkg/util/net" +) + +type SUDPProxy struct { + *BaseProxy + + cfg *config.SUDPProxyConf + + localAddr *net.UDPAddr + + closeCh chan struct{} +} + +func (pxy *SUDPProxy) Run() (err error) { + pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort))) + if err != nil { + return + } + return +} + +func (pxy *SUDPProxy) Close() { + pxy.mu.Lock() + defer pxy.mu.Unlock() + select { + case <-pxy.closeCh: + return + default: + close(pxy.closeCh) + } +} + +func (pxy *SUDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + xl := pxy.xl + xl.Info("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String()) + + var rwc io.ReadWriteCloser = conn + var err error + if pxy.limiter != nil { + rwc = frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { + return conn.Close() + }) + } + if pxy.cfg.UseEncryption { + rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.clientCfg.Token)) + if err != nil { + conn.Close() + xl.Error("create encryption stream error: %v", err) + return + } + } + if pxy.cfg.UseCompression { + rwc = frpIo.WithCompression(rwc) + } + conn = frpNet.WrapReadWriteCloserToConn(rwc, conn) + + workConn := conn + readCh := make(chan *msg.UDPPacket, 1024) + sendCh := make(chan msg.Message, 1024) + isClose := false + + mu := &sync.Mutex{} + + closeFn := func() { + mu.Lock() + defer mu.Unlock() + if isClose { + return + } + + isClose = true + if workConn != nil { + workConn.Close() + } + close(readCh) + close(sendCh) + } + + // udp service <- frpc <- frps <- frpc visitor <- user + workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) { + defer closeFn() + + for { + // first to check sudp proxy is closed or not + select { + case <-pxy.closeCh: + xl.Trace("frpc sudp proxy is closed") + return + default: + } + + var udpMsg msg.UDPPacket + if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil { + xl.Warn("read from workConn for sudp error: %v", errRet) + return + } + + if errRet := errors.PanicToError(func() { + readCh <- &udpMsg + }); errRet != nil { + xl.Warn("reader goroutine for sudp work connection closed: %v", errRet) + return + } + } + } + + // udp service -> frpc -> frps -> frpc visitor -> user + workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) { + defer func() { + closeFn() + xl.Info("writer goroutine for sudp work connection closed") + }() + + var errRet error + for rawMsg := range sendCh { + switch m := rawMsg.(type) { + case *msg.UDPPacket: + xl.Trace("frpc send udp package to frpc visitor, [udp local: %v, remote: %v], [tcp work conn local: %v, remote: %v]", + m.LocalAddr.String(), m.RemoteAddr.String(), conn.LocalAddr().String(), conn.RemoteAddr().String()) + case *msg.Ping: + xl.Trace("frpc send ping message to frpc visitor") + } + + if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil { + xl.Error("sudp work write error: %v", errRet) + return + } + } + } + + heartbeatFn := func(sendCh chan msg.Message) { + ticker := time.NewTicker(30 * time.Second) + defer func() { + ticker.Stop() + closeFn() + }() + + var errRet error + for { + select { + case <-ticker.C: + if errRet = errors.PanicToError(func() { + sendCh <- &msg.Ping{} + }); errRet != nil { + xl.Warn("heartbeat goroutine for sudp work connection closed") + return + } + case <-pxy.closeCh: + xl.Trace("frpc sudp proxy is closed") + return + } + } + } + + go workConnSenderFn(workConn, sendCh) + go workConnReaderFn(workConn, readCh) + go heartbeatFn(sendCh) + + udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize)) +} diff --git a/client/proxy/udp.go b/client/proxy/udp.go new file mode 100644 index 00000000..8a599367 --- /dev/null +++ b/client/proxy/udp.go @@ -0,0 +1,157 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "io" + "net" + "strconv" + "time" + + "github.com/fatedier/golib/errors" + frpIo "github.com/fatedier/golib/io" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/udp" + "github.com/fatedier/frp/pkg/util/limit" + frpNet "github.com/fatedier/frp/pkg/util/net" +) + +// UDP +type UDPProxy struct { + *BaseProxy + + cfg *config.UDPProxyConf + + localAddr *net.UDPAddr + readCh chan *msg.UDPPacket + + // include msg.UDPPacket and msg.Ping + sendCh chan msg.Message + workConn net.Conn +} + +func (pxy *UDPProxy) Run() (err error) { + pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort))) + if err != nil { + return + } + return +} + +func (pxy *UDPProxy) Close() { + pxy.mu.Lock() + defer pxy.mu.Unlock() + + if !pxy.closed { + pxy.closed = true + if pxy.workConn != nil { + pxy.workConn.Close() + } + if pxy.readCh != nil { + close(pxy.readCh) + } + if pxy.sendCh != nil { + close(pxy.sendCh) + } + } +} + +func (pxy *UDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + xl := pxy.xl + xl.Info("incoming a new work connection for udp proxy, %s", conn.RemoteAddr().String()) + // close resources releated with old workConn + pxy.Close() + + var rwc io.ReadWriteCloser = conn + var err error + if pxy.limiter != nil { + rwc = frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { + return conn.Close() + }) + } + if pxy.cfg.UseEncryption { + rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.clientCfg.Token)) + if err != nil { + conn.Close() + xl.Error("create encryption stream error: %v", err) + return + } + } + if pxy.cfg.UseCompression { + rwc = frpIo.WithCompression(rwc) + } + conn = frpNet.WrapReadWriteCloserToConn(rwc, conn) + + pxy.mu.Lock() + pxy.workConn = conn + pxy.readCh = make(chan *msg.UDPPacket, 1024) + pxy.sendCh = make(chan msg.Message, 1024) + pxy.closed = false + pxy.mu.Unlock() + + workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) { + for { + var udpMsg msg.UDPPacket + if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil { + xl.Warn("read from workConn for udp error: %v", errRet) + return + } + if errRet := errors.PanicToError(func() { + xl.Trace("get udp package from workConn: %s", udpMsg.Content) + readCh <- &udpMsg + }); errRet != nil { + xl.Info("reader goroutine for udp work connection closed: %v", errRet) + return + } + } + } + workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) { + defer func() { + xl.Info("writer goroutine for udp work connection closed") + }() + var errRet error + for rawMsg := range sendCh { + switch m := rawMsg.(type) { + case *msg.UDPPacket: + xl.Trace("send udp package to workConn: %s", m.Content) + case *msg.Ping: + xl.Trace("send ping message to udp workConn") + } + if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil { + xl.Error("udp work write error: %v", errRet) + return + } + } + } + heartbeatFn := func(sendCh chan msg.Message) { + var errRet error + for { + time.Sleep(time.Duration(30) * time.Second) + if errRet = errors.PanicToError(func() { + sendCh <- &msg.Ping{} + }); errRet != nil { + xl.Trace("heartbeat goroutine for udp work connection closed") + break + } + } + } + + go workConnSenderFn(pxy.workConn, pxy.sendCh) + go workConnReaderFn(pxy.workConn, pxy.readCh) + go heartbeatFn(pxy.sendCh) + udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize)) +} diff --git a/client/proxy/xtcp.go b/client/proxy/xtcp.go new file mode 100644 index 00000000..9535a314 --- /dev/null +++ b/client/proxy/xtcp.go @@ -0,0 +1,200 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "io" + "net" + "time" + + fmux "github.com/hashicorp/yamux" + "github.com/quic-go/quic-go" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/nathole" + plugin "github.com/fatedier/frp/pkg/plugin/client" + "github.com/fatedier/frp/pkg/transport" + frpNet "github.com/fatedier/frp/pkg/util/net" +) + +// XTCP +type XTCPProxy struct { + *BaseProxy + + cfg *config.XTCPProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *XTCPProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *XTCPProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *XTCPProxy) InWorkConn(conn net.Conn, startWorkConnMsg *msg.StartWorkConn) { + xl := pxy.xl + defer conn.Close() + var natHoleSidMsg msg.NatHoleSid + err := msg.ReadMsgInto(conn, &natHoleSidMsg) + if err != nil { + xl.Error("xtcp read from workConn error: %v", err) + return + } + + prepareResult, err := nathole.Prepare([]string{pxy.clientCfg.NatHoleSTUNServer}) + if err != nil { + xl.Warn("nathole prepare error: %v", err) + return + } + xl.Info("nathole prepare success, nat type: %s, behavior: %s, addresses: %v, assistedAddresses: %v", + prepareResult.NatType, prepareResult.Behavior, prepareResult.Addrs, prepareResult.AssistedAddrs) + defer prepareResult.ListenConn.Close() + + // send NatHoleClient msg to server + transactionID := nathole.NewTransactionID() + natHoleClientMsg := &msg.NatHoleClient{ + TransactionID: transactionID, + ProxyName: pxy.cfg.ProxyName, + Sid: natHoleSidMsg.Sid, + MappedAddrs: prepareResult.Addrs, + AssistedAddrs: prepareResult.AssistedAddrs, + } + + natHoleRespMsg, err := nathole.ExchangeInfo(pxy.ctx, pxy.msgTransporter, transactionID, natHoleClientMsg, 5*time.Second) + if err != nil { + xl.Warn("nathole exchange info error: %v", err) + return + } + + xl.Info("get natHoleRespMsg, sid [%s], protocol [%s], candidate address %v, assisted address %v, detectBehavior: %+v", + natHoleRespMsg.Sid, natHoleRespMsg.Protocol, natHoleRespMsg.CandidateAddrs, + natHoleRespMsg.AssistedAddrs, natHoleRespMsg.DetectBehavior) + + listenConn := prepareResult.ListenConn + newListenConn, raddr, err := nathole.MakeHole(pxy.ctx, listenConn, natHoleRespMsg, []byte(pxy.cfg.Sk)) + if err != nil { + listenConn.Close() + xl.Warn("make hole error: %v", err) + _ = pxy.msgTransporter.Send(&msg.NatHoleReport{ + Sid: natHoleRespMsg.Sid, + Success: false, + }) + return + } + listenConn = newListenConn + xl.Info("establishing nat hole connection successful, sid [%s], remoteAddr [%s]", natHoleRespMsg.Sid, raddr) + + _ = pxy.msgTransporter.Send(&msg.NatHoleReport{ + Sid: natHoleRespMsg.Sid, + Success: true, + }) + + if natHoleRespMsg.Protocol == "kcp" { + pxy.listenByKCP(listenConn, raddr, startWorkConnMsg) + return + } + + // default is quic + pxy.listenByQUIC(listenConn, raddr, startWorkConnMsg) +} + +func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, startWorkConnMsg *msg.StartWorkConn) { + xl := pxy.xl + listenConn.Close() + laddr, _ := net.ResolveUDPAddr("udp", listenConn.LocalAddr().String()) + lConn, err := net.DialUDP("udp", laddr, raddr) + if err != nil { + xl.Warn("dial udp error: %v", err) + return + } + defer lConn.Close() + + remote, err := frpNet.NewKCPConnFromUDP(lConn, true, raddr.String()) + if err != nil { + xl.Warn("create kcp connection from udp connection error: %v", err) + return + } + + fmuxCfg := fmux.DefaultConfig() + fmuxCfg.KeepAliveInterval = 10 * time.Second + fmuxCfg.MaxStreamWindowSize = 2 * 1024 * 1024 + fmuxCfg.LogOutput = io.Discard + session, err := fmux.Server(remote, fmuxCfg) + if err != nil { + xl.Error("create mux session error: %v", err) + return + } + defer session.Close() + + for { + muxConn, err := session.Accept() + if err != nil { + xl.Error("accept connection error: %v", err) + return + } + go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + muxConn, []byte(pxy.cfg.Sk), startWorkConnMsg) + } +} + +func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, startWorkConnMsg *msg.StartWorkConn) { + xl := pxy.xl + defer listenConn.Close() + + tlsConfig, err := transport.NewServerTLSConfig("", "", "") + if err != nil { + xl.Warn("create tls config error: %v", err) + return + } + tlsConfig.NextProtos = []string{"frp"} + quicListener, err := quic.Listen(listenConn, tlsConfig, + &quic.Config{ + MaxIdleTimeout: time.Duration(pxy.clientCfg.QUICMaxIdleTimeout) * time.Second, + MaxIncomingStreams: int64(pxy.clientCfg.QUICMaxIncomingStreams), + KeepAlivePeriod: time.Duration(pxy.clientCfg.QUICKeepalivePeriod) * time.Second, + }, + ) + if err != nil { + xl.Warn("dial quic error: %v", err) + return + } + // only accept one connection from raddr + c, err := quicListener.Accept(pxy.ctx) + if err != nil { + xl.Error("quic accept connection error: %v", err) + return + } + for { + stream, err := c.AcceptStream(pxy.ctx) + if err != nil { + xl.Debug("quic accept stream error: %v", err) + _ = c.CloseWithError(0, "") + return + } + go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + frpNet.QuicStreamToNetConn(stream, c), []byte(pxy.cfg.Sk), startWorkConnMsg) + } +} diff --git a/client/service.go b/client/service.go index bac57167..83439c7f 100644 --- a/client/service.go +++ b/client/service.go @@ -72,9 +72,6 @@ type Service struct { // string if no configuration file was used. cfgFile string - // This is configured by the login response from frps - serverUDPPort int - exit uint32 // 0 means not exit // service context @@ -141,7 +138,7 @@ func (svr *Service) Run() error { util.RandomSleep(10*time.Second, 0.9, 1.1) } else { // login success - ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter) + ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) ctl.Run() svr.ctlMu.Lock() svr.ctl = ctl @@ -223,7 +220,7 @@ func (svr *Service) keepControllerWorking() { // reconnect success, init delayTime delayTime = time.Second - ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter) + ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) ctl.Run() svr.ctlMu.Lock() if svr.ctl != nil { @@ -295,8 +292,7 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) { xl.ResetPrefixes() xl.AppendPrefix(svr.runID) - svr.serverUDPPort = loginRespMsg.ServerUDPPort - xl.Info("login to server success, get run id [%s], server udp port [%d]", loginRespMsg.RunID, loginRespMsg.ServerUDPPort) + xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID) return } diff --git a/client/visitor.go b/client/visitor.go deleted file mode 100644 index 9d51ee33..00000000 --- a/client/visitor.go +++ /dev/null @@ -1,575 +0,0 @@ -// Copyright 2017 fatedier, fatedier@gmail.com -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "bytes" - "context" - "fmt" - "io" - "net" - "strconv" - "sync" - "time" - - "github.com/fatedier/golib/errors" - frpIo "github.com/fatedier/golib/io" - "github.com/fatedier/golib/pool" - fmux "github.com/hashicorp/yamux" - - "github.com/fatedier/frp/pkg/config" - "github.com/fatedier/frp/pkg/msg" - "github.com/fatedier/frp/pkg/proto/udp" - frpNet "github.com/fatedier/frp/pkg/util/net" - "github.com/fatedier/frp/pkg/util/util" - "github.com/fatedier/frp/pkg/util/xlog" -) - -// Visitor is used for forward traffics from local port to remote service. -type Visitor interface { - Run() error - Close() -} - -func NewVisitor(ctx context.Context, ctl *Control, cfg config.VisitorConf) (visitor Visitor) { - xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseInfo().ProxyName) - baseVisitor := BaseVisitor{ - ctl: ctl, - ctx: xlog.NewContext(ctx, xl), - } - switch cfg := cfg.(type) { - case *config.STCPVisitorConf: - visitor = &STCPVisitor{ - BaseVisitor: &baseVisitor, - cfg: cfg, - } - case *config.XTCPVisitorConf: - visitor = &XTCPVisitor{ - BaseVisitor: &baseVisitor, - cfg: cfg, - } - case *config.SUDPVisitorConf: - visitor = &SUDPVisitor{ - BaseVisitor: &baseVisitor, - cfg: cfg, - checkCloseCh: make(chan struct{}), - } - } - return -} - -type BaseVisitor struct { - ctl *Control - l net.Listener - - mu sync.RWMutex - ctx context.Context -} - -type STCPVisitor struct { - *BaseVisitor - - cfg *config.STCPVisitorConf -} - -func (sv *STCPVisitor) Run() (err error) { - sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) - if err != nil { - return - } - - go sv.worker() - return -} - -func (sv *STCPVisitor) Close() { - sv.l.Close() -} - -func (sv *STCPVisitor) worker() { - xl := xlog.FromContextSafe(sv.ctx) - for { - conn, err := sv.l.Accept() - if err != nil { - xl.Warn("stcp local listener closed") - return - } - - go sv.handleConn(conn) - } -} - -func (sv *STCPVisitor) handleConn(userConn net.Conn) { - xl := xlog.FromContextSafe(sv.ctx) - defer userConn.Close() - - xl.Debug("get a new stcp user connection") - visitorConn, err := sv.ctl.connectServer() - if err != nil { - return - } - defer visitorConn.Close() - - now := time.Now().Unix() - newVisitorConnMsg := &msg.NewVisitorConn{ - ProxyName: sv.cfg.ServerName, - SignKey: util.GetAuthKey(sv.cfg.Sk, now), - Timestamp: now, - UseEncryption: sv.cfg.UseEncryption, - UseCompression: sv.cfg.UseCompression, - } - err = msg.WriteMsg(visitorConn, newVisitorConnMsg) - if err != nil { - xl.Warn("send newVisitorConnMsg to server error: %v", err) - return - } - - var newVisitorConnRespMsg msg.NewVisitorConnResp - _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) - if err != nil { - xl.Warn("get newVisitorConnRespMsg error: %v", err) - return - } - _ = visitorConn.SetReadDeadline(time.Time{}) - - if newVisitorConnRespMsg.Error != "" { - xl.Warn("start new visitor connection error: %s", newVisitorConnRespMsg.Error) - return - } - - var remote io.ReadWriteCloser - remote = visitorConn - if sv.cfg.UseEncryption { - remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) - if err != nil { - xl.Error("create encryption stream error: %v", err) - return - } - } - - if sv.cfg.UseCompression { - remote = frpIo.WithCompression(remote) - } - - frpIo.Join(userConn, remote) -} - -type XTCPVisitor struct { - *BaseVisitor - - cfg *config.XTCPVisitorConf -} - -func (sv *XTCPVisitor) Run() (err error) { - sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) - if err != nil { - return - } - - go sv.worker() - return -} - -func (sv *XTCPVisitor) Close() { - sv.l.Close() -} - -func (sv *XTCPVisitor) worker() { - xl := xlog.FromContextSafe(sv.ctx) - for { - conn, err := sv.l.Accept() - if err != nil { - xl.Warn("xtcp local listener closed") - return - } - - go sv.handleConn(conn) - } -} - -func (sv *XTCPVisitor) handleConn(userConn net.Conn) { - xl := xlog.FromContextSafe(sv.ctx) - defer userConn.Close() - - xl.Debug("get a new xtcp user connection") - if sv.ctl.serverUDPPort == 0 { - xl.Error("xtcp is not supported by server") - return - } - - serverAddr := sv.ctl.clientCfg.NatHoleServerAddr - if serverAddr == "" { - serverAddr = sv.ctl.clientCfg.ServerAddr - } - raddr, err := net.ResolveUDPAddr("udp", - net.JoinHostPort(serverAddr, strconv.Itoa(sv.ctl.serverUDPPort))) - if err != nil { - xl.Error("resolve server UDP addr error") - return - } - - visitorConn, err := net.DialUDP("udp", nil, raddr) - if err != nil { - xl.Warn("dial server udp addr error: %v", err) - return - } - defer visitorConn.Close() - - now := time.Now().Unix() - natHoleVisitorMsg := &msg.NatHoleVisitor{ - ProxyName: sv.cfg.ServerName, - SignKey: util.GetAuthKey(sv.cfg.Sk, now), - Timestamp: now, - } - err = msg.WriteMsg(visitorConn, natHoleVisitorMsg) - if err != nil { - xl.Warn("send natHoleVisitorMsg to server error: %v", err) - return - } - - // Wait for client address at most 10 seconds. - var natHoleRespMsg msg.NatHoleResp - _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - buf := pool.GetBuf(1024) - n, err := visitorConn.Read(buf) - if err != nil { - xl.Warn("get natHoleRespMsg error: %v", err) - return - } - - err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) - if err != nil { - xl.Warn("get natHoleRespMsg error: %v", err) - return - } - _ = visitorConn.SetReadDeadline(time.Time{}) - pool.PutBuf(buf) - - if natHoleRespMsg.Error != "" { - xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) - return - } - - xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) - - // Close visitorConn, so we can use it's local address. - visitorConn.Close() - - // send sid message to client - laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) - daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) - if err != nil { - xl.Error("resolve client udp address error: %v", err) - return - } - lConn, err := net.DialUDP("udp", laddr, daddr) - if err != nil { - xl.Error("dial client udp address error: %v", err) - return - } - defer lConn.Close() - - if _, err := lConn.Write([]byte(natHoleRespMsg.Sid)); err != nil { - xl.Error("write sid error: %v", err) - return - } - - // read ack sid from client - sidBuf := pool.GetBuf(1024) - _ = lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) - n, err = lConn.Read(sidBuf) - if err != nil { - xl.Warn("get sid from client error: %v", err) - return - } - _ = lConn.SetReadDeadline(time.Time{}) - if string(sidBuf[:n]) != natHoleRespMsg.Sid { - xl.Warn("incorrect sid from client") - return - } - pool.PutBuf(sidBuf) - - xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) - - // wrap kcp connection - var remote io.ReadWriteCloser - remote, err = frpNet.NewKCPConnFromUDP(lConn, true, natHoleRespMsg.ClientAddr) - if err != nil { - xl.Error("create kcp connection from udp connection error: %v", err) - return - } - - fmuxCfg := fmux.DefaultConfig() - fmuxCfg.KeepAliveInterval = 5 * time.Second - fmuxCfg.LogOutput = io.Discard - sess, err := fmux.Client(remote, fmuxCfg) - if err != nil { - xl.Error("create yamux session error: %v", err) - return - } - defer sess.Close() - muxConn, err := sess.Open() - if err != nil { - xl.Error("open yamux stream error: %v", err) - return - } - - var muxConnRWCloser io.ReadWriteCloser = muxConn - if sv.cfg.UseEncryption { - muxConnRWCloser, err = frpIo.WithEncryption(muxConnRWCloser, []byte(sv.cfg.Sk)) - if err != nil { - xl.Error("create encryption stream error: %v", err) - return - } - } - if sv.cfg.UseCompression { - muxConnRWCloser = frpIo.WithCompression(muxConnRWCloser) - } - - _, _, errs := frpIo.Join(userConn, muxConnRWCloser) - xl.Debug("join connections closed") - if len(errs) > 0 { - xl.Trace("join connections errors: %v", errs) - } -} - -type SUDPVisitor struct { - *BaseVisitor - - checkCloseCh chan struct{} - // udpConn is the listener of udp packet - udpConn *net.UDPConn - readCh chan *msg.UDPPacket - sendCh chan *msg.UDPPacket - - cfg *config.SUDPVisitorConf -} - -// SUDP Run start listen a udp port -func (sv *SUDPVisitor) Run() (err error) { - xl := xlog.FromContextSafe(sv.ctx) - - addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) - if err != nil { - return fmt.Errorf("sudp ResolveUDPAddr error: %v", err) - } - - sv.udpConn, err = net.ListenUDP("udp", addr) - if err != nil { - return fmt.Errorf("listen udp port %s error: %v", addr.String(), err) - } - - sv.sendCh = make(chan *msg.UDPPacket, 1024) - sv.readCh = make(chan *msg.UDPPacket, 1024) - - xl.Info("sudp start to work, listen on %s", addr) - - go sv.dispatcher() - go udp.ForwardUserConn(sv.udpConn, sv.readCh, sv.sendCh, int(sv.ctl.clientCfg.UDPPacketSize)) - - return -} - -func (sv *SUDPVisitor) dispatcher() { - xl := xlog.FromContextSafe(sv.ctx) - - var ( - visitorConn net.Conn - err error - - firstPacket *msg.UDPPacket - ) - - for { - select { - case firstPacket = <-sv.sendCh: - if firstPacket == nil { - xl.Info("frpc sudp visitor proxy is closed") - return - } - case <-sv.checkCloseCh: - xl.Info("frpc sudp visitor proxy is closed") - return - } - - visitorConn, err = sv.getNewVisitorConn() - if err != nil { - xl.Warn("newVisitorConn to frps error: %v, try to reconnect", err) - continue - } - - // visitorConn always be closed when worker done. - sv.worker(visitorConn, firstPacket) - - select { - case <-sv.checkCloseCh: - return - default: - } - } -} - -func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) { - xl := xlog.FromContextSafe(sv.ctx) - xl.Debug("starting sudp proxy worker") - - wg := &sync.WaitGroup{} - wg.Add(2) - closeCh := make(chan struct{}) - - // udp service -> frpc -> frps -> frpc visitor -> user - workConnReaderFn := func(conn net.Conn) { - defer func() { - conn.Close() - close(closeCh) - wg.Done() - }() - - for { - var ( - rawMsg msg.Message - errRet error - ) - - // frpc will send heartbeat in workConn to frpc visitor for keeping alive - _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - if rawMsg, errRet = msg.ReadMsg(conn); errRet != nil { - xl.Warn("read from workconn for user udp conn error: %v", errRet) - return - } - - _ = conn.SetReadDeadline(time.Time{}) - switch m := rawMsg.(type) { - case *msg.Ping: - xl.Debug("frpc visitor get ping message from frpc") - continue - case *msg.UDPPacket: - if errRet := errors.PanicToError(func() { - sv.readCh <- m - xl.Trace("frpc visitor get udp packet from workConn: %s", m.Content) - }); errRet != nil { - xl.Info("reader goroutine for udp work connection closed") - return - } - } - } - } - - // udp service <- frpc <- frps <- frpc visitor <- user - workConnSenderFn := func(conn net.Conn) { - defer func() { - conn.Close() - wg.Done() - }() - - var errRet error - if firstPacket != nil { - if errRet = msg.WriteMsg(conn, firstPacket); errRet != nil { - xl.Warn("sender goroutine for udp work connection closed: %v", errRet) - return - } - xl.Trace("send udp package to workConn: %s", firstPacket.Content) - } - - for { - select { - case udpMsg, ok := <-sv.sendCh: - if !ok { - xl.Info("sender goroutine for udp work connection closed") - return - } - - if errRet = msg.WriteMsg(conn, udpMsg); errRet != nil { - xl.Warn("sender goroutine for udp work connection closed: %v", errRet) - return - } - xl.Trace("send udp package to workConn: %s", udpMsg.Content) - case <-closeCh: - return - } - } - } - - go workConnReaderFn(workConn) - go workConnSenderFn(workConn) - - wg.Wait() - xl.Info("sudp worker is closed") -} - -func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) { - xl := xlog.FromContextSafe(sv.ctx) - visitorConn, err := sv.ctl.connectServer() - if err != nil { - return nil, fmt.Errorf("frpc connect frps error: %v", err) - } - - now := time.Now().Unix() - newVisitorConnMsg := &msg.NewVisitorConn{ - ProxyName: sv.cfg.ServerName, - SignKey: util.GetAuthKey(sv.cfg.Sk, now), - Timestamp: now, - UseEncryption: sv.cfg.UseEncryption, - UseCompression: sv.cfg.UseCompression, - } - err = msg.WriteMsg(visitorConn, newVisitorConnMsg) - if err != nil { - return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err) - } - - var newVisitorConnRespMsg msg.NewVisitorConnResp - _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) - if err != nil { - return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err) - } - _ = visitorConn.SetReadDeadline(time.Time{}) - - if newVisitorConnRespMsg.Error != "" { - return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error) - } - - var remote io.ReadWriteCloser - remote = visitorConn - if sv.cfg.UseEncryption { - remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) - if err != nil { - xl.Error("create encryption stream error: %v", err) - return nil, err - } - } - if sv.cfg.UseCompression { - remote = frpIo.WithCompression(remote) - } - return frpNet.WrapReadWriteCloserToConn(remote, visitorConn), nil -} - -func (sv *SUDPVisitor) Close() { - sv.mu.Lock() - defer sv.mu.Unlock() - - select { - case <-sv.checkCloseCh: - return - default: - close(sv.checkCloseCh) - } - if sv.udpConn != nil { - sv.udpConn.Close() - } - close(sv.readCh) - close(sv.sendCh) -} diff --git a/client/visitor/stcp.go b/client/visitor/stcp.go new file mode 100644 index 00000000..2a7cf640 --- /dev/null +++ b/client/visitor/stcp.go @@ -0,0 +1,118 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package visitor + +import ( + "io" + "net" + "strconv" + "time" + + frpIo "github.com/fatedier/golib/io" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/pkg/util/xlog" +) + +type STCPVisitor struct { + *BaseVisitor + + cfg *config.STCPVisitorConf +} + +func (sv *STCPVisitor) Run() (err error) { + sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) + if err != nil { + return + } + + go sv.worker() + return +} + +func (sv *STCPVisitor) Close() { + sv.l.Close() +} + +func (sv *STCPVisitor) worker() { + xl := xlog.FromContextSafe(sv.ctx) + for { + conn, err := sv.l.Accept() + if err != nil { + xl.Warn("stcp local listener closed") + return + } + + go sv.handleConn(conn) + } +} + +func (sv *STCPVisitor) handleConn(userConn net.Conn) { + xl := xlog.FromContextSafe(sv.ctx) + defer userConn.Close() + + xl.Debug("get a new stcp user connection") + visitorConn, err := sv.connectServer() + if err != nil { + return + } + defer visitorConn.Close() + + now := time.Now().Unix() + newVisitorConnMsg := &msg.NewVisitorConn{ + ProxyName: sv.cfg.ServerName, + SignKey: util.GetAuthKey(sv.cfg.Sk, now), + Timestamp: now, + UseEncryption: sv.cfg.UseEncryption, + UseCompression: sv.cfg.UseCompression, + } + err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + if err != nil { + xl.Warn("send newVisitorConnMsg to server error: %v", err) + return + } + + var newVisitorConnRespMsg msg.NewVisitorConnResp + _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) + if err != nil { + xl.Warn("get newVisitorConnRespMsg error: %v", err) + return + } + _ = visitorConn.SetReadDeadline(time.Time{}) + + if newVisitorConnRespMsg.Error != "" { + xl.Warn("start new visitor connection error: %s", newVisitorConnRespMsg.Error) + return + } + + var remote io.ReadWriteCloser + remote = visitorConn + if sv.cfg.UseEncryption { + remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) + if err != nil { + xl.Error("create encryption stream error: %v", err) + return + } + } + + if sv.cfg.UseCompression { + remote = frpIo.WithCompression(remote) + } + + frpIo.Join(userConn, remote) +} diff --git a/client/visitor/sudp.go b/client/visitor/sudp.go new file mode 100644 index 00000000..cd6f2afe --- /dev/null +++ b/client/visitor/sudp.go @@ -0,0 +1,262 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package visitor + +import ( + "fmt" + "io" + "net" + "strconv" + "sync" + "time" + + "github.com/fatedier/golib/errors" + frpIo "github.com/fatedier/golib/io" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/udp" + frpNet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/pkg/util/xlog" +) + +type SUDPVisitor struct { + *BaseVisitor + + checkCloseCh chan struct{} + // udpConn is the listener of udp packet + udpConn *net.UDPConn + readCh chan *msg.UDPPacket + sendCh chan *msg.UDPPacket + + cfg *config.SUDPVisitorConf +} + +// SUDP Run start listen a udp port +func (sv *SUDPVisitor) Run() (err error) { + xl := xlog.FromContextSafe(sv.ctx) + + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) + if err != nil { + return fmt.Errorf("sudp ResolveUDPAddr error: %v", err) + } + + sv.udpConn, err = net.ListenUDP("udp", addr) + if err != nil { + return fmt.Errorf("listen udp port %s error: %v", addr.String(), err) + } + + sv.sendCh = make(chan *msg.UDPPacket, 1024) + sv.readCh = make(chan *msg.UDPPacket, 1024) + + xl.Info("sudp start to work, listen on %s", addr) + + go sv.dispatcher() + go udp.ForwardUserConn(sv.udpConn, sv.readCh, sv.sendCh, int(sv.clientCfg.UDPPacketSize)) + + return +} + +func (sv *SUDPVisitor) dispatcher() { + xl := xlog.FromContextSafe(sv.ctx) + + var ( + visitorConn net.Conn + err error + + firstPacket *msg.UDPPacket + ) + + for { + select { + case firstPacket = <-sv.sendCh: + if firstPacket == nil { + xl.Info("frpc sudp visitor proxy is closed") + return + } + case <-sv.checkCloseCh: + xl.Info("frpc sudp visitor proxy is closed") + return + } + + visitorConn, err = sv.getNewVisitorConn() + if err != nil { + xl.Warn("newVisitorConn to frps error: %v, try to reconnect", err) + continue + } + + // visitorConn always be closed when worker done. + sv.worker(visitorConn, firstPacket) + + select { + case <-sv.checkCloseCh: + return + default: + } + } +} + +func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) { + xl := xlog.FromContextSafe(sv.ctx) + xl.Debug("starting sudp proxy worker") + + wg := &sync.WaitGroup{} + wg.Add(2) + closeCh := make(chan struct{}) + + // udp service -> frpc -> frps -> frpc visitor -> user + workConnReaderFn := func(conn net.Conn) { + defer func() { + conn.Close() + close(closeCh) + wg.Done() + }() + + for { + var ( + rawMsg msg.Message + errRet error + ) + + // frpc will send heartbeat in workConn to frpc visitor for keeping alive + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + if rawMsg, errRet = msg.ReadMsg(conn); errRet != nil { + xl.Warn("read from workconn for user udp conn error: %v", errRet) + return + } + + _ = conn.SetReadDeadline(time.Time{}) + switch m := rawMsg.(type) { + case *msg.Ping: + xl.Debug("frpc visitor get ping message from frpc") + continue + case *msg.UDPPacket: + if errRet := errors.PanicToError(func() { + sv.readCh <- m + xl.Trace("frpc visitor get udp packet from workConn: %s", m.Content) + }); errRet != nil { + xl.Info("reader goroutine for udp work connection closed") + return + } + } + } + } + + // udp service <- frpc <- frps <- frpc visitor <- user + workConnSenderFn := func(conn net.Conn) { + defer func() { + conn.Close() + wg.Done() + }() + + var errRet error + if firstPacket != nil { + if errRet = msg.WriteMsg(conn, firstPacket); errRet != nil { + xl.Warn("sender goroutine for udp work connection closed: %v", errRet) + return + } + xl.Trace("send udp package to workConn: %s", firstPacket.Content) + } + + for { + select { + case udpMsg, ok := <-sv.sendCh: + if !ok { + xl.Info("sender goroutine for udp work connection closed") + return + } + + if errRet = msg.WriteMsg(conn, udpMsg); errRet != nil { + xl.Warn("sender goroutine for udp work connection closed: %v", errRet) + return + } + xl.Trace("send udp package to workConn: %s", udpMsg.Content) + case <-closeCh: + return + } + } + } + + go workConnReaderFn(workConn) + go workConnSenderFn(workConn) + + wg.Wait() + xl.Info("sudp worker is closed") +} + +func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) { + xl := xlog.FromContextSafe(sv.ctx) + visitorConn, err := sv.connectServer() + if err != nil { + return nil, fmt.Errorf("frpc connect frps error: %v", err) + } + + now := time.Now().Unix() + newVisitorConnMsg := &msg.NewVisitorConn{ + ProxyName: sv.cfg.ServerName, + SignKey: util.GetAuthKey(sv.cfg.Sk, now), + Timestamp: now, + UseEncryption: sv.cfg.UseEncryption, + UseCompression: sv.cfg.UseCompression, + } + err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + if err != nil { + return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err) + } + + var newVisitorConnRespMsg msg.NewVisitorConnResp + _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) + if err != nil { + return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err) + } + _ = visitorConn.SetReadDeadline(time.Time{}) + + if newVisitorConnRespMsg.Error != "" { + return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error) + } + + var remote io.ReadWriteCloser + remote = visitorConn + if sv.cfg.UseEncryption { + remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) + if err != nil { + xl.Error("create encryption stream error: %v", err) + return nil, err + } + } + if sv.cfg.UseCompression { + remote = frpIo.WithCompression(remote) + } + return frpNet.WrapReadWriteCloserToConn(remote, visitorConn), nil +} + +func (sv *SUDPVisitor) Close() { + sv.mu.Lock() + defer sv.mu.Unlock() + + select { + case <-sv.checkCloseCh: + return + default: + close(sv.checkCloseCh) + } + if sv.udpConn != nil { + sv.udpConn.Close() + } + close(sv.readCh) + close(sv.sendCh) +} diff --git a/client/visitor/visitor.go b/client/visitor/visitor.go new file mode 100644 index 00000000..1f6471d9 --- /dev/null +++ b/client/visitor/visitor.go @@ -0,0 +1,77 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package visitor + +import ( + "context" + "net" + "sync" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/transport" + "github.com/fatedier/frp/pkg/util/xlog" +) + +// Visitor is used for forward traffics from local port tot remote service. +type Visitor interface { + Run() error + Close() +} + +func NewVisitor( + ctx context.Context, + cfg config.VisitorConf, + clientCfg config.ClientCommonConf, + connectServer func() (net.Conn, error), + msgTransporter transport.MessageTransporter, +) (visitor Visitor) { + xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseInfo().ProxyName) + baseVisitor := BaseVisitor{ + clientCfg: clientCfg, + connectServer: connectServer, + msgTransporter: msgTransporter, + ctx: xlog.NewContext(ctx, xl), + } + switch cfg := cfg.(type) { + case *config.STCPVisitorConf: + visitor = &STCPVisitor{ + BaseVisitor: &baseVisitor, + cfg: cfg, + } + case *config.XTCPVisitorConf: + visitor = &XTCPVisitor{ + BaseVisitor: &baseVisitor, + cfg: cfg, + startTunnelCh: make(chan struct{}), + } + case *config.SUDPVisitorConf: + visitor = &SUDPVisitor{ + BaseVisitor: &baseVisitor, + cfg: cfg, + checkCloseCh: make(chan struct{}), + } + } + return +} + +type BaseVisitor struct { + clientCfg config.ClientCommonConf + connectServer func() (net.Conn, error) + msgTransporter transport.MessageTransporter + l net.Listener + + mu sync.RWMutex + ctx context.Context +} diff --git a/client/visitor_manager.go b/client/visitor/visitor_manager.go similarity index 70% rename from client/visitor_manager.go rename to client/visitor/visitor_manager.go index 8df47150..02b7e493 100644 --- a/client/visitor_manager.go +++ b/client/visitor/visitor_manager.go @@ -12,22 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package visitor import ( "context" + "net" "sync" "time" "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" ) -type VisitorManager struct { - ctl *Control - - cfgs map[string]config.VisitorConf - visitors map[string]Visitor +type Manager struct { + clientCfg config.ClientCommonConf + connectServer func() (net.Conn, error) + msgTransporter transport.MessageTransporter + cfgs map[string]config.VisitorConf + visitors map[string]Visitor checkInterval time.Duration @@ -37,18 +40,25 @@ type VisitorManager struct { stopCh chan struct{} } -func NewVisitorManager(ctx context.Context, ctl *Control) *VisitorManager { - return &VisitorManager{ - ctl: ctl, - cfgs: make(map[string]config.VisitorConf), - visitors: make(map[string]Visitor), - checkInterval: 10 * time.Second, - ctx: ctx, - stopCh: make(chan struct{}), +func NewManager( + ctx context.Context, + clientCfg config.ClientCommonConf, + connectServer func() (net.Conn, error), + msgTransporter transport.MessageTransporter, +) *Manager { + return &Manager{ + clientCfg: clientCfg, + connectServer: connectServer, + msgTransporter: msgTransporter, + cfgs: make(map[string]config.VisitorConf), + visitors: make(map[string]Visitor), + checkInterval: 10 * time.Second, + ctx: ctx, + stopCh: make(chan struct{}), } } -func (vm *VisitorManager) Run() { +func (vm *Manager) Run() { xl := xlog.FromContextSafe(vm.ctx) ticker := time.NewTicker(vm.checkInterval) @@ -74,10 +84,10 @@ func (vm *VisitorManager) Run() { } // Hold lock before calling this function. -func (vm *VisitorManager) startVisitor(cfg config.VisitorConf) (err error) { +func (vm *Manager) startVisitor(cfg config.VisitorConf) (err error) { xl := xlog.FromContextSafe(vm.ctx) name := cfg.GetBaseInfo().ProxyName - visitor := NewVisitor(vm.ctx, vm.ctl, cfg) + visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.connectServer, vm.msgTransporter) err = visitor.Run() if err != nil { xl.Warn("start error: %v", err) @@ -88,7 +98,7 @@ func (vm *VisitorManager) startVisitor(cfg config.VisitorConf) (err error) { return } -func (vm *VisitorManager) Reload(cfgs map[string]config.VisitorConf) { +func (vm *Manager) Reload(cfgs map[string]config.VisitorConf) { xl := xlog.FromContextSafe(vm.ctx) vm.mu.Lock() defer vm.mu.Unlock() @@ -129,7 +139,7 @@ func (vm *VisitorManager) Reload(cfgs map[string]config.VisitorConf) { } } -func (vm *VisitorManager) Close() { +func (vm *Manager) Close() { vm.mu.Lock() defer vm.mu.Unlock() for _, v := range vm.visitors { diff --git a/client/visitor/xtcp.go b/client/visitor/xtcp.go new file mode 100644 index 00000000..6ace7688 --- /dev/null +++ b/client/visitor/xtcp.go @@ -0,0 +1,410 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package visitor + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "strconv" + "sync" + "time" + + frpIo "github.com/fatedier/golib/io" + fmux "github.com/hashicorp/yamux" + quic "github.com/quic-go/quic-go" + "golang.org/x/time/rate" + + "github.com/fatedier/frp/pkg/config" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/nathole" + "github.com/fatedier/frp/pkg/transport" + frpNet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/pkg/util/xlog" +) + +var ErrNoTunnelSession = errors.New("no tunnel session") + +type XTCPVisitor struct { + *BaseVisitor + session TunnelSession + startTunnelCh chan struct{} + retryLimiter *rate.Limiter + cancel context.CancelFunc + + cfg *config.XTCPVisitorConf +} + +func (sv *XTCPVisitor) Run() (err error) { + sv.ctx, sv.cancel = context.WithCancel(sv.ctx) + + if sv.cfg.Protocol == "kcp" { + sv.session = NewKCPTunnelSession() + } else { + sv.session = NewQUICTunnelSession(&sv.clientCfg) + } + + sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort))) + if err != nil { + return + } + + go sv.worker() + go sv.processTunnelStartEvents() + if sv.cfg.KeepTunnelOpen { + sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour) + go sv.keepTunnelOpenWorker() + } + return +} + +func (sv *XTCPVisitor) Close() { + sv.l.Close() + sv.cancel() + if sv.session != nil { + sv.session.Close() + } +} + +func (sv *XTCPVisitor) worker() { + xl := xlog.FromContextSafe(sv.ctx) + for { + conn, err := sv.l.Accept() + if err != nil { + xl.Warn("xtcp local listener closed") + return + } + + go sv.handleConn(conn) + } +} + +func (sv *XTCPVisitor) processTunnelStartEvents() { + for { + select { + case <-sv.ctx.Done(): + return + case <-sv.startTunnelCh: + start := time.Now() + sv.makeNatHole() + duration := time.Since(start) + // avoid too frequently + if duration < 10*time.Second { + time.Sleep(10*time.Second - duration) + } + } + } +} + +func (sv *XTCPVisitor) keepTunnelOpenWorker() { + xl := xlog.FromContextSafe(sv.ctx) + ticker := time.NewTicker(time.Duration(sv.cfg.MinRetryInterval) * time.Second) + defer ticker.Stop() + + sv.startTunnelCh <- struct{}{} + for { + select { + case <-sv.ctx.Done(): + return + case <-ticker.C: + xl.Debug("keepTunnelOpenWorker try to check tunnel...") + conn, err := sv.getTunnelConn() + if err != nil { + xl.Warn("keepTunnelOpenWorker get tunnel connection error: %v", err) + _ = sv.retryLimiter.Wait(sv.ctx) + continue + } + xl.Debug("keepTunnelOpenWorker check success") + if conn != nil { + conn.Close() + } + } + } +} + +func (sv *XTCPVisitor) handleConn(userConn net.Conn) { + xl := xlog.FromContextSafe(sv.ctx) + defer userConn.Close() + + xl.Debug("get a new xtcp user connection") + + // Open a tunnel connection to the server. If there is already a successful hole-punching connection, + // it will be reused. Otherwise, it will block and wait for a successful hole-punching connection until timeout. + tunnelConn, err := sv.openTunnel() + if err != nil { + xl.Error("open tunnel error: %v", err) + return + } + + var muxConnRWCloser io.ReadWriteCloser = tunnelConn + if sv.cfg.UseEncryption { + muxConnRWCloser, err = frpIo.WithEncryption(muxConnRWCloser, []byte(sv.cfg.Sk)) + if err != nil { + xl.Error("create encryption stream error: %v", err) + return + } + } + if sv.cfg.UseCompression { + muxConnRWCloser = frpIo.WithCompression(muxConnRWCloser) + } + + _, _, errs := frpIo.Join(userConn, muxConnRWCloser) + xl.Debug("join connections closed") + if len(errs) > 0 { + xl.Trace("join connections errors: %v", errs) + } +} + +// openTunnel will open a tunnel connection to the target server. +func (sv *XTCPVisitor) openTunnel() (conn net.Conn, err error) { + xl := xlog.FromContextSafe(sv.ctx) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + timeoutC := time.After(20 * time.Second) + immediateTrigger := make(chan struct{}, 1) + defer close(immediateTrigger) + immediateTrigger <- struct{}{} + + for { + select { + case <-sv.ctx.Done(): + return nil, sv.ctx.Err() + case <-immediateTrigger: + conn, err = sv.getTunnelConn() + case <-ticker.C: + conn, err = sv.getTunnelConn() + case <-timeoutC: + return nil, fmt.Errorf("open tunnel timeout") + } + + if err != nil { + if err != ErrNoTunnelSession { + xl.Warn("get tunnel connection error: %v", err) + } + continue + } + return conn, nil + } +} + +func (sv *XTCPVisitor) getTunnelConn() (net.Conn, error) { + conn, err := sv.session.OpenConn(sv.ctx) + if err == nil { + return conn, nil + } + sv.session.Close() + + select { + case sv.startTunnelCh <- struct{}{}: + default: + } + return nil, err +} + +// 0. PreCheck +// 1. Prepare +// 2. ExchangeInfo +// 3. MakeNATHole +// 4. Create a tunnel session using an underlying UDP connection. +func (sv *XTCPVisitor) makeNatHole() { + xl := xlog.FromContextSafe(sv.ctx) + if err := nathole.PreCheck(sv.ctx, sv.msgTransporter, sv.cfg.ServerName, 5*time.Second); err != nil { + xl.Warn("nathole precheck error: %v", err) + return + } + + prepareResult, err := nathole.Prepare([]string{sv.clientCfg.NatHoleSTUNServer}) + if err != nil { + xl.Warn("nathole prepare error: %v", err) + return + } + xl.Info("nathole prepare success, nat type: %s, behavior: %s, addresses: %v, assistedAddresses: %v", + prepareResult.NatType, prepareResult.Behavior, prepareResult.Addrs, prepareResult.AssistedAddrs) + + listenConn := prepareResult.ListenConn + + // send NatHoleVisitor to server + now := time.Now().Unix() + transactionID := nathole.NewTransactionID() + natHoleVisitorMsg := &msg.NatHoleVisitor{ + TransactionID: transactionID, + ProxyName: sv.cfg.ServerName, + Protocol: sv.cfg.Protocol, + SignKey: util.GetAuthKey(sv.cfg.Sk, now), + Timestamp: now, + MappedAddrs: prepareResult.Addrs, + AssistedAddrs: prepareResult.AssistedAddrs, + } + + natHoleRespMsg, err := nathole.ExchangeInfo(sv.ctx, sv.msgTransporter, transactionID, natHoleVisitorMsg, 5*time.Second) + if err != nil { + listenConn.Close() + xl.Warn("nathole exchange info error: %v", err) + return + } + + xl.Info("get natHoleRespMsg, sid [%s], protocol [%s], candidate address %v, assisted address %v, detectBehavior: %+v", + natHoleRespMsg.Sid, natHoleRespMsg.Protocol, natHoleRespMsg.CandidateAddrs, + natHoleRespMsg.AssistedAddrs, natHoleRespMsg.DetectBehavior) + + newListenConn, raddr, err := nathole.MakeHole(sv.ctx, listenConn, natHoleRespMsg, []byte(sv.cfg.Sk)) + if err != nil { + listenConn.Close() + xl.Warn("make hole error: %v", err) + return + } + listenConn = newListenConn + xl.Info("establishing nat hole connection successful, sid [%s], remoteAddr [%s]", natHoleRespMsg.Sid, raddr) + + if err := sv.session.Init(listenConn, raddr); err != nil { + listenConn.Close() + xl.Warn("init tunnel session error: %v", err) + return + } +} + +type TunnelSession interface { + Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error + OpenConn(context.Context) (net.Conn, error) + Close() +} + +type KCPTunnelSession struct { + session *fmux.Session + lConn *net.UDPConn + mu sync.RWMutex +} + +func NewKCPTunnelSession() TunnelSession { + return &KCPTunnelSession{} +} + +func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error { + listenConn.Close() + laddr, _ := net.ResolveUDPAddr("udp", listenConn.LocalAddr().String()) + lConn, err := net.DialUDP("udp", laddr, raddr) + if err != nil { + return fmt.Errorf("dial udp error: %v", err) + } + remote, err := frpNet.NewKCPConnFromUDP(lConn, true, raddr.String()) + if err != nil { + return fmt.Errorf("create kcp connection from udp connection error: %v", err) + } + + fmuxCfg := fmux.DefaultConfig() + fmuxCfg.KeepAliveInterval = 10 * time.Second + fmuxCfg.MaxStreamWindowSize = 2 * 1024 * 1024 + fmuxCfg.LogOutput = io.Discard + session, err := fmux.Client(remote, fmuxCfg) + if err != nil { + remote.Close() + return fmt.Errorf("initial client session error: %v", err) + } + ks.mu.Lock() + ks.session = session + ks.lConn = lConn + ks.mu.Unlock() + return nil +} + +func (ks *KCPTunnelSession) OpenConn(ctx context.Context) (net.Conn, error) { + ks.mu.RLock() + defer ks.mu.RUnlock() + session := ks.session + if session == nil { + return nil, ErrNoTunnelSession + } + return session.Open() +} + +func (ks *KCPTunnelSession) Close() { + ks.mu.Lock() + defer ks.mu.Unlock() + if ks.session != nil { + _ = ks.session.Close() + ks.session = nil + } + if ks.lConn != nil { + _ = ks.lConn.Close() + ks.lConn = nil + } +} + +type QUICTunnelSession struct { + session quic.Connection + listenConn *net.UDPConn + mu sync.RWMutex + + clientCfg *config.ClientCommonConf +} + +func NewQUICTunnelSession(clientCfg *config.ClientCommonConf) TunnelSession { + return &QUICTunnelSession{ + clientCfg: clientCfg, + } +} + +func (qs *QUICTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error { + tlsConfig, err := transport.NewClientTLSConfig("", "", "", raddr.String()) + if err != nil { + return fmt.Errorf("create tls config error: %v", err) + } + tlsConfig.NextProtos = []string{"frp"} + quicConn, err := quic.Dial(listenConn, raddr, raddr.String(), tlsConfig, + &quic.Config{ + MaxIdleTimeout: time.Duration(qs.clientCfg.QUICMaxIdleTimeout) * time.Second, + MaxIncomingStreams: int64(qs.clientCfg.QUICMaxIncomingStreams), + KeepAlivePeriod: time.Duration(qs.clientCfg.QUICKeepalivePeriod) * time.Second, + }) + if err != nil { + return fmt.Errorf("dial quic error: %v", err) + } + qs.mu.Lock() + qs.session = quicConn + qs.listenConn = listenConn + qs.mu.Unlock() + return nil +} + +func (qs *QUICTunnelSession) OpenConn(ctx context.Context) (net.Conn, error) { + qs.mu.RLock() + defer qs.mu.RUnlock() + session := qs.session + if session == nil { + return nil, ErrNoTunnelSession + } + stream, err := session.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + return frpNet.QuicStreamToNetConn(stream, session), nil +} + +func (qs *QUICTunnelSession) Close() { + qs.mu.Lock() + defer qs.mu.Unlock() + if qs.session != nil { + _ = qs.session.CloseWithError(0, "") + qs.session = nil + } + if qs.listenConn != nil { + _ = qs.listenConn.Close() + qs.listenConn = nil + } +} diff --git a/cmd/frpc/sub/nathole.go b/cmd/frpc/sub/nathole.go index f08990bf..459f8d63 100644 --- a/cmd/frpc/sub/nathole.go +++ b/cmd/frpc/sub/nathole.go @@ -16,9 +16,7 @@ package sub import ( "fmt" - "net" "os" - "strconv" "github.com/spf13/cobra" @@ -28,7 +26,7 @@ import ( var ( natHoleSTUNServer string - serverUDPPort int + natHoleLocalAddr string ) func init() { @@ -37,8 +35,8 @@ func init() { rootCmd.AddCommand(natholeCmd) natholeCmd.AddCommand(natholeDiscoveryCmd) - natholeCmd.PersistentFlags().StringVarP(&natHoleSTUNServer, "nat_hole_stun_server", "", "stun.easyvoip.com:3478", "STUN server address for nathole") - natholeCmd.PersistentFlags().IntVarP(&serverUDPPort, "server_udp_port", "", 0, "UDP port of frps for nathole") + natholeCmd.PersistentFlags().StringVarP(&natHoleSTUNServer, "nat_hole_stun_server", "", "", "STUN server address for nathole") + natholeCmd.PersistentFlags().StringVarP(&natHoleLocalAddr, "nat_hole_local_addr", "l", "", "local address to connect STUN server") } var natholeCmd = &cobra.Command{ @@ -48,48 +46,45 @@ var natholeCmd = &cobra.Command{ var natholeDiscoveryCmd = &cobra.Command{ Use: "discover", - Short: "Discover nathole information by frps and stun server", + Short: "Discover nathole information from stun server", RunE: func(cmd *cobra.Command, args []string) error { // ignore error here, because we can use command line pameters - cfg, _, _, _ := config.ParseClientConfig(cfgFile) + cfg, _, _, err := config.ParseClientConfig(cfgFile) + if err != nil { + cfg = config.GetDefaultClientConf() + } if natHoleSTUNServer != "" { cfg.NatHoleSTUNServer = natHoleSTUNServer } - if serverUDPPort != 0 { - cfg.ServerUDPPort = serverUDPPort - } if err := validateForNatHoleDiscovery(cfg); err != nil { fmt.Println(err) os.Exit(1) } - serverAddr := "" - if cfg.ServerUDPPort != 0 { - serverAddr = net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort)) - } - addresses, err := nathole.Discover( - serverAddr, - []string{cfg.NatHoleSTUNServer}, - []byte(cfg.Token), - ) + addrs, localAddr, err := nathole.Discover([]string{cfg.NatHoleSTUNServer}, natHoleLocalAddr) if err != nil { fmt.Println("discover error:", err) os.Exit(1) } - if len(addresses) < 2 { - fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addresses) + if len(addrs) < 2 { + fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addrs) os.Exit(1) } - natType, behavior, err := nathole.ClassifyNATType(addresses) + localIPs, _ := nathole.ListLocalIPsForNatHole(10) + + natFeature, err := nathole.ClassifyNATFeature(addrs, localIPs) if err != nil { - fmt.Println("classify nat type error:", err) + fmt.Println("classify nat feature error:", err) os.Exit(1) } - fmt.Println("Your NAT type is:", natType) - fmt.Println("Behavior is:", behavior) - fmt.Println("External address is:", addresses) + fmt.Println("STUN server:", cfg.NatHoleSTUNServer) + fmt.Println("Your NAT type is:", natFeature.NatType) + fmt.Println("Behavior is:", natFeature.Behavior) + fmt.Println("External address is:", addrs) + fmt.Println("Local address is:", localAddr.String()) + fmt.Println("Public Network:", natFeature.PublicNetwork) return nil }, } diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index cd39d882..8271d065 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -53,6 +53,7 @@ var ( logFile string logMaxDays int disableLogColor bool + dnsServer string proxyName string localIP string @@ -94,6 +95,7 @@ func RegisterCommonFlags(cmd *cobra.Command) { cmd.PersistentFlags().IntVarP(&logMaxDays, "log_max_days", "", 3, "log file reversed days") cmd.PersistentFlags().BoolVarP(&disableLogColor, "disable_log_color", "", false, "disable log color in console") cmd.PersistentFlags().BoolVarP(&tlsEnable, "tls_enable", "", false, "enable frpc tls") + cmd.PersistentFlags().StringVarP(&dnsServer, "dns_server", "", "", "specify dns server instead of using system default one") } var rootCmd = &cobra.Command{ @@ -108,26 +110,7 @@ var rootCmd = &cobra.Command{ // If cfgDir is not empty, run multiple frpc service for each config file in cfgDir. // Note that it's only designed for testing. It's not guaranteed to be stable. if cfgDir != "" { - var wg sync.WaitGroup - _ = filepath.WalkDir(cfgDir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - return nil - } - wg.Add(1) - time.Sleep(time.Millisecond) - go func() { - defer wg.Done() - err := runClient(path) - if err != nil { - fmt.Printf("frpc service error for config file [%s]\n", path) - } - }() - return nil - }) - wg.Wait() + _ = runMultipleClients(cfgDir) return nil } @@ -141,6 +124,27 @@ var rootCmd = &cobra.Command{ }, } +func runMultipleClients(cfgDir string) error { + var wg sync.WaitGroup + err := filepath.WalkDir(cfgDir, func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + wg.Add(1) + time.Sleep(time.Millisecond) + go func() { + defer wg.Done() + err := runClient(path) + if err != nil { + fmt.Printf("frpc service error for config file [%s]\n", path) + } + }() + return nil + }) + wg.Wait() + return err +} + func Execute() { if err := rootCmd.Execute(); err != nil { os.Exit(1) @@ -177,6 +181,7 @@ func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) { cfg.LogFile = logFile cfg.LogMaxDays = int64(logMaxDays) cfg.DisableLogColor = disableLogColor + cfg.DNSServer = dnsServer // Only token authentication is supported in cmd mode cfg.ClientConfig = auth.GetDefaultClientConf() diff --git a/go.mod b/go.mod index ba23a538..86801843 100644 --- a/go.mod +++ b/go.mod @@ -18,12 +18,14 @@ require ( github.com/pion/stun v0.4.0 github.com/pires/go-proxyproto v0.6.2 github.com/prometheus/client_golang v1.13.0 - github.com/quic-go/quic-go v0.32.0 + github.com/quic-go/quic-go v0.34.0 github.com/rodaine/table v1.0.1 + github.com/samber/lo v1.38.1 github.com/spf13/cobra v1.1.3 github.com/stretchr/testify v1.8.1 golang.org/x/net v0.7.0 golang.org/x/oauth2 v0.3.0 + golang.org/x/sync v0.1.0 golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 gopkg.in/ini.v1 v1.67.0 k8s.io/apimachinery v0.26.1 @@ -55,9 +57,8 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect - github.com/quic-go/qtls-go1-18 v0.2.0 // indirect - github.com/quic-go/qtls-go1-19 v0.2.0 // indirect - github.com/quic-go/qtls-go1-20 v0.1.0 // indirect + github.com/quic-go/qtls-go1-19 v0.3.2 // indirect + github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161 // indirect github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b // indirect diff --git a/go.sum b/go.sum index d708984a..7950ad62 100644 --- a/go.sum +++ b/go.sum @@ -381,14 +381,12 @@ github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U= -github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc= -github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk= -github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI= -github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= -github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA= -github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= +github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= +github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= +github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= +github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= +github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ= github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= @@ -399,6 +397,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= +github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -602,6 +602,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/pkg/config/client_test.go b/pkg/config/client_test.go index e5325107..bff47cd4 100644 --- a/pkg/config/client_test.go +++ b/pkg/config/client_test.go @@ -661,6 +661,9 @@ func Test_LoadClientBasicConf(t *testing.T) { BindAddr: "127.0.0.1", BindPort: 9001, }, + Protocol: "quic", + MaxRetriesAnHour: 8, + MinRetryInterval: 90, }, } diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index de4953ca..133d59ee 100644 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -1078,7 +1078,6 @@ func (cfg *XTCPProxyConf) Compare(cmp ProxyConf) bool { cfg.Sk != cmpConf.Sk { return false } - return true } @@ -1092,7 +1091,6 @@ func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section * if cfg.Role == "" { cfg.Role = "server" } - return nil } @@ -1120,7 +1118,6 @@ func (cfg *XTCPProxyConf) CheckForCli() (err error) { if cfg.Role != "server" { return fmt.Errorf("role should be 'server'") } - return } diff --git a/pkg/config/server.go b/pkg/config/server.go index cdc29daa..f54a32f3 100644 --- a/pkg/config/server.go +++ b/pkg/config/server.go @@ -196,35 +196,38 @@ type ServerCommonConf struct { // Enable golang pprof handlers in dashboard listener. // Dashboard port must be set first. PprofEnable bool `ini:"pprof_enable" json:"pprof_enable"` + // NatHoleAnalysisDataReserveHours specifies the hours to reserve nat hole analysis data. + NatHoleAnalysisDataReserveHours int64 `ini:"nat_hole_analysis_data_reserve_hours" json:"nat_hole_analysis_data_reserve_hours"` } // GetDefaultServerConf returns a server configuration with reasonable // defaults. func GetDefaultServerConf() ServerCommonConf { return ServerCommonConf{ - ServerConfig: auth.GetDefaultServerConf(), - BindAddr: "0.0.0.0", - BindPort: 7000, - QUICKeepalivePeriod: 10, - QUICMaxIdleTimeout: 30, - QUICMaxIncomingStreams: 100000, - VhostHTTPTimeout: 60, - DashboardAddr: "0.0.0.0", - LogFile: "console", - LogWay: "console", - LogLevel: "info", - LogMaxDays: 3, - DetailedErrorsToClient: true, - TCPMux: true, - TCPMuxKeepaliveInterval: 60, - TCPKeepAlive: 7200, - AllowPorts: make(map[int]struct{}), - MaxPoolCount: 5, - MaxPortsPerClient: 0, - HeartbeatTimeout: 90, - UserConnTimeout: 10, - HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), - UDPPacketSize: 1500, + ServerConfig: auth.GetDefaultServerConf(), + BindAddr: "0.0.0.0", + BindPort: 7000, + QUICKeepalivePeriod: 10, + QUICMaxIdleTimeout: 30, + QUICMaxIncomingStreams: 100000, + VhostHTTPTimeout: 60, + DashboardAddr: "0.0.0.0", + LogFile: "console", + LogWay: "console", + LogLevel: "info", + LogMaxDays: 3, + DetailedErrorsToClient: true, + TCPMux: true, + TCPMuxKeepaliveInterval: 60, + TCPKeepAlive: 7200, + AllowPorts: make(map[int]struct{}), + MaxPoolCount: 5, + MaxPortsPerClient: 0, + HeartbeatTimeout: 90, + UserConnTimeout: 10, + HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), + UDPPacketSize: 1500, + NatHoleAnalysisDataReserveHours: 7 * 24, } } diff --git a/pkg/config/server_test.go b/pkg/config/server_test.go index c00f5970..e93ba3d7 100644 --- a/pkg/config/server_test.go +++ b/pkg/config/server_test.go @@ -134,18 +134,19 @@ func Test_LoadServerCommonConf(t *testing.T) { 12: {}, 99: {}, }, - AllowPortsStr: "10-12,99", - MaxPoolCount: 59, - MaxPortsPerClient: 9, - TLSOnly: true, - TLSCertFile: "server.crt", - TLSKeyFile: "server.key", - TLSTrustedCaFile: "ca.crt", - SubDomainHost: "frps.com", - TCPMux: true, - TCPMuxKeepaliveInterval: 60, - TCPKeepAlive: 7200, - UDPPacketSize: 1509, + AllowPortsStr: "10-12,99", + MaxPoolCount: 59, + MaxPortsPerClient: 9, + TLSOnly: true, + TLSCertFile: "server.crt", + TLSKeyFile: "server.key", + TLSTrustedCaFile: "ca.crt", + SubDomainHost: "frps.com", + TCPMux: true, + TCPMuxKeepaliveInterval: 60, + TCPKeepAlive: 7200, + UDPPacketSize: 1509, + NatHoleAnalysisDataReserveHours: 7 * 24, HTTPPlugins: map[string]plugin.HTTPPluginOptions{ "user-manager": { @@ -180,32 +181,33 @@ func Test_LoadServerCommonConf(t *testing.T) { AuthenticateNewWorkConns: false, }, }, - BindAddr: "0.0.0.9", - BindPort: 7009, - BindUDPPort: 7008, - QUICKeepalivePeriod: 10, - QUICMaxIdleTimeout: 30, - QUICMaxIncomingStreams: 100000, - ProxyBindAddr: "0.0.0.9", - VhostHTTPTimeout: 60, - DashboardAddr: "0.0.0.0", - DashboardUser: "", - DashboardPwd: "", - EnablePrometheus: false, - LogFile: "console", - LogWay: "console", - LogLevel: "info", - LogMaxDays: 3, - DetailedErrorsToClient: true, - TCPMux: true, - TCPMuxKeepaliveInterval: 60, - TCPKeepAlive: 7200, - AllowPorts: make(map[int]struct{}), - MaxPoolCount: 5, - HeartbeatTimeout: 90, - UserConnTimeout: 10, - HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), - UDPPacketSize: 1500, + BindAddr: "0.0.0.9", + BindPort: 7009, + BindUDPPort: 7008, + QUICKeepalivePeriod: 10, + QUICMaxIdleTimeout: 30, + QUICMaxIncomingStreams: 100000, + ProxyBindAddr: "0.0.0.9", + VhostHTTPTimeout: 60, + DashboardAddr: "0.0.0.0", + DashboardUser: "", + DashboardPwd: "", + EnablePrometheus: false, + LogFile: "console", + LogWay: "console", + LogLevel: "info", + LogMaxDays: 3, + DetailedErrorsToClient: true, + TCPMux: true, + TCPMuxKeepaliveInterval: 60, + TCPKeepAlive: 7200, + AllowPorts: make(map[int]struct{}), + MaxPoolCount: 5, + HeartbeatTimeout: 90, + UserConnTimeout: 10, + HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), + UDPPacketSize: 1500, + NatHoleAnalysisDataReserveHours: 7 * 24, }, }, } diff --git a/pkg/config/visitor.go b/pkg/config/visitor.go index 3d9440cd..0ee010fb 100644 --- a/pkg/config/visitor.go +++ b/pkg/config/visitor.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" + "github.com/samber/lo" "gopkg.in/ini.v1" "github.com/fatedier/frp/pkg/consts" @@ -61,6 +62,11 @@ type STCPVisitorConf struct { type XTCPVisitorConf struct { BaseVisitorConf `ini:",extends"` + + Protocol string `ini:"protocol" json:"protocol,omitempty"` + KeepTunnelOpen bool `ini:"keep_tunnel_open" json:"keep_tunnel_open,omitempty"` + MaxRetriesAnHour int `ini:"max_retries_an_hour" json:"max_retries_an_hour,omitempty"` + MinRetryInterval int `ini:"min_retry_interval" json:"min_retry_interval,omitempty"` } // DefaultVisitorConf creates a empty VisitorConf object by visitorType. @@ -259,7 +265,12 @@ func (cfg *XTCPVisitorConf) Compare(cmp VisitorConf) bool { } // Add custom login equal, if exists - + if cfg.Protocol != cmpConf.Protocol || + cfg.KeepTunnelOpen != cmpConf.KeepTunnelOpen || + cfg.MaxRetriesAnHour != cmpConf.MaxRetriesAnHour || + cfg.MinRetryInterval != cmpConf.MinRetryInterval { + return false + } return true } @@ -270,7 +281,15 @@ func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section } // Add custom logic unmarshal, if exists - + if cfg.Protocol == "" { + cfg.Protocol = "quic" + } + if cfg.MaxRetriesAnHour <= 0 { + cfg.MaxRetriesAnHour = 8 + } + if cfg.MinRetryInterval <= 0 { + cfg.MinRetryInterval = 90 + } return } @@ -280,6 +299,8 @@ func (cfg *XTCPVisitorConf) Check() (err error) { } // Add custom logic validate, if exists - + if !lo.Contains([]string{"", "kcp", "quic"}, cfg.Protocol) { + return fmt.Errorf("protocol should be 'kcp' or 'quic'") + } return } diff --git a/pkg/config/visitor_test.go b/pkg/config/visitor_test.go index cdfbbf46..d91f90bc 100644 --- a/pkg/config/visitor_test.go +++ b/pkg/config/visitor_test.go @@ -87,6 +87,9 @@ func Test_Visitor_UnmarshalFromIni(t *testing.T) { BindAddr: "127.0.0.1", BindPort: 9001, }, + Protocol: "quic", + MaxRetriesAnHour: 8, + MinRetryInterval: 90, }, }, } diff --git a/pkg/metrics/mem/server.go b/pkg/metrics/mem/server.go index 38a8c68b..e8ebd9d4 100644 --- a/pkg/metrics/mem/server.go +++ b/pkg/metrics/mem/server.go @@ -60,25 +60,30 @@ func (m *serverMetrics) run() { go func() { for { time.Sleep(12 * time.Hour) - log.Debug("start to clear useless proxy statistics data...") - m.clearUselessInfo() - log.Debug("finish to clear useless proxy statistics data") + start := time.Now() + count, total := m.clearUselessInfo() + log.Debug("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start)) } }() } -func (m *serverMetrics) clearUselessInfo() { +func (m *serverMetrics) clearUselessInfo() (int, int) { + count := 0 + total := 0 // To check if there are proxies that closed than 7 days and drop them. m.mu.Lock() defer m.mu.Unlock() + total = len(m.info.ProxyStatistics) for name, data := range m.info.ProxyStatistics { if !data.LastCloseTime.IsZero() && data.LastStartTime.Before(data.LastCloseTime) && time.Since(data.LastCloseTime) > time.Duration(7*24)*time.Hour { delete(m.info.ProxyStatistics, name) + count++ log.Trace("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String()) } } + return count, total } func (m *serverMetrics) NewClient() { diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index b50c2b45..b6a209d2 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -16,54 +16,53 @@ package msg import ( "net" + "reflect" ) const ( - TypeLogin = 'o' - TypeLoginResp = '1' - TypeNewProxy = 'p' - TypeNewProxyResp = '2' - TypeCloseProxy = 'c' - TypeNewWorkConn = 'w' - TypeReqWorkConn = 'r' - TypeStartWorkConn = 's' - TypeNewVisitorConn = 'v' - TypeNewVisitorConnResp = '3' - TypePing = 'h' - TypePong = '4' - TypeUDPPacket = 'u' - TypeNatHoleVisitor = 'i' - TypeNatHoleClient = 'n' - TypeNatHoleResp = 'm' - TypeNatHoleClientDetectOK = 'd' - TypeNatHoleSid = '5' - TypeNatHoleBinding = 'b' - TypeNatHoleBindingResp = '6' + TypeLogin = 'o' + TypeLoginResp = '1' + TypeNewProxy = 'p' + TypeNewProxyResp = '2' + TypeCloseProxy = 'c' + TypeNewWorkConn = 'w' + TypeReqWorkConn = 'r' + TypeStartWorkConn = 's' + TypeNewVisitorConn = 'v' + TypeNewVisitorConnResp = '3' + TypePing = 'h' + TypePong = '4' + TypeUDPPacket = 'u' + TypeNatHoleVisitor = 'i' + TypeNatHoleClient = 'n' + TypeNatHoleResp = 'm' + TypeNatHoleSid = '5' + TypeNatHoleReport = '6' ) var msgTypeMap = map[byte]interface{}{ - TypeLogin: Login{}, - TypeLoginResp: LoginResp{}, - TypeNewProxy: NewProxy{}, - TypeNewProxyResp: NewProxyResp{}, - TypeCloseProxy: CloseProxy{}, - TypeNewWorkConn: NewWorkConn{}, - TypeReqWorkConn: ReqWorkConn{}, - TypeStartWorkConn: StartWorkConn{}, - TypeNewVisitorConn: NewVisitorConn{}, - TypeNewVisitorConnResp: NewVisitorConnResp{}, - TypePing: Ping{}, - TypePong: Pong{}, - TypeUDPPacket: UDPPacket{}, - TypeNatHoleVisitor: NatHoleVisitor{}, - TypeNatHoleClient: NatHoleClient{}, - TypeNatHoleResp: NatHoleResp{}, - TypeNatHoleClientDetectOK: NatHoleClientDetectOK{}, - TypeNatHoleSid: NatHoleSid{}, - TypeNatHoleBinding: NatHoleBinding{}, - TypeNatHoleBindingResp: NatHoleBindingResp{}, + TypeLogin: Login{}, + TypeLoginResp: LoginResp{}, + TypeNewProxy: NewProxy{}, + TypeNewProxyResp: NewProxyResp{}, + TypeCloseProxy: CloseProxy{}, + TypeNewWorkConn: NewWorkConn{}, + TypeReqWorkConn: ReqWorkConn{}, + TypeStartWorkConn: StartWorkConn{}, + TypeNewVisitorConn: NewVisitorConn{}, + TypeNewVisitorConnResp: NewVisitorConnResp{}, + TypePing: Ping{}, + TypePong: Pong{}, + TypeUDPPacket: UDPPacket{}, + TypeNatHoleVisitor: NatHoleVisitor{}, + TypeNatHoleClient: NatHoleClient{}, + TypeNatHoleResp: NatHoleResp{}, + TypeNatHoleSid: NatHoleSid{}, + TypeNatHoleReport: NatHoleReport{}, } +var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name() + // When frpc start, client send this message to login to server. type Login struct { Version string `json:"version,omitempty"` @@ -175,35 +174,58 @@ type UDPPacket struct { } type NatHoleVisitor struct { - ProxyName string `json:"proxy_name,omitempty"` - SignKey string `json:"sign_key,omitempty"` - Timestamp int64 `json:"timestamp,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + ProxyName string `json:"proxy_name,omitempty"` + PreCheck bool `json:"pre_check,omitempty"` + Protocol string `json:"protocol,omitempty"` + SignKey string `json:"sign_key,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + MappedAddrs []string `json:"mapped_addrs,omitempty"` + AssistedAddrs []string `json:"assisted_addrs,omitempty"` } type NatHoleClient struct { - ProxyName string `json:"proxy_name,omitempty"` - Sid string `json:"sid,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + ProxyName string `json:"proxy_name,omitempty"` + Sid string `json:"sid,omitempty"` + MappedAddrs []string `json:"mapped_addrs,omitempty"` + AssistedAddrs []string `json:"assisted_addrs,omitempty"` +} + +type PortsRange struct { + From int `json:"from,omitempty"` + To int `json:"to,omitempty"` +} + +type NatHoleDetectBehavior struct { + Role string `json:"role,omitempty"` // sender or receiver + Mode int `json:"mode,omitempty"` // 0, 1, 2... + TTL int `json:"ttl,omitempty"` + SendDelayMs int `json:"send_delay_ms,omitempty"` + ReadTimeoutMs int `json:"read_timeout,omitempty"` + CandidatePorts []PortsRange `json:"candidate_ports,omitempty"` + SendRandomPorts int `json:"send_random_ports,omitempty"` + ListenRandomPorts int `json:"listen_random_ports,omitempty"` } type NatHoleResp struct { - Sid string `json:"sid,omitempty"` - VisitorAddr string `json:"visitor_addr,omitempty"` - ClientAddr string `json:"client_addr,omitempty"` - Error string `json:"error,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + Sid string `json:"sid,omitempty"` + Protocol string `json:"protocol,omitempty"` + CandidateAddrs []string `json:"candidate_addrs,omitempty"` + AssistedAddrs []string `json:"assisted_addrs,omitempty"` + DetectBehavior NatHoleDetectBehavior `json:"detect_behavior,omitempty"` + Error string `json:"error,omitempty"` } -type NatHoleClientDetectOK struct{} - type NatHoleSid struct { - Sid string `json:"sid,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + Sid string `json:"sid,omitempty"` + Response bool `json:"response,omitempty"` + Nonce string `json:"nonce,omitempty"` } -type NatHoleBinding struct { - TransactionID string `json:"transaction_id,omitempty"` -} - -type NatHoleBindingResp struct { - TransactionID string `json:"transaction_id,omitempty"` - Address string `json:"address,omitempty"` - Error string `json:"error,omitempty"` +type NatHoleReport struct { + Sid string `json:"sid,omitempty"` + Success bool `json:"success,omitempty"` } diff --git a/pkg/nathole/analysis.go b/pkg/nathole/analysis.go new file mode 100644 index 00000000..772fd8f8 --- /dev/null +++ b/pkg/nathole/analysis.go @@ -0,0 +1,328 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nathole + +import ( + "sync" + "time" + + "github.com/samber/lo" +) + +var ( + // mode 0, both EasyNAT, PublicNetwork is always receiver + // sender | receiver, ttl 7 + // receiver, ttl 7 | sender + // sender | receiver, ttl 4 + // receiver, ttl 4 | sender + // sender | receiver + // receiver | sender + // sender, sendDelayMs 5000 | receiver + // sender, sendDelayMs 10000 | receiver + // receiver | sender, sendDelayMs 5000 + // receiver | sender, sendDelayMs 10000 + mode0Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{ + lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}, RecommandBehavior{Role: DetectRoleSender}), + lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}, RecommandBehavior{Role: DetectRoleSender}), + lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}, RecommandBehavior{Role: DetectRoleReceiver}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}, RecommandBehavior{Role: DetectRoleReceiver}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}), + } + + // mode 1, HardNAT is sender, EasyNAT is receiver, port changes is regular + // sender | receiver, ttl 7, portsRangeNumber max 10 + // sender, sendDelayMs 2000 | receiver, ttl 7, portsRangeNumber max 10 + // sender | receiver, ttl 4, portsRangeNumber max 10 + // sender, sendDelayMs 2000 | receiver, ttl 4, portsRangeNumber max 10 + // sender | receiver, portsRangeNumber max 10 + // sender, sendDelayMs 2000 | receiver, portsRangeNumber max 10 + mode1Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{ + lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}), + } + + // mode 2, HardNAT is receiver, EasyNAT is sender + // sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports, ttl 7 + // sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports, ttl 4 + // sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports + mode2Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{ + lo.T2( + RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000}, + RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7}, + ), + lo.T2( + RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000}, + RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4}, + ), + lo.T2( + RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000}, + RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256}, + ), + } + + // mode 3, For HardNAT & HardNAT, both changes in the ports are regular + // sender, portsRangeNumber 10 | receiver, ttl 7, portsRangeNumber 10 + // sender, portsRangeNumber 10 | receiver, ttl 4, portsRangeNumber 10 + // sender, portsRangeNumber 10 | receiver, portsRangeNumber 10 + // receiver, ttl 7, portsRangeNumber 10 | sender, portsRangeNumber 10 + // receiver, ttl 4, portsRangeNumber 10 | sender, portsRangeNumber 10 + // receiver, portsRangeNumber 10 | sender, portsRangeNumber 10 + mode3Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{ + lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}), + lo.T2(RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}), + } + + // mode 4, Regular ports changes are usually the sender. + // sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 7, portsRangeNumber 10 + // sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 4, portsRangeNumber 10 + // sender, portsRandomNumber 1000, SendDelayMs: 2000 | receiver, listen 256 ports, portsRangeNumber 10 + mode4Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{ + lo.T2( + RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000}, + RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7, PortsRangeNumber: 10}, + ), + lo.T2( + RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000}, + RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4, PortsRangeNumber: 10}, + ), + lo.T2( + RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000}, + RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, PortsRangeNumber: 10}, + ), + } +) + +func getBehaviorByMode(mode int) []lo.Tuple2[RecommandBehavior, RecommandBehavior] { + switch mode { + case 0: + return mode0Behaviors + case 1: + return mode1Behaviors + case 2: + return mode2Behaviors + case 3: + return mode3Behaviors + case 4: + return mode4Behaviors + } + // default + return mode0Behaviors +} + +func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, RecommandBehavior) { + behaviors := getBehaviorByMode(mode) + if index >= len(behaviors) { + return RecommandBehavior{}, RecommandBehavior{} + } + return behaviors[index].A, behaviors[index].B +} + +func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore { + return getBehaviorScoresByMode2(mode, defaultScore, defaultScore) +} + +func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore { + behaviors := getBehaviorByMode(mode) + scores := make([]*BehaviorScore, 0, len(behaviors)) + for i := 0; i < len(behaviors); i++ { + score := receiverScore + if behaviors[i].A.Role == DetectRoleSender { + score = senderScore + } + scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score}) + } + return scores +} + +type RecommandBehavior struct { + Role string + TTL int + SendDelayMs int + PortsRangeNumber int + PortsRandomNumber int + ListenRandomPorts int +} + +type MakeHoleRecords struct { + mu sync.Mutex + scores []*BehaviorScore + LastUpdateTime time.Time +} + +func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords { + scores := []*BehaviorScore{} + easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v}) + appendMode0 := func() { + switch { + case c.PublicNetwork: + scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 0, 1)...) + case v.PublicNetwork: + scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 1, 0)...) + default: + scores = append(scores, getBehaviorScoresByMode(DetectMode0, 0)...) + } + } + + switch { + case easyCount == 2: + appendMode0() + case hardCount == 1 && portsChangedRegularCount == 1: + scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...) + scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...) + appendMode0() + case hardCount == 1 && portsChangedRegularCount == 0: + scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...) + scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...) + appendMode0() + case hardCount == 2 && portsChangedRegularCount == 2: + scores = append(scores, getBehaviorScoresByMode(DetectMode3, 0)...) + scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...) + case hardCount == 2 && portsChangedRegularCount == 1: + scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...) + default: + // hard to make hole, just trying it out. + scores = append(scores, getBehaviorScoresByMode(DetectMode0, 1)...) + scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...) + scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...) + } + return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()} +} + +func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) { + mhr.mu.Lock() + defer mhr.mu.Unlock() + mhr.LastUpdateTime = time.Now() + for i := range mhr.scores { + score := mhr.scores[i] + if score.Mode != mode || score.Index != index { + continue + } + + score.Score += 2 + score.Score = lo.Min([]int{score.Score, 10}) + return + } +} + +func (mhr *MakeHoleRecords) Recommand() (mode, index int) { + mhr.mu.Lock() + defer mhr.mu.Unlock() + + maxScore := lo.MaxBy(mhr.scores, func(item, max *BehaviorScore) bool { + return item.Score > max.Score + }) + if maxScore == nil { + return 0, 0 + } + maxScore.Score-- + mhr.LastUpdateTime = time.Now() + return maxScore.Mode, maxScore.Index +} + +type BehaviorScore struct { + Mode int + Index int + // between -10 and 10 + Score int +} + +type Analyzer struct { + // key is client ip + visitor ip + records map[string]*MakeHoleRecords + dataReserveDuration time.Duration + + mu sync.Mutex +} + +func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer { + return &Analyzer{ + records: make(map[string]*MakeHoleRecords), + dataReserveDuration: dataReserveDuration, + } +} + +func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, index int, _ RecommandBehavior, _ RecommandBehavior) { + a.mu.Lock() + records, ok := a.records[key] + if !ok { + records = NewMakeHoleRecords(c, v) + a.records[key] = records + } + a.mu.Unlock() + + mode, index = records.Recommand() + cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index) + + switch mode { + case DetectMode1: + // HardNAT is always the sender + if c.NatType == EasyNAT { + cBehavior, vBehavior = vBehavior, cBehavior + } + case DetectMode2: + // HardNAT is always the receiver + if c.NatType == HardNAT { + cBehavior, vBehavior = vBehavior, cBehavior + } + case DetectMode4: + // Regular ports changes is always the sender + if !c.RegularPortsChange { + cBehavior, vBehavior = vBehavior, cBehavior + } + } + return mode, index, cBehavior, vBehavior +} + +func (a *Analyzer) ReportSuccess(key string, mode, index int) { + a.mu.Lock() + records, ok := a.records[key] + a.mu.Unlock() + if !ok { + return + } + records.ReportSuccess(mode, index) +} + +func (a *Analyzer) Clean() (int, int) { + now := time.Now() + total := 0 + count := 0 + + // cleanup 10w records may take 5ms + a.mu.Lock() + defer a.mu.Unlock() + total = len(a.records) + // clean up records that have not been used for a period of time. + for key, records := range a.records { + if now.Sub(records.LastUpdateTime) > a.dataReserveDuration { + delete(a.records, key) + count++ + } + } + return count, total +} diff --git a/pkg/nathole/classify.go b/pkg/nathole/classify.go index c667e078..79f8efe0 100644 --- a/pkg/nathole/classify.go +++ b/pkg/nathole/classify.go @@ -17,6 +17,9 @@ package nathole import ( "fmt" "net" + "strconv" + + "github.com/samber/lo" ) const ( @@ -29,46 +32,96 @@ const ( BehaviorBothChanged = "BehaviorBothChanged" ) -// ClassifyNATType classify NAT type by given addresses. -func ClassifyNATType(addresses []string) (string, string, error) { +type NatFeature struct { + NatType string + Behavior string + PortsDifference int + RegularPortsChange bool + PublicNetwork bool +} + +func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, error) { if len(addresses) <= 1 { - return "", "", fmt.Errorf("not enough addresses") + return nil, fmt.Errorf("not enough addresses") } + natFeature := &NatFeature{} ipChanged := false portChanged := false var baseIP, basePort string + var portMax, portMin int for _, addr := range addresses { ip, port, err := net.SplitHostPort(addr) if err != nil { - return "", "", err + return nil, err } + portNum, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + if lo.Contains(localIPs, ip) { + natFeature.PublicNetwork = true + } + if baseIP == "" { baseIP = ip basePort = port + portMax = portNum + portMin = portNum continue } + if portNum > portMax { + portMax = portNum + } + if portNum < portMin { + portMin = portNum + } if baseIP != ip { ipChanged = true } if basePort != port { portChanged = true } + } - if ipChanged && portChanged { - break - } + natFeature.PortsDifference = portMax - portMin + if natFeature.PortsDifference <= 10 && natFeature.PortsDifference >= 1 { + natFeature.RegularPortsChange = true } switch { case ipChanged && portChanged: - return HardNAT, BehaviorBothChanged, nil + natFeature.NatType = HardNAT + natFeature.Behavior = BehaviorBothChanged case ipChanged: - return HardNAT, BehaviorIPChanged, nil + natFeature.NatType = HardNAT + natFeature.Behavior = BehaviorIPChanged case portChanged: - return HardNAT, BehaviorPortChanged, nil + natFeature.NatType = HardNAT + natFeature.Behavior = BehaviorPortChanged default: - return EasyNAT, BehaviorNoChange, nil + natFeature.NatType = EasyNAT + natFeature.Behavior = BehaviorNoChange } + return natFeature, nil +} + +func ClassifyFeatureCount(features []*NatFeature) (int, int, int) { + easyCount := 0 + hardCount := 0 + // for HardNAT + portsChangedRegularCount := 0 + for _, feature := range features { + if feature.NatType == EasyNAT { + easyCount++ + continue + } + + hardCount++ + if feature.RegularPortsChange { + portsChangedRegularCount++ + } + } + return easyCount, hardCount, portsChangedRegularCount } diff --git a/pkg/nathole/controller.go b/pkg/nathole/controller.go new file mode 100644 index 00000000..71feb1be --- /dev/null +++ b/pkg/nathole/controller.go @@ -0,0 +1,382 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nathole + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/fatedier/golib/errors" + "github.com/samber/lo" + "golang.org/x/sync/errgroup" + + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/transport" + "github.com/fatedier/frp/pkg/util/log" + "github.com/fatedier/frp/pkg/util/util" +) + +// NatHoleTimeout seconds. +var NatHoleTimeout int64 = 10 + +func NewTransactionID() string { + id, _ := util.RandID() + return fmt.Sprintf("%d%s", time.Now().Unix(), id) +} + +type ClientCfg struct { + name string + sk string + sidCh chan string +} + +type Session struct { + sid string + analysisKey string + recommandMode int + recommandIndex int + + visitorMsg *msg.NatHoleVisitor + visitorTransporter transport.MessageTransporter + vResp *msg.NatHoleResp + vNatFeature *NatFeature + vBehavior RecommandBehavior + + clientMsg *msg.NatHoleClient + clientTransporter transport.MessageTransporter + cResp *msg.NatHoleResp + cNatFeature *NatFeature + cBehavior RecommandBehavior + + notifyCh chan struct{} +} + +func (s *Session) genAnalysisKey() { + hash := md5.New() + vIPs := lo.Uniq(parseIPs(s.visitorMsg.MappedAddrs)) + if len(vIPs) > 0 { + hash.Write([]byte(vIPs[0])) + } + hash.Write([]byte(s.vNatFeature.NatType)) + hash.Write([]byte(s.vNatFeature.Behavior)) + hash.Write([]byte(strconv.FormatBool(s.vNatFeature.RegularPortsChange))) + + cIPs := lo.Uniq(parseIPs(s.clientMsg.MappedAddrs)) + if len(cIPs) > 0 { + hash.Write([]byte(cIPs[0])) + } + hash.Write([]byte(s.cNatFeature.NatType)) + hash.Write([]byte(s.cNatFeature.Behavior)) + hash.Write([]byte(strconv.FormatBool(s.cNatFeature.RegularPortsChange))) + s.analysisKey = hex.EncodeToString(hash.Sum(nil)) +} + +type Controller struct { + clientCfgs map[string]*ClientCfg + sessions map[string]*Session + analyzer *Analyzer + + mu sync.RWMutex +} + +func NewController(analysisDataReserveDuration time.Duration) (*Controller, error) { + return &Controller{ + clientCfgs: make(map[string]*ClientCfg), + sessions: make(map[string]*Session), + analyzer: NewAnalyzer(analysisDataReserveDuration), + }, nil +} + +func (c *Controller) CleanWorker(ctx context.Context) { + ticker := time.NewTicker(time.Hour) + defer ticker.Stop() + for { + select { + case <-ticker.C: + start := time.Now() + count, total := c.analyzer.Clean() + log.Trace("clean %d/%d nathole analysis data, cost %v", count, total, time.Since(start)) + case <-ctx.Done(): + return + } + } +} + +func (c *Controller) ListenClient(name string, sk string) chan string { + cfg := &ClientCfg{ + name: name, + sk: sk, + sidCh: make(chan string), + } + c.mu.Lock() + defer c.mu.Unlock() + c.clientCfgs[name] = cfg + return cfg.sidCh +} + +func (c *Controller) CloseClient(name string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.clientCfgs, name) +} + +func (c *Controller) GenSid() string { + t := time.Now().Unix() + id, _ := util.RandID() + return fmt.Sprintf("%d%s", t, id) +} + +func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter) { + if m.PreCheck { + _, ok := c.clientCfgs[m.ProxyName] + if !ok { + _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName))) + } else { + _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, "")) + } + return + } + + sid := c.GenSid() + session := &Session{ + sid: sid, + visitorMsg: m, + visitorTransporter: transporter, + notifyCh: make(chan struct{}, 1), + } + var ( + clientCfg *ClientCfg + ok bool + ) + err := func() error { + c.mu.Lock() + defer c.mu.Unlock() + + clientCfg, ok = c.clientCfgs[m.ProxyName] + if !ok { + return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName) + } + if m.SignKey != util.GetAuthKey(clientCfg.sk, m.Timestamp) { + return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName) + } + c.sessions[sid] = session + return nil + }() + if err != nil { + log.Warn("handle visitorMsg error: %v", err) + _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, err.Error())) + return + } + log.Trace("handle visitor message, sid [%s]", sid) + + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.sessions, sid) + }() + + if err := errors.PanicToError(func() { + clientCfg.sidCh <- sid + }); err != nil { + return + } + + // wait for NatHoleClient message + select { + case <-session.notifyCh: + case <-time.After(time.Duration(NatHoleTimeout) * time.Second): + log.Debug("wait for NatHoleClient message timeout, sid [%s]", sid) + return + } + + // Make hole-punching decisions based on the NAT information of the client and visitor. + vResp, cResp, err := c.analysis(session) + if err != nil { + log.Debug("sid [%s] analysis error: %v", err) + vResp = c.GenNatHoleResponse(session.visitorMsg.TransactionID, nil, err.Error()) + cResp = c.GenNatHoleResponse(session.clientMsg.TransactionID, nil, err.Error()) + } + session.cResp = cResp + session.vResp = vResp + + // send response to visitor and client + var g errgroup.Group + g.Go(func() error { + // if it's sender, wait for a while to make sure the client has send the detect messages + if vResp.DetectBehavior.Role == "sender" { + time.Sleep(1 * time.Second) + } + _ = session.visitorTransporter.Send(vResp) + return nil + }) + g.Go(func() error { + // if it's sender, wait for a while to make sure the client has send the detect messages + if cResp.DetectBehavior.Role == "sender" { + time.Sleep(1 * time.Second) + } + _ = session.clientTransporter.Send(cResp) + return nil + }) + _ = g.Wait() + + time.Sleep(time.Duration(cResp.DetectBehavior.ReadTimeoutMs+30000) * time.Millisecond) +} + +func (c *Controller) HandleClient(m *msg.NatHoleClient, transporter transport.MessageTransporter) { + c.mu.RLock() + session, ok := c.sessions[m.Sid] + c.mu.RUnlock() + if !ok { + return + } + log.Trace("handle client message, sid [%s]", session.sid) + session.clientMsg = m + session.clientTransporter = transporter + select { + case session.notifyCh <- struct{}{}: + default: + } +} + +func (c *Controller) HandleReport(m *msg.NatHoleReport) { + c.mu.RLock() + session, ok := c.sessions[m.Sid] + c.mu.RUnlock() + if !ok { + log.Trace("sid [%s] report make hole success: %v, but session not found", m.Sid, m.Success) + return + } + if m.Success { + c.analyzer.ReportSuccess(session.analysisKey, session.recommandMode, session.recommandIndex) + } + log.Info("sid [%s] report make hole success: %v, mode %v, index %v", + m.Sid, m.Success, session.recommandMode, session.recommandIndex) +} + +func (c *Controller) GenNatHoleResponse(transactionID string, session *Session, errInfo string) *msg.NatHoleResp { + var sid string + if session != nil { + sid = session.sid + } + return &msg.NatHoleResp{ + TransactionID: transactionID, + Sid: sid, + Error: errInfo, + } +} + +// analysis analyzes the NAT type and behavior of the visitor and client, then makes hole-punching decisions. +// return the response to the visitor and client. +func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleResp, error) { + cm := session.clientMsg + vm := session.visitorMsg + + cNatFeature, err := ClassifyNATFeature(cm.MappedAddrs, parseIPs(cm.AssistedAddrs)) + if err != nil { + return nil, nil, fmt.Errorf("classify client nat feature error: %v", err) + } + + vNatFeature, err := ClassifyNATFeature(vm.MappedAddrs, parseIPs(vm.AssistedAddrs)) + if err != nil { + return nil, nil, fmt.Errorf("classify visitor nat feature error: %v", err) + } + session.cNatFeature = cNatFeature + session.vNatFeature = vNatFeature + session.genAnalysisKey() + + mode, index, cBehavior, vBehavior := c.analyzer.GetRecommandBehaviors(session.analysisKey, cNatFeature, vNatFeature) + session.recommandMode = mode + session.recommandIndex = index + session.cBehavior = cBehavior + session.vBehavior = vBehavior + + timeoutMs := lo.Max([]int{cBehavior.SendDelayMs, vBehavior.SendDelayMs}) + 5000 + if cBehavior.ListenRandomPorts > 0 || vBehavior.ListenRandomPorts > 0 { + timeoutMs += 30000 + } + + protocol := vm.Protocol + vResp := &msg.NatHoleResp{ + TransactionID: vm.TransactionID, + Sid: session.sid, + Protocol: protocol, + CandidateAddrs: lo.Uniq(cm.MappedAddrs), + AssistedAddrs: lo.Uniq(cm.AssistedAddrs), + DetectBehavior: msg.NatHoleDetectBehavior{ + Mode: mode, + Role: vBehavior.Role, + TTL: vBehavior.TTL, + SendDelayMs: vBehavior.SendDelayMs, + ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs, + SendRandomPorts: vBehavior.PortsRandomNumber, + ListenRandomPorts: vBehavior.ListenRandomPorts, + CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber), + }, + } + cResp := &msg.NatHoleResp{ + TransactionID: cm.TransactionID, + Sid: session.sid, + Protocol: protocol, + CandidateAddrs: lo.Uniq(vm.MappedAddrs), + AssistedAddrs: lo.Uniq(vm.AssistedAddrs), + DetectBehavior: msg.NatHoleDetectBehavior{ + Mode: mode, + Role: cBehavior.Role, + TTL: cBehavior.TTL, + SendDelayMs: cBehavior.SendDelayMs, + ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs, + SendRandomPorts: cBehavior.PortsRandomNumber, + ListenRandomPorts: cBehavior.ListenRandomPorts, + CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber), + }, + } + + log.Debug("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s", + session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol) + log.Debug("sid [%s] visitor detect behavior: %+v", session.sid, vResp.DetectBehavior) + log.Debug("sid [%s] client detect behavior: %+v", session.sid, cResp.DetectBehavior) + return vResp, cResp, nil +} + +func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange { + if maxNumber <= 0 { + return nil + } + + addr, err := lo.Last(addrs) + if err != nil { + return nil + } + var ports []msg.PortsRange + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil + } + ports = append(ports, msg.PortsRange{ + From: lo.Max([]int{port - difference - 5, port - maxNumber, 1}), + To: lo.Min([]int{port + difference + 5, port + maxNumber, 65535}), + }) + return ports +} diff --git a/pkg/nathole/discovery.go b/pkg/nathole/discovery.go index 4c684b3f..0be41260 100644 --- a/pkg/nathole/discovery.go +++ b/pkg/nathole/discovery.go @@ -20,8 +20,6 @@ import ( "time" "github.com/pion/stun" - - "github.com/fatedier/frp/pkg/msg" ) var responseTimeout = 3 * time.Second @@ -31,35 +29,27 @@ type Message struct { Addr string } -func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) { +// If the localAddr is empty, it will listen on a random port. +func Discover(stunServers []string, localAddr string) ([]string, net.Addr, error) { // create a discoverConn and get response from messageChan - discoverConn, err := listen() + discoverConn, err := listen(localAddr) if err != nil { - return nil, err + return nil, nil, err } defer discoverConn.Close() go discoverConn.readLoop() - addresses := make([]string, 0, len(stunServers)+1) - if serverAddress != "" { - // get external address from frp server - externalAddr, err := discoverConn.discoverFromServer(serverAddress, key) - if err != nil { - return nil, err - } - addresses = append(addresses, externalAddr) - } - + addresses := make([]string, 0, len(stunServers)) for _, addr := range stunServers { // get external address from stun server externalAddrs, err := discoverConn.discoverFromStunServer(addr) if err != nil { - return nil, err + return nil, nil, err } addresses = append(addresses, externalAddrs...) } - return addresses, nil + return addresses, discoverConn.localAddr, nil } type stunResponse struct { @@ -74,8 +64,16 @@ type discoverConn struct { messageChan chan *Message } -func listen() (*discoverConn, error) { - conn, err := net.ListenUDP("udp4", nil) +func listen(localAddr string) (*discoverConn, error) { + var local *net.UDPAddr + if localAddr != "" { + addr, err := net.ResolveUDPAddr("udp4", localAddr) + if err != nil { + return nil, err + } + local = addr + } + conn, err := net.ListenUDP("udp4", local) if err != nil { return nil, err } @@ -159,43 +157,6 @@ func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) { return resp, nil } -func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) { - addr, err := net.ResolveUDPAddr("udp4", serverAddress) - if err != nil { - return "", err - } - m := &msg.NatHoleBinding{ - TransactionID: NewTransactionID(), - } - - buf, err := EncodeMessage(m, key) - if err != nil { - return "", err - } - - if _, err := c.conn.WriteTo(buf, addr); err != nil { - return "", err - } - - var respMsg msg.NatHoleBindingResp - select { - case rawMsg := <-c.messageChan: - if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil { - return "", err - } - case <-time.After(responseTimeout): - return "", fmt.Errorf("wait response from frp server timeout") - } - - if respMsg.TransactionID == "" { - return "", fmt.Errorf("error format: no transaction id found") - } - if respMsg.Error != "" { - return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error) - } - return respMsg.Address, nil -} - func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) { resp, err := c.doSTUNRequest(addr) if err != nil { diff --git a/pkg/nathole/nathole.go b/pkg/nathole/nathole.go index 2b114bce..a4d5e463 100644 --- a/pkg/nathole/nathole.go +++ b/pkg/nathole/nathole.go @@ -15,249 +15,426 @@ package nathole import ( - "bytes" + "context" "fmt" + "math/rand" "net" - "sync" + "strconv" + "strings" "time" - "github.com/fatedier/golib/crypto" - "github.com/fatedier/golib/errors" "github.com/fatedier/golib/pool" + "github.com/samber/lo" + "golang.org/x/net/ipv4" + "k8s.io/apimachinery/pkg/util/sets" "github.com/fatedier/frp/pkg/msg" - "github.com/fatedier/frp/pkg/util/log" - "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/pkg/transport" + "github.com/fatedier/frp/pkg/util/xlog" ) -// NatHoleTimeout seconds. -var NatHoleTimeout int64 = 10 +var ( + // mode 0: simple detect mode, usually for both EasyNAT or HardNAT & EasyNAT(Public Network) + // a. receiver sends detect message with low TTL + // b. sender sends normal detect message to receiver + // c. receiver receives detect message and sends back a message to sender + // + // mode 1: For HardNAT & EasyNAT, send detect messages to multiple guessed ports. + // Usually applicable to scenarios where port changes are regular. + // Most of the steps are the same as mode 0, but EasyNAT is fixed as the receiver and will send detect messages + // with low TTL to multiple guessed ports of the sender. + // + // mode 2: For HardNAT & EasyNAT, ports changes are not regular. + // a. HardNAT machine will listen on multiple ports and send detect messages with low TTL to EasyNAT machine + // b. EasyNAT machine will send detect messages to random ports of HardNAT machine. + // + // mode 3: For HardNAT & HardNAT, both changes in the ports are regular. + // Most of the steps are the same as mode 1, but the sender also needs to send detect messages to multiple guessed + // ports of the receiver. + // + // mode 4: For HardNAT & HardNAT, one of the changes in the ports is regular. + // Regular port changes are usually on the sender side. + // a. Receiver listens on multiple ports and sends detect messages with low TTL to the sender's guessed range ports. + // b. Sender sends detect messages to random ports of the receiver. + SupportedModes = []int{DetectMode0, DetectMode1, DetectMode2, DetectMode3, DetectMode4} + SupportedRoles = []string{DetectRoleSender, DetectRoleReceiver} -func NewTransactionID() string { - id, _ := util.RandID() - return fmt.Sprintf("%d%s", time.Now().Unix(), id) + DetectMode0 = 0 + DetectMode1 = 1 + DetectMode2 = 2 + DetectMode3 = 3 + DetectMode4 = 4 + DetectRoleSender = "sender" + DetectRoleReceiver = "receiver" +) + +type PrepareResult struct { + Addrs []string + AssistedAddrs []string + ListenConn *net.UDPConn + NatType string + Behavior string } -type SidRequest struct { - Sid string - NotifyCh chan struct{} -} +// PreCheck is used to check if the proxy is ready for penetration. +// Call this function before calling Prepare to avoid unnecessary preparation work. +func PreCheck( + ctx context.Context, transporter transport.MessageTransporter, + proxyName string, timeout time.Duration, +) error { + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() -type Controller struct { - listener *net.UDPConn - - clientCfgs map[string]*ClientCfg - sessions map[string]*Session - - encryptionKey []byte - mu sync.RWMutex -} - -func NewController(udpBindAddr string, encryptionKey []byte) (nc *Controller, err error) { - addr, err := net.ResolveUDPAddr("udp", udpBindAddr) + var natHoleRespMsg *msg.NatHoleResp + transactionID := NewTransactionID() + m, err := transporter.Do(timeoutCtx, &msg.NatHoleVisitor{ + TransactionID: transactionID, + ProxyName: proxyName, + PreCheck: true, + }, transactionID, msg.TypeNameNatHoleResp) if err != nil { - return nil, err + return fmt.Errorf("get natHoleRespMsg error: %v", err) } - lconn, err := net.ListenUDP("udp", addr) + mm, ok := m.(*msg.NatHoleResp) + if !ok { + return fmt.Errorf("get natHoleRespMsg error: invalid message type") + } + natHoleRespMsg = mm + + if natHoleRespMsg.Error != "" { + return fmt.Errorf("%s", natHoleRespMsg.Error) + } + return nil +} + +// Prepare is used to do some preparation work before penetration. +func Prepare(stunServers []string) (*PrepareResult, error) { + // discover for Nat type + addrs, localAddr, err := Discover(stunServers, "") if err != nil { - return nil, err + return nil, fmt.Errorf("discover error: %v", err) } - nc = &Controller{ - listener: lconn, - clientCfgs: make(map[string]*ClientCfg), - sessions: make(map[string]*Session), - encryptionKey: encryptionKey, + if len(addrs) < 2 { + return nil, fmt.Errorf("discover error: not enough addresses") } - return nc, nil + + localIPs, _ := ListLocalIPsForNatHole(10) + natFeature, err := ClassifyNATFeature(addrs, localIPs) + if err != nil { + return nil, fmt.Errorf("classify nat feature error: %v", err) + } + + laddr, err := net.ResolveUDPAddr("udp4", localAddr.String()) + if err != nil { + return nil, fmt.Errorf("resolve local udp addr error: %v", err) + } + listenConn, err := net.ListenUDP("udp4", laddr) + if err != nil { + return nil, fmt.Errorf("listen local udp addr error: %v", err) + } + + assistedAddrs := make([]string, 0, len(localIPs)) + for _, ip := range localIPs { + assistedAddrs = append(assistedAddrs, net.JoinHostPort(ip, strconv.Itoa(laddr.Port))) + } + return &PrepareResult{ + Addrs: addrs, + AssistedAddrs: assistedAddrs, + ListenConn: listenConn, + NatType: natFeature.NatType, + Behavior: natFeature.Behavior, + }, nil } -func (nc *Controller) ListenClient(name string, sk string) (sidCh chan *SidRequest) { - clientCfg := &ClientCfg{ - Name: name, - Sk: sk, - SidCh: make(chan *SidRequest), +// ExchangeInfo is used to exchange information between client and visitor. +// 1. Send input message to server by msgTransporter. +// 2. Server will gather information from client and visitor and analyze it. Then send back a NatHoleResp message to them to tell them how to do next. +// 3. Receive NatHoleResp message from server. +func ExchangeInfo( + ctx context.Context, transporter transport.MessageTransporter, + laneKey string, m msg.Message, timeout time.Duration, +) (*msg.NatHoleResp, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + var natHoleRespMsg *msg.NatHoleResp + m, err := transporter.Do(timeoutCtx, m, laneKey, msg.TypeNameNatHoleResp) + if err != nil { + return nil, fmt.Errorf("get natHoleRespMsg error: %v", err) } - nc.mu.Lock() - nc.clientCfgs[name] = clientCfg - nc.mu.Unlock() - return clientCfg.SidCh + mm, ok := m.(*msg.NatHoleResp) + if !ok { + return nil, fmt.Errorf("get natHoleRespMsg error: invalid message type") + } + natHoleRespMsg = mm + + if natHoleRespMsg.Error != "" { + return nil, fmt.Errorf("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) + } + if len(natHoleRespMsg.CandidateAddrs) == 0 { + return nil, fmt.Errorf("natHoleRespMsg get empty candidate addresses") + } + return natHoleRespMsg, nil } -func (nc *Controller) CloseClient(name string) { - nc.mu.Lock() - defer nc.mu.Unlock() - delete(nc.clientCfgs, name) +// MakeHole is used to make a NAT hole between client and visitor. +func MakeHole(ctx context.Context, listenConn *net.UDPConn, m *msg.NatHoleResp, key []byte) (*net.UDPConn, *net.UDPAddr, error) { + xl := xlog.FromContextSafe(ctx) + transactionID := NewTransactionID() + sendToRangePortsFunc := func(conn *net.UDPConn, addr string) error { + return sendSidMessage(ctx, conn, m.Sid, transactionID, addr, key, m.DetectBehavior.TTL) + } + + listenConns := []*net.UDPConn{listenConn} + var detectAddrs []string + if m.DetectBehavior.Role == DetectRoleSender { + // sender + if m.DetectBehavior.SendDelayMs > 0 { + time.Sleep(time.Duration(m.DetectBehavior.SendDelayMs) * time.Millisecond) + } + detectAddrs = m.AssistedAddrs + detectAddrs = append(detectAddrs, m.CandidateAddrs...) + } else { + // receiver + if len(m.DetectBehavior.CandidatePorts) == 0 { + detectAddrs = m.CandidateAddrs + } + + if m.DetectBehavior.ListenRandomPorts > 0 { + for i := 0; i < m.DetectBehavior.ListenRandomPorts; i++ { + tmpConn, err := net.ListenUDP("udp4", nil) + if err != nil { + xl.Warn("listen random udp addr error: %v", err) + continue + } + listenConns = append(listenConns, tmpConn) + } + } + } + + detectAddrs = lo.Uniq(detectAddrs) + for _, detectAddr := range detectAddrs { + for _, conn := range listenConns { + if err := sendSidMessage(ctx, conn, m.Sid, transactionID, detectAddr, key, m.DetectBehavior.TTL); err != nil { + xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) + } + } + } + if len(m.DetectBehavior.CandidatePorts) > 0 { + for _, conn := range listenConns { + sendSidMessageToRangePorts(ctx, conn, m.CandidateAddrs, m.DetectBehavior.CandidatePorts, sendToRangePortsFunc) + } + } + if m.DetectBehavior.SendRandomPorts > 0 { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + for i := range listenConns { + go sendSidMessageToRandomPorts(ctx, listenConns[i], m.CandidateAddrs, m.DetectBehavior.SendRandomPorts, sendToRangePortsFunc) + } + } + + timeout := 5 * time.Second + if m.DetectBehavior.ReadTimeoutMs > 0 { + timeout = time.Duration(m.DetectBehavior.ReadTimeoutMs) * time.Millisecond + } + + if len(listenConns) == 1 { + raddr, err := waitDetectMessage(ctx, listenConns[0], m.Sid, key, timeout, m.DetectBehavior.Role) + if err != nil { + return nil, nil, fmt.Errorf("wait detect message error: %v", err) + } + return listenConns[0], raddr, nil + } + + type result struct { + lConn *net.UDPConn + raddr *net.UDPAddr + } + resultCh := make(chan result) + for _, conn := range listenConns { + go func(lConn *net.UDPConn) { + addr, err := waitDetectMessage(ctx, lConn, m.Sid, key, timeout, m.DetectBehavior.Role) + if err != nil { + lConn.Close() + return + } + select { + case resultCh <- result{lConn: lConn, raddr: addr}: + default: + lConn.Close() + } + }(conn) + } + + select { + case result := <-resultCh: + return result.lConn, result.raddr, nil + case <-time.After(timeout): + return nil, nil, fmt.Errorf("wait detect message timeout") + case <-ctx.Done(): + return nil, nil, fmt.Errorf("wait detect message canceled") + } } -func (nc *Controller) Run() { +func waitDetectMessage( + ctx context.Context, conn *net.UDPConn, sid string, key []byte, + timeout time.Duration, role string, +) (*net.UDPAddr, error) { + xl := xlog.FromContextSafe(ctx) for { buf := pool.GetBuf(1024) - n, raddr, err := nc.listener.ReadFromUDP(buf) + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + n, raddr, err := conn.ReadFromUDP(buf) + _ = conn.SetReadDeadline(time.Time{}) if err != nil { - log.Warn("nat hole listener read from udp error: %v", err) - return + return nil, err } - plain, err := crypto.Decode(buf[:n], nc.encryptionKey) - if err != nil { - log.Warn("nathole listener decode from %s error: %v", raddr.String(), err) - continue - } - - rawMsg, err := msg.ReadMsg(bytes.NewReader(plain)) - if err != nil { - log.Warn("read nat hole message error: %v", err) - continue - } - - switch m := rawMsg.(type) { - case *msg.NatHoleBinding: - go nc.HandleBinding(m, raddr) - case *msg.NatHoleVisitor: - go nc.HandleVisitor(m, raddr) - case *msg.NatHoleClient: - go nc.HandleClient(m, raddr) - default: - log.Trace("unknown nat hole message type") + xl.Debug("get udp message local %s, from %s", conn.LocalAddr(), raddr) + var m msg.NatHoleSid + if err := DecodeMessageInto(buf[:n], key, &m); err != nil { + xl.Warn("decode sid message error: %v", err) continue } pool.PutBuf(buf) - } -} -func (nc *Controller) GenSid() string { - t := time.Now().Unix() - id, _ := util.RandID() - return fmt.Sprintf("%d%s", t, id) -} - -func (nc *Controller) HandleBinding(m *msg.NatHoleBinding, raddr *net.UDPAddr) { - log.Trace("handle binding message from %s", raddr.String()) - resp := &msg.NatHoleBindingResp{ - TransactionID: m.TransactionID, - Address: raddr.String(), - } - plain, err := msg.Pack(resp) - if err != nil { - log.Error("pack nat hole binding response error: %v", err) - return - } - buf, err := crypto.Encode(plain, nc.encryptionKey) - if err != nil { - log.Error("encode nat hole binding response error: %v", err) - return - } - _, err = nc.listener.WriteToUDP(buf, raddr) - if err != nil { - log.Error("write nat hole binding response to %s error: %v", raddr.String(), err) - return - } -} - -func (nc *Controller) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) { - sid := nc.GenSid() - session := &Session{ - Sid: sid, - VisitorAddr: raddr, - NotifyCh: make(chan struct{}), - } - nc.mu.Lock() - clientCfg, ok := nc.clientCfgs[m.ProxyName] - if !ok { - nc.mu.Unlock() - errInfo := fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName) - log.Debug(errInfo) - _, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr) - return - } - if m.SignKey != util.GetAuthKey(clientCfg.Sk, m.Timestamp) { - nc.mu.Unlock() - errInfo := fmt.Sprintf("xtcp connection of [%s] auth failed", m.ProxyName) - log.Debug(errInfo) - _, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr) - return - } - - nc.sessions[sid] = session - nc.mu.Unlock() - log.Trace("handle visitor message, sid [%s]", sid) - - defer func() { - nc.mu.Lock() - delete(nc.sessions, sid) - nc.mu.Unlock() - }() - - err := errors.PanicToError(func() { - clientCfg.SidCh <- &SidRequest{ - Sid: sid, - NotifyCh: session.NotifyCh, + if m.Sid != sid { + xl.Warn("get sid message with wrong sid: %s, expect: %s", m.Sid, sid) + continue } - }) + + if !m.Response { + // only wait for response messages if we are a sender + if role == DetectRoleSender { + continue + } + + m.Response = true + buf2, err := EncodeMessage(&m, key) + if err != nil { + xl.Warn("encode sid message error: %v", err) + continue + } + _, _ = conn.WriteToUDP(buf2, raddr) + } + return raddr, nil + } +} + +func sendSidMessage( + ctx context.Context, conn *net.UDPConn, + sid string, transactionID string, addr string, key []byte, ttl int, +) error { + xl := xlog.FromContextSafe(ctx) + ttlStr := "" + if ttl > 0 { + ttlStr = fmt.Sprintf(" with ttl %d", ttl) + } + xl.Trace("send sid message from %s to %s%s", conn.LocalAddr(), addr, ttlStr) + raddr, err := net.ResolveUDPAddr("udp4", addr) if err != nil { - return + return err } - - // Wait client connections. - select { - case <-session.NotifyCh: - resp := nc.GenNatHoleResponse(session, "") - log.Trace("send nat hole response to visitor") - _, _ = nc.listener.WriteToUDP(resp, raddr) - case <-time.After(time.Duration(NatHoleTimeout) * time.Second): - return + if transactionID == "" { + transactionID = NewTransactionID() } -} - -func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) { - nc.mu.RLock() - session, ok := nc.sessions[m.Sid] - nc.mu.RUnlock() - if !ok { - return + m := &msg.NatHoleSid{ + TransactionID: transactionID, + Sid: sid, + Response: false, + Nonce: strings.Repeat("0", rand.Intn(20)), } - log.Trace("handle client message, sid [%s]", session.Sid) - session.ClientAddr = raddr - - resp := nc.GenNatHoleResponse(session, "") - log.Trace("send nat hole response to client") - _, _ = nc.listener.WriteToUDP(resp, raddr) -} - -func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byte { - var ( - sid string - visitorAddr string - clientAddr string - ) - if session != nil { - sid = session.Sid - visitorAddr = session.VisitorAddr.String() - clientAddr = session.ClientAddr.String() - } - m := &msg.NatHoleResp{ - Sid: sid, - VisitorAddr: visitorAddr, - ClientAddr: clientAddr, - Error: errInfo, - } - b := bytes.NewBuffer(nil) - err := msg.WriteMsg(b, m) + buf, err := EncodeMessage(m, key) if err != nil { - return []byte("") + return err } - return b.Bytes() + if ttl > 0 { + uConn := ipv4.NewConn(conn) + original, err := uConn.TTL() + if err != nil { + xl.Trace("get ttl error %v", err) + return err + } + xl.Trace("original ttl %d", original) + + err = uConn.SetTTL(ttl) + if err != nil { + xl.Trace("set ttl error %v", err) + } else { + defer func() { + _ = uConn.SetTTL(original) + }() + } + } + + if _, err := conn.WriteToUDP(buf, raddr); err != nil { + return err + } + return nil } -type Session struct { - Sid string - VisitorAddr *net.UDPAddr - ClientAddr *net.UDPAddr - - NotifyCh chan struct{} +func sendSidMessageToRangePorts( + ctx context.Context, conn *net.UDPConn, addrs []string, ports []msg.PortsRange, + sendFunc func(*net.UDPConn, string) error, +) { + xl := xlog.FromContextSafe(ctx) + for _, ip := range lo.Uniq(parseIPs(addrs)) { + for _, portsRange := range ports { + for i := portsRange.From; i <= portsRange.To; i++ { + detectAddr := net.JoinHostPort(ip, strconv.Itoa(i)) + if err := sendFunc(conn, detectAddr); err != nil { + xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) + } + time.Sleep(5 * time.Millisecond) + } + } + } } -type ClientCfg struct { - Name string - Sk string - SidCh chan *SidRequest +func sendSidMessageToRandomPorts( + ctx context.Context, conn *net.UDPConn, addrs []string, count int, + sendFunc func(*net.UDPConn, string) error, +) { + xl := xlog.FromContextSafe(ctx) + used := sets.New[int]() + getUnusedPort := func() int { + for i := 0; i < 10; i++ { + port := rand.Intn(65535-1024) + 1024 + if !used.Has(port) { + used.Insert(port) + return port + } + } + return 0 + } + + for i := 0; i < count; i++ { + select { + case <-ctx.Done(): + return + default: + } + + port := getUnusedPort() + if port == 0 { + continue + } + + for _, ip := range lo.Uniq(parseIPs(addrs)) { + detectAddr := net.JoinHostPort(ip, strconv.Itoa(port)) + if err := sendFunc(conn, detectAddr); err != nil { + xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) + } + time.Sleep(time.Millisecond * 15) + } + } +} + +func parseIPs(addrs []string) []string { + var ips []string + for _, addr := range addrs { + if ip, _, err := net.SplitHostPort(addr); err == nil { + ips = append(ips, ip) + } + } + return ips } diff --git a/pkg/nathole/utils.go b/pkg/nathole/utils.go index 75eda1a6..c889c7eb 100644 --- a/pkg/nathole/utils.go +++ b/pkg/nathole/utils.go @@ -16,6 +16,7 @@ package nathole import ( "bytes" + "fmt" "net" "strconv" @@ -63,3 +64,49 @@ func (s *ChangedAddress) GetFrom(m *stun.Message) error { func (s *ChangedAddress) String() string { return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port)) } + +func ListAllLocalIPs() ([]net.IP, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil { + continue + } + ips = append(ips, ip) + } + return ips, nil +} + +func ListLocalIPsForNatHole(max int) ([]string, error) { + if max <= 0 { + return nil, fmt.Errorf("max must be greater than 0") + } + + ips, err := ListAllLocalIPs() + if err != nil { + return nil, err + } + + filtered := make([]string, 0, max) + for _, ip := range ips { + if len(filtered) >= max { + break + } + + // ignore ipv6 address + if ip.To4() == nil { + continue + } + // ignore localhost IP + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + continue + } + + filtered = append(filtered, ip.String()) + } + return filtered, nil +} diff --git a/pkg/transport/message.go b/pkg/transport/message.go new file mode 100644 index 00000000..6bcd8ce8 --- /dev/null +++ b/pkg/transport/message.go @@ -0,0 +1,119 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "context" + "reflect" + "sync" + + "github.com/fatedier/golib/errors" + + "github.com/fatedier/frp/pkg/msg" +) + +type MessageTransporter interface { + Send(msg.Message) error + // Recv(ctx context.Context, laneKey string, msgType string) (Message, error) + // Do will first send msg, then recv msg with the same laneKey and specified msgType. + Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error) + Dispatch(m msg.Message, laneKey string) bool + DispatchWithType(m msg.Message, msgType, laneKey string) bool +} + +func NewMessageTransporter(sendCh chan msg.Message) MessageTransporter { + return &transporterImpl{ + sendCh: sendCh, + registry: make(map[string]map[string]chan msg.Message), + } +} + +type transporterImpl struct { + sendCh chan msg.Message + + // First key is message type and second key is lane key. + // Dispatch will dispatch message to releated channel by its message type + // and lane key. + registry map[string]map[string]chan msg.Message + mu sync.RWMutex +} + +func (impl *transporterImpl) Send(m msg.Message) error { + return errors.PanicToError(func() { + impl.sendCh <- m + }) +} + +func (impl *transporterImpl) Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error) { + ch := make(chan msg.Message, 1) + defer close(ch) + unregisterFn := impl.registerMsgChan(ch, laneKey, recvMsgType) + defer unregisterFn() + + if err := impl.Send(req); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resp := <-ch: + return resp, nil + } +} + +func (impl *transporterImpl) DispatchWithType(m msg.Message, msgType, laneKey string) bool { + var ch chan msg.Message + impl.mu.RLock() + byLaneKey, ok := impl.registry[msgType] + if ok { + ch = byLaneKey[laneKey] + } + impl.mu.RUnlock() + + if ch == nil { + return false + } + + if err := errors.PanicToError(func() { + ch <- m + }); err != nil { + return false + } + return true +} + +func (impl *transporterImpl) Dispatch(m msg.Message, laneKey string) bool { + msgType := reflect.TypeOf(m).Elem().Name() + return impl.DispatchWithType(m, msgType, laneKey) +} + +func (impl *transporterImpl) registerMsgChan(recvCh chan msg.Message, laneKey string, msgType string) (unregister func()) { + impl.mu.Lock() + byLaneKey, ok := impl.registry[msgType] + if !ok { + byLaneKey = make(map[string]chan msg.Message) + impl.registry[msgType] = byLaneKey + } + byLaneKey[laneKey] = recvCh + impl.mu.Unlock() + + unregister = func() { + impl.mu.Lock() + delete(byLaneKey, laneKey) + impl.mu.Unlock() + } + return +} diff --git a/pkg/transport/tls.go b/pkg/transport/tls.go index 02950cc5..13e241b2 100644 --- a/pkg/transport/tls.go +++ b/pkg/transport/tls.go @@ -1,3 +1,17 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package transport import ( diff --git a/pkg/util/net/udp.go b/pkg/util/net/udp.go index 82999a5f..599447c3 100644 --- a/pkg/util/net/udp.go +++ b/pkg/util/net/udp.go @@ -256,3 +256,11 @@ func (l *UDPListener) Close() error { func (l *UDPListener) Addr() net.Addr { return l.addr } + +// ConnectedUDPConn is a wrapper for net.UDPConn which converts WriteTo syscalls +// to Write syscalls that are 4 times faster on some OS'es. This should only be +// used for connections that were produced by a net.Dial* call. +type ConnectedUDPConn struct{ *net.UDPConn } + +// WriteTo redirects all writes to the Write syscall, which is 4 times faster. +func (c *ConnectedUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { return c.Write(b) } diff --git a/pkg/util/util/slice.go b/pkg/util/util/slice.go deleted file mode 100644 index 6d4f14ec..00000000 --- a/pkg/util/util/slice.go +++ /dev/null @@ -1,25 +0,0 @@ -package util - -func InSlice[T comparable](v T, s []T) bool { - for _, vv := range s { - if v == vv { - return true - } - } - return false -} - -func InSliceAny[T any](v T, s []T, equalFn func(a, b T) bool) bool { - for _, vv := range s { - if equalFn(v, vv) { - return true - } - } - return false -} - -func InSliceAnyFunc[T any](equalFn func(a, b T) bool) func(v T, s []T) bool { - return func(v T, s []T) bool { - return InSliceAny(v, s, equalFn) - } -} diff --git a/pkg/util/util/slice_test.go b/pkg/util/util/slice_test.go deleted file mode 100644 index 3e9d1428..00000000 --- a/pkg/util/util/slice_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package util - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestInSlice(t *testing.T) { - require := require.New(t) - require.True(InSlice(1, []int{1, 2, 3})) - require.False(InSlice(0, []int{1, 2, 3})) - require.True(InSlice("foo", []string{"foo", "bar"})) - require.False(InSlice("not exist", []string{"foo", "bar"})) -} - -type testStructA struct { - Name string - Age int -} - -func TestInSliceAny(t *testing.T) { - require := require.New(t) - - a := testStructA{Name: "foo", Age: 20} - b := testStructA{Name: "foo", Age: 30} - c := testStructA{Name: "bar", Age: 20} - - equalFn := func(o, p testStructA) bool { - return o.Name == p.Name - } - require.True(InSliceAny(a, []testStructA{b, c}, equalFn)) - require.False(InSliceAny(c, []testStructA{a, b}, equalFn)) -} - -func TestInSliceAnyFunc(t *testing.T) { - require := require.New(t) - - a := testStructA{Name: "foo", Age: 20} - b := testStructA{Name: "foo", Age: 30} - c := testStructA{Name: "bar", Age: 20} - - equalFn := func(o, p testStructA) bool { - return o.Name == p.Name - } - testStructAInSlice := InSliceAnyFunc(equalFn) - require.True(testStructAInSlice(a, []testStructA{b, c})) - require.False(testStructAInSlice(c, []testStructA{a, b})) -} diff --git a/pkg/util/util/util.go b/pkg/util/util/util.go index b72209c6..b437f3e3 100644 --- a/pkg/util/util/util.go +++ b/pkg/util/util/util.go @@ -28,19 +28,32 @@ import ( // RandID return a rand string used in frp. func RandID() (id string, err error) { - return RandIDWithLen(8) + return RandIDWithLen(16) } // RandIDWithLen return a rand string with idLen length. func RandIDWithLen(idLen int) (id string, err error) { - b := make([]byte, idLen) + if idLen <= 0 { + return "", nil + } + b := make([]byte, idLen/2+1) _, err = rand.Read(b) if err != nil { return } id = fmt.Sprintf("%x", b) - return + return id[:idLen], nil +} + +// RandIDWithRandLen return a rand string with length between [start, end). +func RandIDWithRandLen(start, end int) (id string, err error) { + if start >= end { + err = fmt.Errorf("start should be less than end") + return + } + idLen := mathrand.Intn(end-start) + start + return RandIDWithLen(idLen) } func GetAuthKey(token string, timestamp int64) (key string) { diff --git a/pkg/util/util/util_test.go b/pkg/util/util/util_test.go index 311732b4..cf76485b 100644 --- a/pkg/util/util/util_test.go +++ b/pkg/util/util/util_test.go @@ -14,10 +14,51 @@ func TestRandId(t *testing.T) { assert.Equal(16, len(id)) } +func TestRandIDWithRandLen(t *testing.T) { + tests := []struct { + name string + start int + end int + expectErr bool + }{ + { + name: "start and end are equal", + start: 5, + end: 5, + expectErr: true, + }, + { + name: "start is less than end", + start: 5, + end: 10, + expectErr: false, + }, + { + name: "start is greater than end", + start: 10, + end: 5, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + id, err := RandIDWithRandLen(tt.start, tt.end) + if tt.expectErr { + assert.Error(err) + } else { + assert.NoError(err) + assert.GreaterOrEqual(len(id), tt.start) + assert.Less(len(id), tt.end) + } + }) + } +} + func TestGetAuthKey(t *testing.T) { assert := assert.New(t) key := GetAuthKey("1234", 1488720000) - t.Log(key) assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) } diff --git a/server/control.go b/server/control.go index e80e4242..f905e74c 100644 --- a/server/control.go +++ b/server/control.go @@ -33,6 +33,7 @@ import ( frpErr "github.com/fatedier/frp/pkg/errors" "github.com/fatedier/frp/pkg/msg" plugin "github.com/fatedier/frp/pkg/plugin/server" + "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/xlog" @@ -82,6 +83,16 @@ func (cm *ControlManager) GetByID(runID string) (ctl *Control, ok bool) { return } +func (cm *ControlManager) Close() error { + cm.mu.Lock() + defer cm.mu.Unlock() + for _, ctl := range cm.ctlsByRunID { + ctl.Close() + } + cm.ctlsByRunID = make(map[string]*Control) + return nil +} + type Control struct { // all resource managers and controllers rc *controller.ResourceController @@ -95,6 +106,9 @@ type Control struct { // verifies authentication based on selected method authVerifier auth.Verifier + // other components can use this to communicate with client + msgTransporter transport.MessageTransporter + // login message loginMsg *msg.Login @@ -158,7 +172,7 @@ func NewControl( if poolCount > int(serverCfg.MaxPoolCount) { poolCount = int(serverCfg.MaxPoolCount) } - return &Control{ + ctl := &Control{ rc: rc, pxyManager: pxyManager, pluginManager: pluginManager, @@ -182,6 +196,8 @@ func NewControl( xl: xlog.FromContextSafe(ctx), ctx: ctx, } + ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) + return ctl } // Start send a login success message to client and start working. @@ -204,6 +220,18 @@ func (ctl *Control) Start() { go ctl.stoper() } +func (ctl *Control) Close() error { + ctl.allShutdown.Start() + return nil +} + +func (ctl *Control) Replaced(newCtl *Control) { + xl := ctl.xl + xl.Info("Replaced by client [%s]", newCtl.runID) + ctl.runID = "" + ctl.allShutdown.Start() +} + func (ctl *Control) RegisterWorkConn(conn net.Conn) error { xl := ctl.xl defer func() { @@ -275,13 +303,6 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { return } -func (ctl *Control) Replaced(newCtl *Control) { - xl := ctl.xl - xl.Info("Replaced by client [%s]", newCtl.runID) - ctl.runID = "" - ctl.allShutdown.Start() -} - func (ctl *Control) writer() { xl := ctl.xl defer func() { @@ -465,6 +486,12 @@ func (ctl *Control) manager() { metrics.Server.NewProxy(m.ProxyName, m.ProxyType) } ctl.sendCh <- resp + case *msg.NatHoleVisitor: + go ctl.HandleNatHoleVisitor(m) + case *msg.NatHoleClient: + go ctl.HandleNatHoleClient(m) + case *msg.NatHoleReport: + go ctl.HandleNatHoleReport(m) case *msg.CloseProxy: _ = ctl.CloseProxy(m) xl.Info("close proxy [%s] success", m.ProxyName) @@ -497,6 +524,18 @@ func (ctl *Control) manager() { } } +func (ctl *Control) HandleNatHoleVisitor(m *msg.NatHoleVisitor) { + ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter) +} + +func (ctl *Control) HandleNatHoleClient(m *msg.NatHoleClient) { + ctl.rc.NatHoleController.HandleClient(m, ctl.msgTransporter) +} + +func (ctl *Control) HandleNatHoleReport(m *msg.NatHoleReport) { + ctl.rc.NatHoleController.HandleReport(m) +} + func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { var pxyConf config.ProxyConf // Load configures from NewProxy message and check. diff --git a/server/proxy/xtcp.go b/server/proxy/xtcp.go index b6c7be3a..8b2717bf 100644 --- a/server/proxy/xtcp.go +++ b/server/proxy/xtcp.go @@ -44,41 +44,20 @@ func (pxy *XTCPProxy) Run() (remoteAddr string, err error) { for { select { case <-pxy.closeCh: - break - case sidRequest := <-sidCh: - sr := sidRequest + return + case sid := <-sidCh: workConn, errRet := pxy.GetWorkConnFromPool(nil, nil) if errRet != nil { continue } m := &msg.NatHoleSid{ - Sid: sr.Sid, + Sid: sid, } errRet = msg.WriteMsg(workConn, m) if errRet != nil { xl.Warn("write nat hole sid package error, %v", errRet) - workConn.Close() - break } - - go func() { - raw, errRet := msg.ReadMsg(workConn) - if errRet != nil { - xl.Warn("read nat hole client ok package error: %v", errRet) - workConn.Close() - return - } - if _, ok := raw.(*msg.NatHoleClientDetectOK); !ok { - xl.Warn("read nat hole client ok package format error") - workConn.Close() - return - } - - select { - case sr.NotifyCh <- struct{}{}: - default: - } - }() + workConn.Close() } } }() diff --git a/server/service.go b/server/service.go index 4ce5f992..0cdd2666 100644 --- a/server/service.go +++ b/server/service.go @@ -99,6 +99,11 @@ type Service struct { tlsConfig *tls.Config cfg config.ServerCommonConf + + // service context + ctx context.Context + // call cancel to stop service + cancel context.CancelFunc } func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { @@ -110,6 +115,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { return } + ctx, cancel := context.WithCancel(context.Background()) svr = &Service{ ctlManager: NewControlManager(), pxyManager: proxy.NewManager(), @@ -123,6 +129,8 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { authVerifier: auth.NewAuthVerifier(cfg.ServerConfig), tlsConfig: tlsConfig, cfg: cfg, + ctx: ctx, + cancel: cancel, } // Create tcpmux httpconnect multiplexer. @@ -290,17 +298,12 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { }) // Create nat hole controller. - if cfg.BindUDPPort > 0 { - var nc *nathole.Controller - address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.BindUDPPort)) - nc, err = nathole.NewController(address, []byte(cfg.Token)) - if err != nil { - err = fmt.Errorf("create nat hole controller error, %v", err) - return - } - svr.rc.NatHoleController = nc - log.Info("nat hole udp service listen on %s", address) + nc, err := nathole.NewController(time.Duration(cfg.NatHoleAnalysisDataReserveHours) * time.Hour) + if err != nil { + err = fmt.Errorf("create nat hole controller error, %v", err) + return } + svr.rc.NatHoleController = nc var statsEnable bool // Create dashboard web server. @@ -327,22 +330,43 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { } func (svr *Service) Run() { - if svr.rc.NatHoleController != nil { - go svr.rc.NatHoleController.Run() - } if svr.kcpListener != nil { go svr.HandleListener(svr.kcpListener) } if svr.quicListener != nil { go svr.HandleQUICListener(svr.quicListener) } - go svr.HandleListener(svr.websocketListener) go svr.HandleListener(svr.tlsListener) + if svr.rc.NatHoleController != nil { + go svr.rc.NatHoleController.CleanWorker(svr.ctx) + } svr.HandleListener(svr.listener) } +func (svr *Service) Close() error { + if svr.kcpListener != nil { + svr.kcpListener.Close() + } + if svr.quicListener != nil { + svr.quicListener.Close() + } + if svr.websocketListener != nil { + svr.websocketListener.Close() + } + if svr.tlsListener != nil { + svr.tlsListener.Close() + } + if svr.listener != nil { + svr.listener.Close() + } + svr.cancel() + + svr.ctlManager.Close() + return nil +} + func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) { xl := xlog.FromContextSafe(ctx)