Code refactoring related to message handling and retry logic. (#3745)

This commit is contained in:
fatedier 2023-11-06 10:51:48 +08:00 committed by GitHub
parent 5760c1cf92
commit 184223cb2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 701 additions and 540 deletions

View File

@ -1,3 +1,4 @@
### Fixes ### Fixes
* frpc: Return code 1 when the first login attempt fails and exits. * frpc: Return code 1 when the first login attempt fails and exits.
* When auth.method is `oidc` and auth.additionalScopes contains `HeartBeats`, if obtaining AccessToken fails, the application will be unresponsive.

View File

@ -144,7 +144,14 @@ func (svr *Service) apiStatus(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write(buf) _, _ = w.Write(buf)
}() }()
ps := svr.ctl.pm.GetAllProxyStatus() svr.ctlMu.RLock()
ctl := svr.ctl
svr.ctlMu.RUnlock()
if ctl == nil {
return
}
ps := ctl.pm.GetAllProxyStatus()
for _, status := range ps { for _, status := range ps {
res[status.Type] = append(res[status.Type], NewProxyStatusResp(status, svr.cfg.ServerAddr)) res[status.Type] = append(res[status.Type], NewProxyStatusResp(status, svr.cfg.ServerAddr))
} }

View File

@ -16,13 +16,10 @@ package client
import ( import (
"context" "context"
"io"
"net" "net"
"runtime/debug" "sync/atomic"
"time" "time"
"github.com/fatedier/golib/control/shutdown"
"github.com/fatedier/golib/crypto"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/fatedier/frp/client/proxy" "github.com/fatedier/frp/client/proxy"
@ -31,6 +28,8 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -39,6 +38,12 @@ type Control struct {
ctx context.Context ctx context.Context
xl *xlog.Logger xl *xlog.Logger
// The client configuration
clientCfg *v1.ClientCommonConfig
// sets authentication based on selected method
authSetter auth.Setter
// Unique ID obtained from frps. // Unique ID obtained from frps.
// It should be attached to the login message when reconnecting. // It should be attached to the login message when reconnecting.
runID string runID string
@ -50,36 +55,25 @@ type Control struct {
// manage all visitors // manage all visitors
vm *visitor.Manager vm *visitor.Manager
// control connection // control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
conn net.Conn conn net.Conn
// use cm to create new connections, which could be real TCP connections or virtual streams.
cm *ConnectionManager cm *ConnectionManager
// put a message in this channel to send it over control connection to server doneCh chan struct{}
sendCh chan (msg.Message)
// read from this channel to get the next message sent by server // of time.Time, last time got the Pong message
readCh chan (msg.Message) lastPong atomic.Value
// goroutines can block by reading from this channel, it will be closed only in reader() when control connection is closed
closedCh chan struct{}
closedDoneCh chan struct{}
// last time got the Pong message
lastPong time.Time
// The client configuration
clientCfg *v1.ClientCommonConfig
readerShutdown *shutdown.Shutdown
writerShutdown *shutdown.Shutdown
msgHandlerShutdown *shutdown.Shutdown
// sets authentication based on selected method
authSetter auth.Setter
// The role of msgTransporter is similar to HTTP2.
// It allows multiple messages to be sent simultaneously on the same control connection.
// The server's response messages will be dispatched to the corresponding waiting goroutines based on the laneKey and message type.
msgTransporter transport.MessageTransporter msgTransporter transport.MessageTransporter
// msgDispatcher is a wrapper for control connection.
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher
} }
func NewControl( func NewControl(
@ -88,31 +82,34 @@ func NewControl(
pxyCfgs []v1.ProxyConfigurer, pxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer, visitorCfgs []v1.VisitorConfigurer,
authSetter auth.Setter, authSetter auth.Setter,
) *Control { ) (*Control, error) {
// new xlog instance // new xlog instance
ctl := &Control{ ctl := &Control{
ctx: ctx, ctx: ctx,
xl: xlog.FromContextSafe(ctx), xl: xlog.FromContextSafe(ctx),
clientCfg: clientCfg,
authSetter: authSetter,
runID: runID, runID: runID,
pxyCfgs: pxyCfgs,
conn: conn, conn: conn,
cm: cm, cm: cm,
pxyCfgs: pxyCfgs, doneCh: make(chan struct{}),
sendCh: make(chan msg.Message, 100),
readCh: make(chan msg.Message, 100),
closedCh: make(chan struct{}),
closedDoneCh: make(chan struct{}),
clientCfg: clientCfg,
readerShutdown: shutdown.New(),
writerShutdown: shutdown.New(),
msgHandlerShutdown: shutdown.New(),
authSetter: authSetter,
} }
ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) ctl.lastPong.Store(time.Now())
ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
cryptoRW, err := utilnet.NewCryptoReadWriter(conn, []byte(clientCfg.Auth.Token))
if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter)
ctl.vm.Reload(visitorCfgs) ctl.vm.Reload(visitorCfgs)
return ctl return ctl, nil
} }
func (ctl *Control) Run() { func (ctl *Control) Run() {
@ -125,7 +122,7 @@ func (ctl *Control) Run() {
go ctl.vm.Run() go ctl.vm.Run()
} }
func (ctl *Control) HandleReqWorkConn(_ *msg.ReqWorkConn) { func (ctl *Control) handleReqWorkConn(_ msg.Message) {
xl := ctl.xl xl := ctl.xl
workConn, err := ctl.connectServer() workConn, err := ctl.connectServer()
if err != nil { if err != nil {
@ -162,8 +159,9 @@ func (ctl *Control) HandleReqWorkConn(_ *msg.ReqWorkConn) {
ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg) ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg)
} }
func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { func (ctl *Control) handleNewProxyResp(m msg.Message) {
xl := ctl.xl xl := ctl.xl
inMsg := m.(*msg.NewProxyResp)
// Server will return NewProxyResp message to each NewProxy message. // Server will return NewProxyResp message to each NewProxy message.
// Start a new proxy handler if no error got // Start a new proxy handler if no error got
err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)
@ -174,8 +172,9 @@ func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) {
} }
} }
func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) { func (ctl *Control) handleNatHoleResp(m msg.Message) {
xl := ctl.xl xl := ctl.xl
inMsg := m.(*msg.NatHoleResp)
// Dispatch the NatHoleResp message to the related proxy. // Dispatch the NatHoleResp message to the related proxy.
ok := ctl.msgTransporter.DispatchWithType(inMsg, msg.TypeNameNatHoleResp, inMsg.TransactionID) ok := ctl.msgTransporter.DispatchWithType(inMsg, msg.TypeNameNatHoleResp, inMsg.TransactionID)
@ -184,6 +183,19 @@ func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) {
} }
} }
func (ctl *Control) handlePong(m msg.Message) {
xl := ctl.xl
inMsg := m.(*msg.Pong)
if inMsg.Error != "" {
xl.Error("Pong message contains error: %s", inMsg.Error)
ctl.conn.Close()
return
}
ctl.lastPong.Store(time.Now())
xl.Debug("receive heartbeat from server")
}
func (ctl *Control) Close() error { func (ctl *Control) Close() error {
return ctl.GracefulClose(0) return ctl.GracefulClose(0)
} }
@ -199,9 +211,9 @@ func (ctl *Control) GracefulClose(d time.Duration) error {
return nil return nil
} }
// ClosedDoneCh returns a channel that will be closed after all resources are released // Done returns a channel that will be closed after all resources are released
func (ctl *Control) ClosedDoneCh() <-chan struct{} { func (ctl *Control) Done() <-chan struct{} {
return ctl.closedDoneCh return ctl.doneCh
} }
// connectServer return a new connection to frps // connectServer return a new connection to frps
@ -209,151 +221,70 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
return ctl.cm.Connect() return ctl.cm.Connect()
} }
// reader read all messages from frps and send to readCh func (ctl *Control) registerMsgHandlers() {
func (ctl *Control) reader() { ctl.msgDispatcher.RegisterHandler(&msg.ReqWorkConn{}, msg.AsyncHandler(ctl.handleReqWorkConn))
ctl.msgDispatcher.RegisterHandler(&msg.NewProxyResp{}, ctl.handleNewProxyResp)
ctl.msgDispatcher.RegisterHandler(&msg.NatHoleResp{}, ctl.handleNatHoleResp)
ctl.msgDispatcher.RegisterHandler(&msg.Pong{}, ctl.handlePong)
}
// headerWorker sends heartbeat to server and check heartbeat timeout.
func (ctl *Control) heartbeatWorker() {
xl := ctl.xl xl := ctl.xl
defer func() {
if err := recover(); err != nil {
xl.Error("panic error: %v", err)
xl.Error(string(debug.Stack()))
}
}()
defer ctl.readerShutdown.Done()
defer close(ctl.closedCh)
encReader := crypto.NewReader(ctl.conn, []byte(ctl.clientCfg.Auth.Token)) // TODO(fatedier): Change default value of HeartbeatInterval to -1 if tcpmux is enabled.
for { // Users can still enable heartbeat feature by setting HeartbeatInterval to a positive value.
m, err := msg.ReadMsg(encReader)
if err != nil {
if err == io.EOF {
xl.Debug("read from control connection EOF")
return
}
xl.Warn("read error: %v", err)
ctl.conn.Close()
return
}
ctl.readCh <- m
}
}
// writer writes messages got from sendCh to frps
func (ctl *Control) writer() {
xl := ctl.xl
defer ctl.writerShutdown.Done()
encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.clientCfg.Auth.Token))
if err != nil {
xl.Error("crypto new writer error: %v", err)
ctl.conn.Close()
return
}
for {
m, ok := <-ctl.sendCh
if !ok {
xl.Info("control writer is closing")
return
}
if err := msg.WriteMsg(encWriter, m); err != nil {
xl.Warn("write message to control connection error: %v", err)
return
}
}
}
// msgHandler handles all channel events and performs corresponding operations.
func (ctl *Control) msgHandler() {
xl := ctl.xl
defer func() {
if err := recover(); err != nil {
xl.Error("panic error: %v", err)
xl.Error(string(debug.Stack()))
}
}()
defer ctl.msgHandlerShutdown.Done()
var hbSendCh <-chan time.Time
// TODO(fatedier): disable heartbeat if TCPMux is enabled.
// Just keep it here to keep compatible with old version frps.
if ctl.clientCfg.Transport.HeartbeatInterval > 0 { if ctl.clientCfg.Transport.HeartbeatInterval > 0 {
hbSend := time.NewTicker(time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second)
defer hbSend.Stop()
hbSendCh = hbSend.C
}
var hbCheckCh <-chan time.Time
// Check heartbeat timeout only if TCPMux is not enabled and users don't disable heartbeat feature.
if ctl.clientCfg.Transport.HeartbeatInterval > 0 && ctl.clientCfg.Transport.HeartbeatTimeout > 0 &&
!lo.FromPtr(ctl.clientCfg.Transport.TCPMux) {
hbCheck := time.NewTicker(time.Second)
defer hbCheck.Stop()
hbCheckCh = hbCheck.C
}
ctl.lastPong = time.Now()
for {
select {
case <-hbSendCh:
// send heartbeat to server // send heartbeat to server
sendHeartBeat := func() error {
xl.Debug("send heartbeat to server") xl.Debug("send heartbeat to server")
pingMsg := &msg.Ping{} pingMsg := &msg.Ping{}
if err := ctl.authSetter.SetPing(pingMsg); err != nil { if err := ctl.authSetter.SetPing(pingMsg); err != nil {
xl.Warn("error during ping authentication: %v. skip sending ping message", err) xl.Warn("error during ping authentication: %v, skip sending ping message", err)
continue return err
} }
ctl.sendCh <- pingMsg _ = ctl.msgDispatcher.Send(pingMsg)
case <-hbCheckCh: return nil
if time.Since(ctl.lastPong) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second { }
go wait.BackoffUntil(sendHeartBeat,
wait.NewFastBackoffManager(wait.FastBackoffOptions{
Duration: time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second,
InitDurationIfFail: time.Second,
Factor: 2.0,
Jitter: 0.1,
MaxDuration: time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second,
}),
true, ctl.doneCh,
)
}
// Check heartbeat timeout only if TCPMux is not enabled and users don't disable heartbeat feature.
if ctl.clientCfg.Transport.HeartbeatInterval > 0 && ctl.clientCfg.Transport.HeartbeatTimeout > 0 &&
!lo.FromPtr(ctl.clientCfg.Transport.TCPMux) {
go wait.Until(func() {
if time.Since(ctl.lastPong.Load().(time.Time)) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warn("heartbeat timeout") xl.Warn("heartbeat timeout")
// let reader() stop
ctl.conn.Close() ctl.conn.Close()
return return
} }
case rawMsg, ok := <-ctl.readCh: }, time.Second, ctl.doneCh)
if !ok {
return
}
switch m := rawMsg.(type) {
case *msg.ReqWorkConn:
go ctl.HandleReqWorkConn(m)
case *msg.NewProxyResp:
ctl.HandleNewProxyResp(m)
case *msg.NatHoleResp:
ctl.HandleNatHoleResp(m)
case *msg.Pong:
if m.Error != "" {
xl.Error("Pong contains error: %s", m.Error)
ctl.conn.Close()
return
}
ctl.lastPong = time.Now()
xl.Debug("receive heartbeat from server")
}
}
} }
} }
// If controler is notified by closedCh, reader and writer and handler will exit
func (ctl *Control) worker() { func (ctl *Control) worker() {
go ctl.msgHandler() go ctl.heartbeatWorker()
go ctl.reader() go ctl.msgDispatcher.Run()
go ctl.writer()
<-ctl.closedCh <-ctl.msgDispatcher.Done()
// close related channels and wait until other goroutines done ctl.conn.Close()
close(ctl.readCh)
ctl.readerShutdown.WaitDone()
ctl.msgHandlerShutdown.WaitDone()
close(ctl.sendCh)
ctl.writerShutdown.WaitDone()
ctl.pm.Close() ctl.pm.Close()
ctl.vm.Close() ctl.vm.Close()
close(ctl.closedDoneCh)
ctl.cm.Close() ctl.cm.Close()
close(ctl.doneCh)
} }
func (ctl *Control) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error { func (ctl *Control) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {

View File

@ -17,6 +17,7 @@ package client
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -24,7 +25,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/fatedier/golib/crypto" "github.com/fatedier/golib/crypto"
@ -40,8 +40,8 @@ import (
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -70,12 +70,11 @@ type Service struct {
// string if no configuration file was used. // string if no configuration file was used.
cfgFile string cfgFile string
exit uint32 // 0 means not exit
// service context // service context
ctx context.Context ctx context.Context
// call cancel to stop service // call cancel to stop service
cancel context.CancelFunc cancel context.CancelFunc
gracefulDuration time.Duration
} }
func NewService( func NewService(
@ -91,7 +90,6 @@ func NewService(
pxyCfgs: pxyCfgs, pxyCfgs: pxyCfgs,
visitorCfgs: visitorCfgs, visitorCfgs: visitorCfgs,
ctx: context.Background(), ctx: context.Background(),
exit: 0,
} }
} }
@ -106,8 +104,6 @@ func (svr *Service) Run(ctx context.Context) error {
svr.ctx = xlog.NewContext(ctx, xlog.New()) svr.ctx = xlog.NewContext(ctx, xlog.New())
svr.cancel = cancel svr.cancel = cancel
xl := xlog.FromContextSafe(svr.ctx)
// set custom DNSServer // set custom DNSServer
if svr.cfg.DNSServer != "" { if svr.cfg.DNSServer != "" {
dnsAddr := svr.cfg.DNSServer dnsAddr := svr.cfg.DNSServer
@ -124,26 +120,9 @@ func (svr *Service) Run(ctx context.Context) error {
} }
// login to frps // login to frps
for { svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.cfg.LoginFailExit))
conn, cm, err := svr.login() if svr.ctl == nil {
if err != nil { return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled")
xl.Warn("login to server failed: %v", err)
// if login_fail_exit is true, just exit this program
// otherwise sleep a while and try again to connect to server
if lo.FromPtr(svr.cfg.LoginFailExit) {
return err
}
util.RandomSleep(5*time.Second, 0.9, 1.1)
} else {
// login success
ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
ctl.Run()
svr.ctlMu.Lock()
svr.ctl = ctl
svr.ctlMu.Unlock()
break
}
} }
go svr.keepControllerWorking() go svr.keepControllerWorking()
@ -160,80 +139,35 @@ func (svr *Service) Run(ctx context.Context) error {
log.Info("admin server listen on %s:%d", svr.cfg.WebServer.Addr, svr.cfg.WebServer.Port) log.Info("admin server listen on %s:%d", svr.cfg.WebServer.Addr, svr.cfg.WebServer.Port)
} }
<-svr.ctx.Done() <-svr.ctx.Done()
// service context may not be canceled by svr.Close(), we should call it here to release resources svr.stop()
if atomic.LoadUint32(&svr.exit) == 0 {
svr.Close()
}
return nil return nil
} }
func (svr *Service) keepControllerWorking() { func (svr *Service) keepControllerWorking() {
xl := xlog.FromContextSafe(svr.ctx) <-svr.ctl.Done()
maxDelayTime := 20 * time.Second
delayTime := time.Second
// if frpc reconnect frps, we need to limit retry times in 1min // There is a situation where the login is successful but due to certain reasons,
// current retry logic is sleep 0s, 0s, 0s, 1s, 2s, 4s, 8s, ... // the control immediately exits. It is necessary to limit the frequency of reconnection in this case.
// when exceed 1min, we will reset delay and counts // The interval for the first three retries in 1 minute will be very short, and then it will increase exponentially.
cutoffTime := time.Now().Add(time.Minute) // The maximum interval is 20 seconds.
reconnectDelay := time.Second wait.BackoffUntil(func() error {
reconnectCounts := 1 // loopLoginUntilSuccess is another layer of loop that will continuously attempt to
// login to the server until successful.
for { svr.loopLoginUntilSuccess(20*time.Second, false)
<-svr.ctl.ClosedDoneCh() <-svr.ctl.Done()
if atomic.LoadUint32(&svr.exit) != 0 { return errors.New("control is closed and try another loop")
return }, wait.NewFastBackoffManager(
} wait.FastBackoffOptions{
Duration: time.Second,
// the first three attempts with a low delay Factor: 2,
if reconnectCounts > 3 { Jitter: 0.1,
util.RandomSleep(reconnectDelay, 0.9, 1.1) MaxDuration: 20 * time.Second,
xl.Info("wait %v to reconnect", reconnectDelay) FastRetryCount: 3,
reconnectDelay *= 2 FastRetryDelay: 200 * time.Millisecond,
} else { FastRetryWindow: time.Minute,
util.RandomSleep(time.Second, 0, 0.5) FastRetryJitter: 0.5,
} },
reconnectCounts++ ), true, svr.ctx.Done())
now := time.Now()
if now.After(cutoffTime) {
// reset
cutoffTime = now.Add(time.Minute)
reconnectDelay = time.Second
reconnectCounts = 1
}
for {
if atomic.LoadUint32(&svr.exit) != 0 {
return
}
xl.Info("try to reconnect to server...")
conn, cm, err := svr.login()
if err != nil {
xl.Warn("reconnect to server error: %v, wait %v for another retry", err, delayTime)
util.RandomSleep(delayTime, 0.9, 1.1)
delayTime *= 2
if delayTime > maxDelayTime {
delayTime = maxDelayTime
}
continue
}
// reconnect success, init delayTime
delayTime = time.Second
ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
ctl.Run()
svr.ctlMu.Lock()
if svr.ctl != nil {
svr.ctl.Close()
}
svr.ctl = ctl
svr.ctlMu.Unlock()
break
}
}
} }
// login creates a connection to frps and registers it self as a client // login creates a connection to frps and registers it self as a client
@ -299,6 +233,54 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
return return
} }
func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
xl := xlog.FromContextSafe(svr.ctx)
successCh := make(chan struct{})
loginFunc := func() error {
xl.Info("try to connect to server...")
conn, cm, err := svr.login()
if err != nil {
xl.Warn("connect to server error: %v", err)
if firstLoginExit {
svr.cancel()
}
return err
}
ctl, err := NewControl(svr.ctx, svr.runID, conn, cm,
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
if err != nil {
conn.Close()
xl.Error("NewControl error: %v", err)
return err
}
ctl.Run()
// close and replace previous control
svr.ctlMu.Lock()
if svr.ctl != nil {
svr.ctl.Close()
}
svr.ctl = ctl
svr.ctlMu.Unlock()
close(successCh)
return nil
}
// try to reconnect to server until success
wait.BackoffUntil(loginFunc, wait.NewFastBackoffManager(
wait.FastBackoffOptions{
Duration: time.Second,
Factor: 2,
Jitter: 0.1,
MaxDuration: maxInterval,
}),
true,
wait.MergeAndCloseOnAnyStopChannel(svr.ctx.Done(), successCh))
}
func (svr *Service) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error { func (svr *Service) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {
svr.cfgMu.Lock() svr.cfgMu.Lock()
svr.pxyCfgs = pxyCfgs svr.pxyCfgs = pxyCfgs
@ -320,20 +302,20 @@ func (svr *Service) Close() {
} }
func (svr *Service) GracefulClose(d time.Duration) { func (svr *Service) GracefulClose(d time.Duration) {
atomic.StoreUint32(&svr.exit, 1) svr.gracefulDuration = d
svr.ctlMu.RLock()
if svr.ctl != nil {
svr.ctl.GracefulClose(d)
svr.ctl = nil
}
svr.ctlMu.RUnlock()
if svr.cancel != nil {
svr.cancel() svr.cancel()
} }
func (svr *Service) stop() {
svr.ctlMu.Lock()
defer svr.ctlMu.Unlock()
if svr.ctl != nil {
svr.ctl.GracefulClose(svr.gracefulDuration)
svr.ctl = nil
}
} }
// ConnectionManager is a wrapper for establishing connections to the server.
type ConnectionManager struct { type ConnectionManager struct {
ctx context.Context ctx context.Context
cfg *v1.ClientCommonConfig cfg *v1.ClientCommonConfig
@ -349,6 +331,10 @@ func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *Conn
} }
} }
// OpenConnection opens a underlying connection to the server.
// The underlying connection is either a TCP connection or a QUIC connection.
// After the underlying connection is established, you can call Connect() to get a stream.
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
func (cm *ConnectionManager) OpenConnection() error { func (cm *ConnectionManager) OpenConnection() error {
xl := xlog.FromContextSafe(cm.ctx) xl := xlog.FromContextSafe(cm.ctx)
@ -411,6 +397,7 @@ func (cm *ConnectionManager) OpenConnection() error {
return nil return nil
} }
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
func (cm *ConnectionManager) Connect() (net.Conn, error) { func (cm *ConnectionManager) Connect() (net.Conn, error) {
if cm.quicConn != nil { if cm.quicConn != nil {
stream, err := cm.quicConn.OpenStreamSync(context.Background()) stream, err := cm.quicConn.OpenStreamSync(context.Background())

View File

@ -1,3 +1,17 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package metrics package metrics
import ( import (

103
pkg/msg/handler.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"io"
"reflect"
)
func AsyncHandler(f func(Message)) func(Message) {
return func(m Message) {
go f(m)
}
}
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
type Dispatcher struct {
rw io.ReadWriter
sendCh chan Message
doneCh chan struct{}
msgHandlers map[reflect.Type]func(Message)
defaultHandler func(Message)
}
func NewDispatcher(rw io.ReadWriter) *Dispatcher {
return &Dispatcher{
rw: rw,
sendCh: make(chan Message, 100),
doneCh: make(chan struct{}),
msgHandlers: make(map[reflect.Type]func(Message)),
}
}
// Run will block until io.EOF or some error occurs.
func (d *Dispatcher) Run() {
go d.sendLoop()
go d.readLoop()
}
func (d *Dispatcher) sendLoop() {
for {
select {
case <-d.doneCh:
return
case m := <-d.sendCh:
_ = WriteMsg(d.rw, m)
}
}
}
func (d *Dispatcher) readLoop() {
for {
m, err := ReadMsg(d.rw)
if err != nil {
close(d.doneCh)
return
}
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
handler(m)
} else if d.defaultHandler != nil {
d.defaultHandler(m)
}
}
}
func (d *Dispatcher) Send(m Message) error {
select {
case <-d.doneCh:
return io.EOF
case d.sendCh <- m:
return nil
}
}
func (d *Dispatcher) SendChannel() chan Message {
return d.sendCh
}
func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
d.msgHandlers[reflect.TypeOf(msg)] = handler
}
func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
d.defaultHandler = handler
}
func (d *Dispatcher) Done() chan struct{} {
return d.doneCh
}

View File

@ -29,7 +29,9 @@ type MessageTransporter interface {
// Recv(ctx context.Context, laneKey string, msgType string) (Message, error) // Recv(ctx context.Context, laneKey string, msgType string) (Message, error)
// Do will first send msg, then recv msg with the same laneKey and specified msgType. // Do will first send msg, then recv msg with the same laneKey and specified msgType.
Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error) Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error)
// Dispatch will dispatch message to releated channel registered in Do function by its message type and laneKey.
Dispatch(m msg.Message, laneKey string) bool Dispatch(m msg.Message, laneKey string) bool
// Same with Dispatch but with specified message type.
DispatchWithType(m msg.Message, msgType, laneKey string) bool DispatchWithType(m msg.Message, msgType, laneKey string) bool
} }

