mirror of
https://github.com/fatedier/frp.git
synced 2026-01-11 22:23:12 +00:00
sshTunnelGateway refactor (#3784)
This commit is contained in:
@@ -16,30 +16,22 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/crypto"
|
||||
libdial "github.com/fatedier/golib/net/dial"
|
||||
fmux "github.com/hashicorp/yamux"
|
||||
quic "github.com/quic-go/quic-go"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/assets"
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
@@ -75,6 +67,9 @@ type Service struct {
|
||||
// call cancel to stop service
|
||||
cancel context.CancelFunc
|
||||
gracefulDuration time.Duration
|
||||
|
||||
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
|
||||
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
|
||||
}
|
||||
|
||||
func NewService(
|
||||
@@ -84,15 +79,24 @@ func NewService(
|
||||
cfgFile string,
|
||||
) *Service {
|
||||
return &Service{
|
||||
authSetter: auth.NewAuthSetter(cfg.Auth),
|
||||
cfg: cfg,
|
||||
cfgFile: cfgFile,
|
||||
pxyCfgs: pxyCfgs,
|
||||
visitorCfgs: visitorCfgs,
|
||||
ctx: context.Background(),
|
||||
authSetter: auth.NewAuthSetter(cfg.Auth),
|
||||
cfg: cfg,
|
||||
cfgFile: cfgFile,
|
||||
pxyCfgs: pxyCfgs,
|
||||
visitorCfgs: visitorCfgs,
|
||||
ctx: context.Background(),
|
||||
connectorCreator: NewConnector,
|
||||
}
|
||||
}
|
||||
|
||||
func (svr *Service) SetConnectorCreator(h func(context.Context, *v1.ClientCommonConfig) Connector) {
|
||||
svr.connectorCreator = h
|
||||
}
|
||||
|
||||
func (svr *Service) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||
svr.inWorkConnCallback = cb
|
||||
}
|
||||
|
||||
func (svr *Service) GetController() *Control {
|
||||
svr.ctlMu.RLock()
|
||||
defer svr.ctlMu.RUnlock()
|
||||
@@ -101,7 +105,7 @@ func (svr *Service) GetController() *Control {
|
||||
|
||||
func (svr *Service) Run(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
svr.ctx = xlog.NewContext(ctx, xlog.New())
|
||||
svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx))
|
||||
svr.cancel = cancel
|
||||
|
||||
// set custom DNSServer
|
||||
@@ -173,21 +177,20 @@ func (svr *Service) keepControllerWorking() {
|
||||
// login creates a connection to frps and registers it self as a client
|
||||
// conn: control connection
|
||||
// session: if it's not nil, using tcp mux
|
||||
func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
||||
func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
||||
xl := xlog.FromContextSafe(svr.ctx)
|
||||
cm = NewConnectionManager(svr.ctx, svr.cfg)
|
||||
|
||||
if err = cm.OpenConnection(); err != nil {
|
||||
connector = svr.connectorCreator(svr.ctx, svr.cfg)
|
||||
if err = connector.Open(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cm.Close()
|
||||
connector.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err = cm.Connect()
|
||||
conn, err = connector.Connect()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -226,8 +229,7 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
||||
}
|
||||
|
||||
svr.runID = loginRespMsg.RunID
|
||||
xl.ResetPrefixes()
|
||||
xl.AppendPrefix(svr.runID)
|
||||
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||
|
||||
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID)
|
||||
return
|
||||
@@ -239,7 +241,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
||||
|
||||
loginFunc := func() error {
|
||||
xl.Info("try to connect to server...")
|
||||
conn, cm, err := svr.login()
|
||||
conn, connector, err := svr.login()
|
||||
if err != nil {
|
||||
xl.Warn("connect to server error: %v", err)
|
||||
if firstLoginExit {
|
||||
@@ -248,13 +250,14 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
||||
return err
|
||||
}
|
||||
|
||||
ctl, err := NewControl(svr.ctx, svr.runID, conn, cm,
|
||||
ctl, err := NewControl(svr.ctx, svr.runID, conn, connector,
|
||||
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Error("NewControl error: %v", err)
|
||||
return err
|
||||
}
|
||||
ctl.SetInWorkConnCallback(svr.inWorkConnCallback)
|
||||
|
||||
ctl.Run()
|
||||
// close and replace previous control
|
||||
@@ -314,184 +317,3 @@ func (svr *Service) stop() {
|
||||
svr.ctl = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionManager is a wrapper for establishing connections to the server.
|
||||
type ConnectionManager struct {
|
||||
ctx context.Context
|
||||
cfg *v1.ClientCommonConfig
|
||||
|
||||
muxSession *fmux.Session
|
||||
quicConn quic.Connection
|
||||
}
|
||||
|
||||
func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *ConnectionManager {
|
||||
return &ConnectionManager{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenConnection opens a underlying connection to the server.
|
||||
// The underlying connection is either a TCP connection or a QUIC connection.
|
||||
// After the underlying connection is established, you can call Connect() to get a stream.
|
||||
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
|
||||
func (cm *ConnectionManager) OpenConnection() error {
|
||||
xl := xlog.FromContextSafe(cm.ctx)
|
||||
|
||||
// special for quic
|
||||
if strings.EqualFold(cm.cfg.Transport.Protocol, "quic") {
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
sn := cm.cfg.Transport.TLS.ServerName
|
||||
if sn == "" {
|
||||
sn = cm.cfg.ServerAddr
|
||||
}
|
||||
if lo.FromPtr(cm.cfg.Transport.TLS.Enable) {
|
||||
tlsConfig, err = transport.NewClientTLSConfig(
|
||||
cm.cfg.Transport.TLS.CertFile,
|
||||
cm.cfg.Transport.TLS.KeyFile,
|
||||
cm.cfg.Transport.TLS.TrustedCaFile,
|
||||
sn)
|
||||
} else {
|
||||
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warn("fail to build tls configuration, err: %v", err)
|
||||
return err
|
||||
}
|
||||
tlsConfig.NextProtos = []string{"frp"}
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
cm.ctx,
|
||||
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
||||
tlsConfig, &quic.Config{
|
||||
MaxIdleTimeout: time.Duration(cm.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
|
||||
MaxIncomingStreams: int64(cm.cfg.Transport.QUIC.MaxIncomingStreams),
|
||||
KeepAlivePeriod: time.Duration(cm.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cm.quicConn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
if !lo.FromPtr(cm.cfg.Transport.TCPMux) {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := cm.realConnect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmuxCfg := fmux.DefaultConfig()
|
||||
fmuxCfg.KeepAliveInterval = time.Duration(cm.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
||||
fmuxCfg.LogOutput = io.Discard
|
||||
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
|
||||
session, err := fmux.Client(conn, fmuxCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cm.muxSession = session
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
|
||||
func (cm *ConnectionManager) Connect() (net.Conn, error) {
|
||||
if cm.quicConn != nil {
|
||||
stream, err := cm.quicConn.OpenStreamSync(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utilnet.QuicStreamToNetConn(stream, cm.quicConn), nil
|
||||
} else if cm.muxSession != nil {
|
||||
stream, err := cm.muxSession.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
return cm.realConnect()
|
||||
}
|
||||
|
||||
func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
||||
xl := xlog.FromContextSafe(cm.ctx)
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
tlsEnable := lo.FromPtr(cm.cfg.Transport.TLS.Enable)
|
||||
if cm.cfg.Transport.Protocol == "wss" {
|
||||
tlsEnable = true
|
||||
}
|
||||
if tlsEnable {
|
||||
sn := cm.cfg.Transport.TLS.ServerName
|
||||
if sn == "" {
|
||||
sn = cm.cfg.ServerAddr
|
||||
}
|
||||
|
||||
tlsConfig, err = transport.NewClientTLSConfig(
|
||||
cm.cfg.Transport.TLS.CertFile,
|
||||
cm.cfg.Transport.TLS.KeyFile,
|
||||
cm.cfg.Transport.TLS.TrustedCaFile,
|
||||
sn)
|
||||
if err != nil {
|
||||
xl.Warn("fail to build tls configuration, err: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
proxyType, addr, auth, err := libdial.ParseProxyURL(cm.cfg.Transport.ProxyURL)
|
||||
if err != nil {
|
||||
xl.Error("fail to parse proxy url")
|
||||
return nil, err
|
||||
}
|
||||
dialOptions := []libdial.DialOption{}
|
||||
protocol := cm.cfg.Transport.Protocol
|
||||
switch protocol {
|
||||
case "websocket":
|
||||
protocol = "tcp"
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
||||
}))
|
||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
||||
case "wss":
|
||||
protocol = "tcp"
|
||||
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
|
||||
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
|
||||
default:
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
||||
}))
|
||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
||||
}
|
||||
|
||||
if cm.cfg.Transport.ConnectServerLocalIP != "" {
|
||||
dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.Transport.ConnectServerLocalIP))
|
||||
}
|
||||
dialOptions = append(dialOptions,
|
||||
libdial.WithProtocol(protocol),
|
||||
libdial.WithTimeout(time.Duration(cm.cfg.Transport.DialServerTimeout)*time.Second),
|
||||
libdial.WithKeepAlive(time.Duration(cm.cfg.Transport.DialServerKeepAlive)*time.Second),
|
||||
libdial.WithProxy(proxyType, addr),
|
||||
libdial.WithProxyAuth(auth),
|
||||
)
|
||||
conn, err := libdial.DialContext(
|
||||
cm.ctx,
|
||||
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
||||
dialOptions...,
|
||||
)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (cm *ConnectionManager) Close() error {
|
||||
if cm.quicConn != nil {
|
||||
_ = cm.quicConn.CloseWithError(0, "")
|
||||
}
|
||||
if cm.muxSession != nil {
|
||||
_ = cm.muxSession.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user