sshTunnelGateway refactor (#3784)

This commit is contained in:
fatedier
2023-11-21 11:19:35 +08:00
parent 8b432e179d
commit d5b41f1e14
34 changed files with 1036 additions and 1255 deletions

View File

@@ -16,30 +16,22 @@ package client
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/fatedier/golib/crypto"
libdial "github.com/fatedier/golib/net/dial"
fmux "github.com/hashicorp/yamux"
quic "github.com/quic-go/quic-go"
"github.com/samber/lo"
"github.com/fatedier/frp/assets"
"github.com/fatedier/frp/pkg/auth"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog"
@@ -75,6 +67,9 @@ type Service struct {
// call cancel to stop service
cancel context.CancelFunc
gracefulDuration time.Duration
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
}
func NewService(
@@ -84,15 +79,24 @@ func NewService(
cfgFile string,
) *Service {
return &Service{
authSetter: auth.NewAuthSetter(cfg.Auth),
cfg: cfg,
cfgFile: cfgFile,
pxyCfgs: pxyCfgs,
visitorCfgs: visitorCfgs,
ctx: context.Background(),
authSetter: auth.NewAuthSetter(cfg.Auth),
cfg: cfg,
cfgFile: cfgFile,
pxyCfgs: pxyCfgs,
visitorCfgs: visitorCfgs,
ctx: context.Background(),
connectorCreator: NewConnector,
}
}
func (svr *Service) SetConnectorCreator(h func(context.Context, *v1.ClientCommonConfig) Connector) {
svr.connectorCreator = h
}
func (svr *Service) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
svr.inWorkConnCallback = cb
}
func (svr *Service) GetController() *Control {
svr.ctlMu.RLock()
defer svr.ctlMu.RUnlock()
@@ -101,7 +105,7 @@ func (svr *Service) GetController() *Control {
func (svr *Service) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
svr.ctx = xlog.NewContext(ctx, xlog.New())
svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx))
svr.cancel = cancel
// set custom DNSServer
@@ -173,21 +177,20 @@ func (svr *Service) keepControllerWorking() {
// login creates a connection to frps and registers it self as a client
// conn: control connection
// session: if it's not nil, using tcp mux
func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
xl := xlog.FromContextSafe(svr.ctx)
cm = NewConnectionManager(svr.ctx, svr.cfg)
if err = cm.OpenConnection(); err != nil {
connector = svr.connectorCreator(svr.ctx, svr.cfg)
if err = connector.Open(); err != nil {
return nil, nil, err
}
defer func() {
if err != nil {
cm.Close()
connector.Close()
}
}()
conn, err = cm.Connect()
conn, err = connector.Connect()
if err != nil {
return
}
@@ -226,8 +229,7 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
}
svr.runID = loginRespMsg.RunID
xl.ResetPrefixes()
xl.AppendPrefix(svr.runID)
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID)
return
@@ -239,7 +241,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
loginFunc := func() error {
xl.Info("try to connect to server...")
conn, cm, err := svr.login()
conn, connector, err := svr.login()
if err != nil {
xl.Warn("connect to server error: %v", err)
if firstLoginExit {
@@ -248,13 +250,14 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
return err
}
ctl, err := NewControl(svr.ctx, svr.runID, conn, cm,
ctl, err := NewControl(svr.ctx, svr.runID, conn, connector,
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
if err != nil {
conn.Close()
xl.Error("NewControl error: %v", err)
return err
}
ctl.SetInWorkConnCallback(svr.inWorkConnCallback)
ctl.Run()
// close and replace previous control
@@ -314,184 +317,3 @@ func (svr *Service) stop() {
svr.ctl = nil
}
}
// ConnectionManager is a wrapper for establishing connections to the server.
type ConnectionManager struct {
ctx context.Context
cfg *v1.ClientCommonConfig
muxSession *fmux.Session
quicConn quic.Connection
}
func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *ConnectionManager {
return &ConnectionManager{
ctx: ctx,
cfg: cfg,
}
}
// OpenConnection opens a underlying connection to the server.
// The underlying connection is either a TCP connection or a QUIC connection.
// After the underlying connection is established, you can call Connect() to get a stream.
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
func (cm *ConnectionManager) OpenConnection() error {
xl := xlog.FromContextSafe(cm.ctx)
// special for quic
if strings.EqualFold(cm.cfg.Transport.Protocol, "quic") {
var tlsConfig *tls.Config
var err error
sn := cm.cfg.Transport.TLS.ServerName
if sn == "" {
sn = cm.cfg.ServerAddr
}
if lo.FromPtr(cm.cfg.Transport.TLS.Enable) {
tlsConfig, err = transport.NewClientTLSConfig(
cm.cfg.Transport.TLS.CertFile,
cm.cfg.Transport.TLS.KeyFile,
cm.cfg.Transport.TLS.TrustedCaFile,
sn)
} else {
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
}
if err != nil {
xl.Warn("fail to build tls configuration, err: %v", err)
return err
}
tlsConfig.NextProtos = []string{"frp"}
conn, err := quic.DialAddr(
cm.ctx,
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
tlsConfig, &quic.Config{
MaxIdleTimeout: time.Duration(cm.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
MaxIncomingStreams: int64(cm.cfg.Transport.QUIC.MaxIncomingStreams),
KeepAlivePeriod: time.Duration(cm.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
})
if err != nil {
return err
}
cm.quicConn = conn
return nil
}
if !lo.FromPtr(cm.cfg.Transport.TCPMux) {
return nil
}
conn, err := cm.realConnect()
if err != nil {
return err
}
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = time.Duration(cm.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
fmuxCfg.LogOutput = io.Discard
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
session, err := fmux.Client(conn, fmuxCfg)
if err != nil {
return err
}
cm.muxSession = session
return nil
}
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
func (cm *ConnectionManager) Connect() (net.Conn, error) {
if cm.quicConn != nil {
stream, err := cm.quicConn.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
return utilnet.QuicStreamToNetConn(stream, cm.quicConn), nil
} else if cm.muxSession != nil {
stream, err := cm.muxSession.OpenStream()
if err != nil {
return nil, err
}
return stream, nil
}
return cm.realConnect()
}
func (cm *ConnectionManager) realConnect() (net.Conn, error) {
xl := xlog.FromContextSafe(cm.ctx)
var tlsConfig *tls.Config
var err error
tlsEnable := lo.FromPtr(cm.cfg.Transport.TLS.Enable)
if cm.cfg.Transport.Protocol == "wss" {
tlsEnable = true
}
if tlsEnable {
sn := cm.cfg.Transport.TLS.ServerName
if sn == "" {
sn = cm.cfg.ServerAddr
}
tlsConfig, err = transport.NewClientTLSConfig(
cm.cfg.Transport.TLS.CertFile,
cm.cfg.Transport.TLS.KeyFile,
cm.cfg.Transport.TLS.TrustedCaFile,
sn)
if err != nil {
xl.Warn("fail to build tls configuration, err: %v", err)
return nil, err
}
}
proxyType, addr, auth, err := libdial.ParseProxyURL(cm.cfg.Transport.ProxyURL)
if err != nil {
xl.Error("fail to parse proxy url")
return nil, err
}
dialOptions := []libdial.DialOption{}
protocol := cm.cfg.Transport.Protocol
switch protocol {
case "websocket":
protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
}))
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
case "wss":
protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
default:
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
}))
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
}
if cm.cfg.Transport.ConnectServerLocalIP != "" {
dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.Transport.ConnectServerLocalIP))
}
dialOptions = append(dialOptions,
libdial.WithProtocol(protocol),
libdial.WithTimeout(time.Duration(cm.cfg.Transport.DialServerTimeout)*time.Second),
libdial.WithKeepAlive(time.Duration(cm.cfg.Transport.DialServerKeepAlive)*time.Second),
libdial.WithProxy(proxyType, addr),
libdial.WithProxyAuth(auth),
)
conn, err := libdial.DialContext(
cm.ctx,
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
dialOptions...,
)
return conn, err
}
func (cm *ConnectionManager) Close() error {
if cm.quicConn != nil {
_ = cm.quicConn.CloseWithError(0, "")
}
if cm.muxSession != nil {
_ = cm.muxSession.Close()
}
return nil
}