View File

@ -22,6 +22,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/fatedier/golib/crypto"
quic "github.com/quic-go/quic-go" quic "github.com/quic-go/quic-go"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
@ -216,3 +217,18 @@ func (conn *wrapQuicStream) Close() error {
conn.Stream.CancelRead(0) conn.Stream.CancelRead(0)
return conn.Stream.Close() return conn.Stream.Close()
} }
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
encReader := crypto.NewReader(rw, key)
encWriter, err := crypto.NewWriter(rw, key)
if err != nil {
return nil, err
}
return struct {
io.Reader
io.Writer
}{
Reader: encReader,
Writer: encWriter,
}, nil
}

197
pkg/util/wait/backoff.go Normal file
View File

@ -0,0 +1,197 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wait
import (
"math/rand"
"time"
"github.com/samber/lo"
"github.com/fatedier/frp/pkg/util/util"
)
type BackoffFunc func(previousDuration time.Duration, previousConditionError bool) time.Duration
func (f BackoffFunc) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
return f(previousDuration, previousConditionError)
}
type BackoffManager interface {
Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration
}
type FastBackoffOptions struct {
Duration time.Duration
Factor float64
Jitter float64
MaxDuration time.Duration
InitDurationIfFail time.Duration
// If FastRetryCount > 0, then within the FastRetryWindow time window,
// the retry will be performed with a delay of FastRetryDelay for the first FastRetryCount calls.
FastRetryCount int
FastRetryDelay time.Duration
FastRetryJitter float64
FastRetryWindow time.Duration
}
type fastBackoffImpl struct {
options FastBackoffOptions
lastCalledTime time.Time
consecutiveErrCount int
fastRetryCutoffTime time.Time
countsInFastRetryWindow int
}
func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
return &fastBackoffImpl{
options: options,
countsInFastRetryWindow: 1,
}
}
func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
if f.lastCalledTime.IsZero() {
f.lastCalledTime = time.Now()
return f.options.Duration
}
now := time.Now()
f.lastCalledTime = now
if previousConditionError {
f.consecutiveErrCount++
} else {
f.consecutiveErrCount = 0
}
if f.options.FastRetryCount > 0 && previousConditionError {
f.countsInFastRetryWindow++
if f.countsInFastRetryWindow <= f.options.FastRetryCount {
return Jitter(f.options.FastRetryDelay, f.options.FastRetryJitter)
}
if now.After(f.fastRetryCutoffTime) {
// reset
f.fastRetryCutoffTime = now.Add(f.options.FastRetryWindow)
f.countsInFastRetryWindow = 0
}
}
if previousConditionError {
var duration time.Duration
if f.consecutiveErrCount == 1 {
duration = util.EmptyOr(f.options.InitDurationIfFail, previousDuration)
} else {
duration = previousDuration
}
duration = util.EmptyOr(duration, time.Second)
if f.options.Factor != 0 {
duration = time.Duration(float64(duration) * f.options.Factor)
}
if f.options.Jitter > 0 {
duration = Jitter(duration, f.options.Jitter)
}
if f.options.MaxDuration > 0 && duration > f.options.MaxDuration {
duration = f.options.MaxDuration
}
return duration
}
return f.options.Duration
}
func BackoffUntil(f func() error, backoff BackoffManager, sliding bool, stopCh <-chan struct{}) {
var delay time.Duration
previousError := false
ticker := time.NewTicker(backoff.Backoff(delay, previousError))
defer ticker.Stop()
for {
select {
case <-stopCh:
return
default:
}
if !sliding {
delay = backoff.Backoff(delay, previousError)
}
if err := f(); err != nil {
previousError = true
} else {
previousError = false
}
if sliding {
delay = backoff.Backoff(delay, previousError)
}
ticker.Reset(delay)
select {
case <-stopCh:
return
default:
}
select {
case <-stopCh:
return
case <-ticker.C:
}
}
}
// Jitter returns a time.Duration between duration and duration + maxFactor *
// duration.
//
// This allows clients to avoid converging on periodic behavior. If maxFactor
// is 0.0, a suggested default value will be chosen.
func Jitter(duration time.Duration, maxFactor float64) time.Duration {
if maxFactor <= 0.0 {
maxFactor = 1.0
}
wait := duration + time.Duration(rand.Float64()*maxFactor*float64(duration))
return wait
}
func Until(f func(), period time.Duration, stopCh <-chan struct{}) {
ff := func() error {
f()
return nil
}
BackoffUntil(ff, BackoffFunc(func(time.Duration, bool) time.Duration {
return period
}), true, stopCh)
}
func MergeAndCloseOnAnyStopChannel[T any](upstreams ...<-chan T) <-chan T {
out := make(chan T)
for _, upstream := range upstreams {
ch := upstream
go lo.Try0(func() {
select {
case <-ch:
close(out)
case <-out:
}
})
}
return out
}

View File

@ -17,15 +17,12 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"runtime/debug" "runtime/debug"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/fatedier/golib/control/shutdown"
"github.com/fatedier/golib/crypto"
"github.com/fatedier/golib/errors"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/fatedier/frp/pkg/auth" "github.com/fatedier/frp/pkg/auth"
@ -35,8 +32,10 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/server/controller" "github.com/fatedier/frp/server/controller"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
@ -111,18 +110,16 @@ type Control struct {
// other components can use this to communicate with client // other components can use this to communicate with client
msgTransporter transport.MessageTransporter msgTransporter transport.MessageTransporter
// msgDispatcher is a wrapper for control connection.
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher
// login message // login message
loginMsg *msg.Login loginMsg *msg.Login
// control connection // control connection
conn net.Conn conn net.Conn
// put a message in this channel to send it over control connection to client
sendCh chan (msg.Message)
// read from this channel to get the next message sent by client
readCh chan (msg.Message)
// work connections // work connections
workConnCh chan net.Conn workConnCh chan net.Conn
@ -136,20 +133,13 @@ type Control struct {
portsUsedNum int portsUsedNum int
// last time got the Ping message // last time got the Ping message
lastPing time.Time lastPing atomic.Value
// A new run id will be generated when a new client login. // A new run id will be generated when a new client login.
// If run id got from login message has same run id, it means it's the same client, so we can // If run id got from login message has same run id, it means it's the same client, so we can
// replace old controller instantly. // replace old controller instantly.
runID string runID string
readerShutdown *shutdown.Shutdown
writerShutdown *shutdown.Shutdown
managerShutdown *shutdown.Shutdown
allShutdown *shutdown.Shutdown
started bool
mu sync.RWMutex mu sync.RWMutex
// Server configuration information // Server configuration information
@ -157,6 +147,7 @@ type Control struct {
xl *xlog.Logger xl *xlog.Logger
ctx context.Context ctx context.Context
doneCh chan struct{}
} }
func NewControl( func NewControl(
@ -168,7 +159,7 @@ func NewControl(
ctlConn net.Conn, ctlConn net.Conn,
loginMsg *msg.Login, loginMsg *msg.Login,
serverCfg *v1.ServerConfig, serverCfg *v1.ServerConfig,
) *Control { ) (*Control, error) {
poolCount := loginMsg.PoolCount poolCount := loginMsg.PoolCount
if poolCount > int(serverCfg.Transport.MaxPoolCount) { if poolCount > int(serverCfg.Transport.MaxPoolCount) {
poolCount = int(serverCfg.Transport.MaxPoolCount) poolCount = int(serverCfg.Transport.MaxPoolCount)
@ -180,24 +171,26 @@ func NewControl(
authVerifier: authVerifier, authVerifier: authVerifier,
conn: ctlConn, conn: ctlConn,
loginMsg: loginMsg, loginMsg: loginMsg,
sendCh: make(chan msg.Message, 10),
readCh: make(chan msg.Message, 10),
workConnCh: make(chan net.Conn, poolCount+10), workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy), proxies: make(map[string]proxy.Proxy),
poolCount: poolCount, poolCount: poolCount,
portsUsedNum: 0, portsUsedNum: 0,
lastPing: time.Now(),
runID: loginMsg.RunID, runID: loginMsg.RunID,
readerShutdown: shutdown.New(),
writerShutdown: shutdown.New(),
managerShutdown: shutdown.New(),
allShutdown: shutdown.New(),
serverCfg: serverCfg, serverCfg: serverCfg,
xl: xlog.FromContextSafe(ctx), xl: xlog.FromContextSafe(ctx),
ctx: ctx, ctx: ctx,
doneCh: make(chan struct{}),
} }
ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) ctl.lastPing.Store(time.Now())
return ctl
cryptoRW, err := utilnet.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token))
if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
return ctl, nil
} }
// Start send a login success message to client and start working. // Start send a login success message to client and start working.
@ -208,27 +201,18 @@ func (ctl *Control) Start() {
Error: "", Error: "",
} }
_ = msg.WriteMsg(ctl.conn, loginRespMsg) _ = msg.WriteMsg(ctl.conn, loginRespMsg)
ctl.mu.Lock()
ctl.started = true
ctl.mu.Unlock()
go ctl.writer()
go func() { go func() {
for i := 0; i < ctl.poolCount; i++ { for i := 0; i < ctl.poolCount; i++ {
// ignore error here, that means that this control is closed // ignore error here, that means that this control is closed
_ = errors.PanicToError(func() { _ = ctl.msgDispatcher.Send(&msg.ReqWorkConn{})
ctl.sendCh <- &msg.ReqWorkConn{}
})
} }
}() }()
go ctl.worker()
go ctl.manager()
go ctl.reader()
go ctl.stoper()
} }
func (ctl *Control) Close() error { func (ctl *Control) Close() error {
ctl.allShutdown.Start() ctl.conn.Close()
return nil return nil
} }
@ -236,7 +220,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl xl := ctl.xl
xl.Info("Replaced by client [%s]", newCtl.runID) xl.Info("Replaced by client [%s]", newCtl.runID)
ctl.runID = "" ctl.runID = ""
ctl.allShutdown.Start() ctl.conn.Close()
} }
func (ctl *Control) RegisterWorkConn(conn net.Conn) error { func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@ -282,9 +266,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
xl.Debug("get work connection from pool") xl.Debug("get work connection from pool")
default: default:
// no work connections available in the poll, send message to frpc to get more // no work connections available in the poll, send message to frpc to get more
if err = errors.PanicToError(func() { if err := ctl.msgDispatcher.Send(&msg.ReqWorkConn{}); err != nil {
ctl.sendCh <- &msg.ReqWorkConn{}
}); err != nil {
return nil, fmt.Errorf("control is already closed") return nil, fmt.Errorf("control is already closed")
} }
@ -304,92 +286,39 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
} }
// When we get a work connection from pool, replace it with a new one. // When we get a work connection from pool, replace it with a new one.
_ = errors.PanicToError(func() { _ = ctl.msgDispatcher.Send(&msg.ReqWorkConn{})
ctl.sendCh <- &msg.ReqWorkConn{}
})
return return
} }
func (ctl *Control) writer() { func (ctl *Control) heartbeatWorker() {
xl := ctl.xl xl := ctl.xl
defer func() {
if err := recover(); err != nil {
xl.Error("panic error: %v", err)
xl.Error(string(debug.Stack()))
}
}()
defer ctl.allShutdown.Start() // Don't need application heartbeat if TCPMux is enabled,
defer ctl.writerShutdown.Done() // yamux will do same thing.
// TODO(fatedier): let default HeartbeatTimeout to -1 if TCPMux is enabled. Users can still set it to positive value to enable it.
encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) if !lo.FromPtr(ctl.serverCfg.Transport.TCPMux) && ctl.serverCfg.Transport.HeartbeatTimeout > 0 {
if err != nil { go wait.Until(func() {
xl.Error("crypto new writer error: %v", err) if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
ctl.allShutdown.Start() xl.Warn("heartbeat timeout")
return
}
for {
m, ok := <-ctl.sendCh
if !ok {
xl.Info("control writer is closing")
return
}
if err := msg.WriteMsg(encWriter, m); err != nil {
xl.Warn("write message to control connection error: %v", err)
return return
} }
}, time.Second, ctl.doneCh)
} }
} }
func (ctl *Control) reader() { // block until Control closed
func (ctl *Control) WaitClosed() {
<-ctl.doneCh
}
func (ctl *Control) worker() {
xl := ctl.xl xl := ctl.xl
defer func() {
if err := recover(); err != nil {
xl.Error("panic error: %v", err)
xl.Error(string(debug.Stack()))
}
}()
defer ctl.allShutdown.Start() go ctl.heartbeatWorker()
defer ctl.readerShutdown.Done() go ctl.msgDispatcher.Run()
encReader := crypto.NewReader(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) <-ctl.msgDispatcher.Done()
for {
m, err := msg.ReadMsg(encReader)
if err != nil {
if err == io.EOF {
xl.Debug("control connection closed")
return
}
xl.Warn("read error: %v", err)
ctl.conn.Close() ctl.conn.Close()
return
}
ctl.readCh <- m
}
}
func (ctl *Control) stoper() {
xl := ctl.xl
defer func() {
if err := recover(); err != nil {
xl.Error("panic error: %v", err)
xl.Error(string(debug.Stack()))
}
}()
ctl.allShutdown.WaitStart()
ctl.conn.Close()
ctl.readerShutdown.WaitDone()
close(ctl.readCh)
ctl.managerShutdown.WaitDone()
close(ctl.sendCh)
ctl.writerShutdown.WaitDone()
ctl.mu.Lock() ctl.mu.Lock()
defer ctl.mu.Unlock() defer ctl.mu.Unlock()
@ -419,136 +348,104 @@ func (ctl *Control) stoper() {
}() }()
} }
ctl.allShutdown.Done()
xl.Info("client exit success")
metrics.Server.CloseClient() metrics.Server.CloseClient()
xl.Info("client exit success")
close(ctl.doneCh)
} }
// block until Control closed func (ctl *Control) registerMsgHandlers() {
func (ctl *Control) WaitClosed() { ctl.msgDispatcher.RegisterHandler(&msg.NewProxy{}, ctl.handleNewProxy)
ctl.mu.RLock() ctl.msgDispatcher.RegisterHandler(&msg.Ping{}, ctl.handlePing)
started := ctl.started ctl.msgDispatcher.RegisterHandler(&msg.NatHoleVisitor{}, msg.AsyncHandler(ctl.handleNatHoleVisitor))
ctl.mu.RUnlock() ctl.msgDispatcher.RegisterHandler(&msg.NatHoleClient{}, msg.AsyncHandler(ctl.handleNatHoleClient))
ctl.msgDispatcher.RegisterHandler(&msg.NatHoleReport{}, msg.AsyncHandler(ctl.handleNatHoleReport))
if !started { ctl.msgDispatcher.RegisterHandler(&msg.CloseProxy{}, ctl.handleCloseProxy)
ctl.allShutdown.Done()
return
}
ctl.allShutdown.WaitDone()
} }
func (ctl *Control) manager() { func (ctl *Control) handleNewProxy(m msg.Message) {
xl := ctl.xl xl := ctl.xl
defer func() { inMsg := m.(*msg.NewProxy)
if err := recover(); err != nil {
xl.Error("panic error: %v", err)
xl.Error(string(debug.Stack()))
}
}()
defer ctl.allShutdown.Start()
defer ctl.managerShutdown.Done()
var heartbeatCh <-chan time.Time
// Don't need application heartbeat if TCPMux is enabled,
// yamux will do same thing.
if !lo.FromPtr(ctl.serverCfg.Transport.TCPMux) && ctl.serverCfg.Transport.HeartbeatTimeout > 0 {
heartbeat := time.NewTicker(time.Second)
defer heartbeat.Stop()
heartbeatCh = heartbeat.C
}
for {
select {
case <-heartbeatCh:
if time.Since(ctl.lastPing) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warn("heartbeat timeout")
return
}
case rawMsg, ok := <-ctl.readCh:
if !ok {
return
}
switch m := rawMsg.(type) {
case *msg.NewProxy:
content := &plugin.NewProxyContent{ content := &plugin.NewProxyContent{
User: plugin.UserInfo{ User: plugin.UserInfo{
User: ctl.loginMsg.User, User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas, Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID, RunID: ctl.loginMsg.RunID,
}, },
NewProxy: *m, NewProxy: *inMsg,
} }
var remoteAddr string var remoteAddr string
retContent, err := ctl.pluginManager.NewProxy(content) retContent, err := ctl.pluginManager.NewProxy(content)
if err == nil { if err == nil {
m = &retContent.NewProxy inMsg = &retContent.NewProxy
remoteAddr, err = ctl.RegisterProxy(m) remoteAddr, err = ctl.RegisterProxy(inMsg)
} }
// register proxy in this control // register proxy in this control
resp := &msg.NewProxyResp{ resp := &msg.NewProxyResp{
ProxyName: m.ProxyName, ProxyName: inMsg.ProxyName,
} }
if err != nil { if err != nil {
xl.Warn("new proxy [%s] type [%s] error: %v", m.ProxyName, m.ProxyType, err) xl.Warn("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", m.ProxyName), resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)) err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
} else { } else {
resp.RemoteAddr = remoteAddr resp.RemoteAddr = remoteAddr
xl.Info("new proxy [%s] type [%s] success", m.ProxyName, m.ProxyType) xl.Info("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
metrics.Server.NewProxy(m.ProxyName, m.ProxyType) metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType)
} }
ctl.sendCh <- resp _ = ctl.msgDispatcher.Send(resp)
case *msg.NatHoleVisitor: }
go ctl.HandleNatHoleVisitor(m)
case *msg.NatHoleClient: func (ctl *Control) handlePing(m msg.Message) {
go ctl.HandleNatHoleClient(m) xl := ctl.xl
case *msg.NatHoleReport: inMsg := m.(*msg.Ping)
go ctl.HandleNatHoleReport(m)
case *msg.CloseProxy:
_ = ctl.CloseProxy(m)
xl.Info("close proxy [%s] success", m.ProxyName)
case *msg.Ping:
content := &plugin.PingContent{ content := &plugin.PingContent{
User: plugin.UserInfo{ User: plugin.UserInfo{
User: ctl.loginMsg.User, User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas, Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID, RunID: ctl.loginMsg.RunID,
}, },
Ping: *m, Ping: *inMsg,
} }
retContent, err := ctl.pluginManager.Ping(content) retContent, err := ctl.pluginManager.Ping(content)
if err == nil { if err == nil {
m = &retContent.Ping inMsg = &retContent.Ping
err = ctl.authVerifier.VerifyPing(m) err = ctl.authVerifier.VerifyPing(inMsg)
} }
if err != nil { if err != nil {
xl.Warn("received invalid ping: %v", err) xl.Warn("received invalid ping: %v", err)
ctl.sendCh <- &msg.Pong{ _ = ctl.msgDispatcher.Send(&msg.Pong{
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)), Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
} })
return return
} }
ctl.lastPing = time.Now() ctl.lastPing.Store(time.Now())
xl.Debug("receive heartbeat") xl.Debug("receive heartbeat")
ctl.sendCh <- &msg.Pong{} _ = ctl.msgDispatcher.Send(&msg.Pong{})
}
}
}
} }
func (ctl *Control) HandleNatHoleVisitor(m *msg.NatHoleVisitor) { func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter, ctl.loginMsg.User) inMsg := m.(*msg.NatHoleVisitor)
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
} }
func (ctl *Control) HandleNatHoleClient(m *msg.NatHoleClient) { func (ctl *Control) handleNatHoleClient(m msg.Message) {
ctl.rc.NatHoleController.HandleClient(m, ctl.msgTransporter) inMsg := m.(*msg.NatHoleClient)
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
} }
func (ctl *Control) HandleNatHoleReport(m *msg.NatHoleReport) { func (ctl *Control) handleNatHoleReport(m msg.Message) {
ctl.rc.NatHoleController.HandleReport(m) inMsg := m.(*msg.NatHoleReport)
ctl.rc.NatHoleController.HandleReport(inMsg)
}
func (ctl *Control) handleCloseProxy(m msg.Message) {
xl := ctl.xl
inMsg := m.(*msg.CloseProxy)
_ = ctl.CloseProxy(inMsg)
xl.Info("close proxy [%s] success", inMsg.ProxyName)
} }
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {

View File

@ -516,13 +516,14 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
} }
} }
func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err error) { func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error {
// If client's RunID is empty, it's a new client, we just create a new controller. // If client's RunID is empty, it's a new client, we just create a new controller.
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one. // Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
var err error
if loginMsg.RunID == "" { if loginMsg.RunID == "" {
loginMsg.RunID, err = util.RandID() loginMsg.RunID, err = util.RandID()
if err != nil { if err != nil {
return return err
} }
} }
@ -534,11 +535,16 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err
ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch) ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch)
// Check auth. // Check auth.
if err = svr.authVerifier.VerifyLogin(loginMsg); err != nil { if err := svr.authVerifier.VerifyLogin(loginMsg); err != nil {
return return err
} }
ctl := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg) ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg)
if err != nil {
xl.Warn("create new controller error: %v", err)
// don't return detailed errors to client
return fmt.Errorf("unexpect error when creating new controller")
}
if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil { if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
oldCtl.WaitClosed() oldCtl.WaitClosed()
} }
@ -553,7 +559,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err
ctl.WaitClosed() ctl.WaitClosed()
svr.ctlManager.Del(loginMsg.RunID, ctl) svr.ctlManager.Del(loginMsg.RunID, ctl)
}() }()
return return nil
} }
// RegisterWorkConn register a new work connection to control and proxies need it. // RegisterWorkConn register a new work connection to control and proxies need it.