mirror of
https://github.com/fatedier/frp.git
synced 2026-01-11 22:23:12 +00:00
refactor: separate auth config from runtime and defer token resolution (#5105)
This commit is contained in:
@@ -43,8 +43,8 @@ type SessionContext struct {
|
|||||||
Conn net.Conn
|
Conn net.Conn
|
||||||
// Indicates whether the connection is encrypted.
|
// Indicates whether the connection is encrypted.
|
||||||
ConnEncrypted bool
|
ConnEncrypted bool
|
||||||
// Sets authentication based on selected method
|
// Auth runtime used for login, heartbeats, and encryption.
|
||||||
AuthSetter auth.Setter
|
Auth *auth.ClientAuth
|
||||||
// Connector is used to create new connections, which could be real TCP connections or virtual streams.
|
// Connector is used to create new connections, which could be real TCP connections or virtual streams.
|
||||||
Connector Connector
|
Connector Connector
|
||||||
// Virtual net controller
|
// Virtual net controller
|
||||||
@@ -91,7 +91,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
|||||||
ctl.lastPong.Store(time.Now())
|
ctl.lastPong.Store(time.Now())
|
||||||
|
|
||||||
if sessionCtx.ConnEncrypted {
|
if sessionCtx.ConnEncrypted {
|
||||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, []byte(sessionCtx.Common.Auth.Token))
|
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -102,7 +102,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
|||||||
ctl.registerMsgHandlers()
|
ctl.registerMsgHandlers()
|
||||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||||
|
|
||||||
ctl.pm = proxy.NewManager(ctl.ctx, sessionCtx.Common, ctl.msgTransporter, sessionCtx.VnetController)
|
ctl.pm = proxy.NewManager(ctl.ctx, sessionCtx.Common, sessionCtx.Auth.EncryptionKey(), ctl.msgTransporter, sessionCtx.VnetController)
|
||||||
ctl.vm = visitor.NewManager(ctl.ctx, sessionCtx.RunID, sessionCtx.Common,
|
ctl.vm = visitor.NewManager(ctl.ctx, sessionCtx.RunID, sessionCtx.Common,
|
||||||
ctl.connectServer, ctl.msgTransporter, sessionCtx.VnetController)
|
ctl.connectServer, ctl.msgTransporter, sessionCtx.VnetController)
|
||||||
return ctl, nil
|
return ctl, nil
|
||||||
@@ -133,7 +133,7 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
|||||||
m := &msg.NewWorkConn{
|
m := &msg.NewWorkConn{
|
||||||
RunID: ctl.sessionCtx.RunID,
|
RunID: ctl.sessionCtx.RunID,
|
||||||
}
|
}
|
||||||
if err = ctl.sessionCtx.AuthSetter.SetNewWorkConn(m); err != nil {
|
if err = ctl.sessionCtx.Auth.Setter.SetNewWorkConn(m); err != nil {
|
||||||
xl.Warnf("error during NewWorkConn authentication: %v", err)
|
xl.Warnf("error during NewWorkConn authentication: %v", err)
|
||||||
workConn.Close()
|
workConn.Close()
|
||||||
return
|
return
|
||||||
@@ -243,7 +243,7 @@ func (ctl *Control) heartbeatWorker() {
|
|||||||
sendHeartBeat := func() (bool, error) {
|
sendHeartBeat := func() (bool, error) {
|
||||||
xl.Debugf("send heartbeat to server")
|
xl.Debugf("send heartbeat to server")
|
||||||
pingMsg := &msg.Ping{}
|
pingMsg := &msg.Ping{}
|
||||||
if err := ctl.sessionCtx.AuthSetter.SetPing(pingMsg); err != nil {
|
if err := ctl.sessionCtx.Auth.Setter.SetPing(pingMsg); err != nil {
|
||||||
xl.Warnf("error during ping authentication: %v, skip sending ping message", err)
|
xl.Warnf("error during ping authentication: %v, skip sending ping message", err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ func NewProxy(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
pxyConf v1.ProxyConfigurer,
|
pxyConf v1.ProxyConfigurer,
|
||||||
clientCfg *v1.ClientCommonConfig,
|
clientCfg *v1.ClientCommonConfig,
|
||||||
|
encryptionKey []byte,
|
||||||
msgTransporter transport.MessageTransporter,
|
msgTransporter transport.MessageTransporter,
|
||||||
vnetController *vnet.Controller,
|
vnetController *vnet.Controller,
|
||||||
) (pxy Proxy) {
|
) (pxy Proxy) {
|
||||||
@@ -69,6 +70,7 @@ func NewProxy(
|
|||||||
baseProxy := BaseProxy{
|
baseProxy := BaseProxy{
|
||||||
baseCfg: pxyConf.GetBaseConfig(),
|
baseCfg: pxyConf.GetBaseConfig(),
|
||||||
clientCfg: clientCfg,
|
clientCfg: clientCfg,
|
||||||
|
encryptionKey: encryptionKey,
|
||||||
limiter: limiter,
|
limiter: limiter,
|
||||||
msgTransporter: msgTransporter,
|
msgTransporter: msgTransporter,
|
||||||
vnetController: vnetController,
|
vnetController: vnetController,
|
||||||
@@ -86,6 +88,7 @@ func NewProxy(
|
|||||||
type BaseProxy struct {
|
type BaseProxy struct {
|
||||||
baseCfg *v1.ProxyBaseConfig
|
baseCfg *v1.ProxyBaseConfig
|
||||||
clientCfg *v1.ClientCommonConfig
|
clientCfg *v1.ClientCommonConfig
|
||||||
|
encryptionKey []byte
|
||||||
msgTransporter transport.MessageTransporter
|
msgTransporter transport.MessageTransporter
|
||||||
vnetController *vnet.Controller
|
vnetController *vnet.Controller
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
@@ -129,7 +132,7 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Auth.Token))
|
pxy.HandleTCPWorkConnection(conn, m, pxy.encryptionKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Common handler for tcp work connections.
|
// Common handler for tcp work connections.
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type Manager struct {
|
|||||||
closed bool
|
closed bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
encryptionKey []byte
|
||||||
clientCfg *v1.ClientCommonConfig
|
clientCfg *v1.ClientCommonConfig
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -48,6 +49,7 @@ type Manager struct {
|
|||||||
func NewManager(
|
func NewManager(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
clientCfg *v1.ClientCommonConfig,
|
clientCfg *v1.ClientCommonConfig,
|
||||||
|
encryptionKey []byte,
|
||||||
msgTransporter transport.MessageTransporter,
|
msgTransporter transport.MessageTransporter,
|
||||||
vnetController *vnet.Controller,
|
vnetController *vnet.Controller,
|
||||||
) *Manager {
|
) *Manager {
|
||||||
@@ -56,6 +58,7 @@ func NewManager(
|
|||||||
msgTransporter: msgTransporter,
|
msgTransporter: msgTransporter,
|
||||||
vnetController: vnetController,
|
vnetController: vnetController,
|
||||||
closed: false,
|
closed: false,
|
||||||
|
encryptionKey: encryptionKey,
|
||||||
clientCfg: clientCfg,
|
clientCfg: clientCfg,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
@@ -163,7 +166,7 @@ func (pm *Manager) UpdateAll(proxyCfgs []v1.ProxyConfigurer) {
|
|||||||
for _, cfg := range proxyCfgs {
|
for _, cfg := range proxyCfgs {
|
||||||
name := cfg.GetBaseConfig().Name
|
name := cfg.GetBaseConfig().Name
|
||||||
if _, ok := pm.proxies[name]; !ok {
|
if _, ok := pm.proxies[name]; !ok {
|
||||||
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter, pm.vnetController)
|
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.encryptionKey, pm.HandleEvent, pm.msgTransporter, pm.vnetController)
|
||||||
if pm.inWorkConnCallback != nil {
|
if pm.inWorkConnCallback != nil {
|
||||||
pxy.SetInWorkConnCallback(pm.inWorkConnCallback)
|
pxy.SetInWorkConnCallback(pm.inWorkConnCallback)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ func NewWrapper(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
cfg v1.ProxyConfigurer,
|
cfg v1.ProxyConfigurer,
|
||||||
clientCfg *v1.ClientCommonConfig,
|
clientCfg *v1.ClientCommonConfig,
|
||||||
|
encryptionKey []byte,
|
||||||
eventHandler event.Handler,
|
eventHandler event.Handler,
|
||||||
msgTransporter transport.MessageTransporter,
|
msgTransporter transport.MessageTransporter,
|
||||||
vnetController *vnet.Controller,
|
vnetController *vnet.Controller,
|
||||||
@@ -122,7 +123,7 @@ func NewWrapper(
|
|||||||
xl.Tracef("enable health check monitor")
|
xl.Tracef("enable health check monitor")
|
||||||
}
|
}
|
||||||
|
|
||||||
pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, pw.msgTransporter, pw.vnetController)
|
pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, encryptionKey, pw.msgTransporter, pw.vnetController)
|
||||||
return pw
|
return pw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
if pxy.cfg.Transport.UseEncryption {
|
if pxy.cfg.Transport.UseEncryption {
|
||||||
rwc, err = libio.WithEncryption(rwc, []byte(pxy.clientCfg.Auth.Token))
|
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
xl.Errorf("create encryption stream error: %v", err)
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
if pxy.cfg.Transport.UseEncryption {
|
if pxy.cfg.Transport.UseEncryption {
|
||||||
rwc, err = libio.WithEncryption(rwc, []byte(pxy.clientCfg.Auth.Token))
|
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
xl.Errorf("create encryption stream error: %v", err)
|
||||||
|
|||||||
@@ -111,8 +111,8 @@ type Service struct {
|
|||||||
// Uniq id got from frps, it will be attached to loginMsg.
|
// Uniq id got from frps, it will be attached to loginMsg.
|
||||||
runID string
|
runID string
|
||||||
|
|
||||||
// Sets authentication based on selected method
|
// Auth runtime and encryption materials
|
||||||
authSetter auth.Setter
|
auth *auth.ClientAuth
|
||||||
|
|
||||||
// web server for admin UI and apis
|
// web server for admin UI and apis
|
||||||
webServer *httppkg.Server
|
webServer *httppkg.Server
|
||||||
@@ -155,14 +155,14 @@ func NewService(options ServiceOptions) (*Service, error) {
|
|||||||
webServer = ws
|
webServer = ws
|
||||||
}
|
}
|
||||||
|
|
||||||
authSetter, err := auth.NewAuthSetter(options.Common.Auth)
|
authRuntime, err := auth.BuildClientAuth(&options.Common.Auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &Service{
|
s := &Service{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
authSetter: authSetter,
|
auth: authRuntime,
|
||||||
webServer: webServer,
|
webServer: webServer,
|
||||||
common: options.Common,
|
common: options.Common,
|
||||||
configFilePath: options.ConfigFilePath,
|
configFilePath: options.ConfigFilePath,
|
||||||
@@ -296,7 +296,7 @@ func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add auth
|
// Add auth
|
||||||
if err = svr.authSetter.SetLogin(loginMsg); err != nil {
|
if err = svr.auth.Setter.SetLogin(loginMsg); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
|||||||
RunID: svr.runID,
|
RunID: svr.runID,
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
ConnEncrypted: connEncrypted,
|
ConnEncrypted: connEncrypted,
|
||||||
AuthSetter: svr.authSetter,
|
Auth: svr.auth,
|
||||||
Connector: connector,
|
Connector: connector,
|
||||||
VnetController: svr.vnetController,
|
VnetController: svr.vnetController,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,7 +80,8 @@ func NewProxyCommand(name string, c v1.ProxyConfigurer, clientCfg *v1.ClientComm
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
|
unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
|
||||||
if _, err := validation.ValidateClientCommonConfig(clientCfg, unsafeFeatures); err != nil {
|
validator := validation.NewConfigValidator(unsafeFeatures)
|
||||||
|
if _, err := validator.ValidateClientCommonConfig(clientCfg); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -110,7 +111,8 @@ func NewVisitorCommand(name string, c v1.VisitorConfigurer, clientCfg *v1.Client
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
|
unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
|
||||||
if _, err := validation.ValidateClientCommonConfig(clientCfg, unsafeFeatures); err != nil {
|
validator := validation.NewConfigValidator(unsafeFeatures)
|
||||||
|
if _, err := validator.ValidateClientCommonConfig(clientCfg); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -142,6 +142,7 @@ func runClient(cfgFilePath string, unsafeFeatures *security.UnsafeFeatures) erro
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return startService(cfg, proxyCfgs, visitorCfgs, unsafeFeatures, cfgFilePath)
|
return startService(cfg, proxyCfgs, visitorCfgs, unsafeFeatures, cfgFilePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config"
|
"github.com/fatedier/frp/pkg/config"
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||||
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
"github.com/fatedier/frp/pkg/util/version"
|
"github.com/fatedier/frp/pkg/util/version"
|
||||||
"github.com/fatedier/frp/server"
|
"github.com/fatedier/frp/server"
|
||||||
@@ -33,6 +35,7 @@ var (
|
|||||||
cfgFile string
|
cfgFile string
|
||||||
showVersion bool
|
showVersion bool
|
||||||
strictConfigMode bool
|
strictConfigMode bool
|
||||||
|
allowUnsafe []string
|
||||||
|
|
||||||
serverCfg v1.ServerConfig
|
serverCfg v1.ServerConfig
|
||||||
)
|
)
|
||||||
@@ -41,6 +44,8 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps")
|
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps")
|
||||||
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps")
|
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps")
|
||||||
rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", true, "strict config parsing mode, unknown fields will cause errors")
|
rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", true, "strict config parsing mode, unknown fields will cause errors")
|
||||||
|
rootCmd.PersistentFlags().StringSliceVarP(&allowUnsafe, "allow-unsafe", "", []string{},
|
||||||
|
fmt.Sprintf("allowed unsafe features, one or more of: %s", strings.Join(security.ServerUnsafeFeatures, ", ")))
|
||||||
|
|
||||||
config.RegisterServerConfigFlags(rootCmd, &serverCfg)
|
config.RegisterServerConfigFlags(rootCmd, &serverCfg)
|
||||||
}
|
}
|
||||||
@@ -77,7 +82,9 @@ var rootCmd = &cobra.Command{
|
|||||||
svrCfg = &serverCfg
|
svrCfg = &serverCfg
|
||||||
}
|
}
|
||||||
|
|
||||||
warning, err := validation.ValidateServerConfig(svrCfg)
|
unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
|
||||||
|
validator := validation.NewConfigValidator(unsafeFeatures)
|
||||||
|
warning, err := validator.ValidateServerConfig(svrCfg)
|
||||||
if warning != nil {
|
if warning != nil {
|
||||||
fmt.Printf("WARNING: %v\n", warning)
|
fmt.Printf("WARNING: %v\n", warning)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config"
|
"github.com/fatedier/frp/pkg/config"
|
||||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||||
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -42,7 +43,9 @@ var verifyCmd = &cobra.Command{
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
warning, err := validation.ValidateServerConfig(svrCfg)
|
unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
|
||||||
|
validator := validation.NewConfigValidator(unsafeFeatures)
|
||||||
|
warning, err := validator.ValidateServerConfig(svrCfg)
|
||||||
if warning != nil {
|
if warning != nil {
|
||||||
fmt.Printf("WARNING: %v\n", warning)
|
fmt.Printf("WARNING: %v\n", warning)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
@@ -27,6 +28,39 @@ type Setter interface {
|
|||||||
SetNewWorkConn(*msg.NewWorkConn) error
|
SetNewWorkConn(*msg.NewWorkConn) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClientAuth struct {
|
||||||
|
Setter Setter
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ClientAuth) EncryptionKey() []byte {
|
||||||
|
return a.key
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClientAuth resolves any dynamic auth values and returns a prepared auth runtime.
|
||||||
|
// Caller must run validation before calling this function.
|
||||||
|
func BuildClientAuth(cfg *v1.AuthClientConfig) (*ClientAuth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("auth config is nil")
|
||||||
|
}
|
||||||
|
resolved := *cfg
|
||||||
|
if resolved.Method == v1.AuthMethodToken && resolved.TokenSource != nil {
|
||||||
|
token, err := resolved.TokenSource.Resolve(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
||||||
|
}
|
||||||
|
resolved.Token = token
|
||||||
|
}
|
||||||
|
setter, err := NewAuthSetter(resolved)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ClientAuth{
|
||||||
|
Setter: setter,
|
||||||
|
key: []byte(resolved.Token),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewAuthSetter(cfg v1.AuthClientConfig) (authProvider Setter, err error) {
|
func NewAuthSetter(cfg v1.AuthClientConfig) (authProvider Setter, err error) {
|
||||||
switch cfg.Method {
|
switch cfg.Method {
|
||||||
case v1.AuthMethodToken:
|
case v1.AuthMethodToken:
|
||||||
@@ -52,6 +86,35 @@ type Verifier interface {
|
|||||||
VerifyNewWorkConn(*msg.NewWorkConn) error
|
VerifyNewWorkConn(*msg.NewWorkConn) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ServerAuth struct {
|
||||||
|
Verifier Verifier
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ServerAuth) EncryptionKey() []byte {
|
||||||
|
return a.key
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildServerAuth resolves any dynamic auth values and returns a prepared auth runtime.
|
||||||
|
// Caller must run validation before calling this function.
|
||||||
|
func BuildServerAuth(cfg *v1.AuthServerConfig) (*ServerAuth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("auth config is nil")
|
||||||
|
}
|
||||||
|
resolved := *cfg
|
||||||
|
if resolved.Method == v1.AuthMethodToken && resolved.TokenSource != nil {
|
||||||
|
token, err := resolved.TokenSource.Resolve(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
||||||
|
}
|
||||||
|
resolved.Token = token
|
||||||
|
}
|
||||||
|
return &ServerAuth{
|
||||||
|
Verifier: NewAuthVerifier(resolved),
|
||||||
|
key: []byte(resolved.Token),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) {
|
func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) {
|
||||||
switch cfg.Method {
|
switch cfg.Method {
|
||||||
case v1.AuthMethodToken:
|
case v1.AuthMethodToken:
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@@ -198,17 +196,6 @@ type AuthClientConfig struct {
|
|||||||
|
|
||||||
func (c *AuthClientConfig) Complete() error {
|
func (c *AuthClientConfig) Complete() error {
|
||||||
c.Method = util.EmptyOr(c.Method, "token")
|
c.Method = util.EmptyOr(c.Method, "token")
|
||||||
|
|
||||||
// Resolve tokenSource during configuration loading
|
|
||||||
if c.Method == AuthMethodToken && c.TokenSource != nil {
|
|
||||||
token, err := c.TokenSource.Resolve(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
|
||||||
}
|
|
||||||
// Move the resolved token to the Token field and clear TokenSource
|
|
||||||
c.Token = token
|
|
||||||
c.TokenSource = nil
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@@ -38,68 +36,9 @@ func TestClientConfigComplete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthClientConfig_Complete(t *testing.T) {
|
func TestAuthClientConfig_Complete(t *testing.T) {
|
||||||
// Create a temporary file for testing
|
require := require.New(t)
|
||||||
tmpDir := t.TempDir()
|
cfg := &AuthClientConfig{}
|
||||||
testFile := filepath.Join(tmpDir, "test_token")
|
err := cfg.Complete()
|
||||||
testContent := "client-token-value"
|
require.NoError(err)
|
||||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
require.EqualValues("token", cfg.Method)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config AuthClientConfig
|
|
||||||
expectToken string
|
|
||||||
expectPanic bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "tokenSource resolved to token",
|
|
||||||
config: AuthClientConfig{
|
|
||||||
Method: AuthMethodToken,
|
|
||||||
TokenSource: &ValueSource{
|
|
||||||
Type: "file",
|
|
||||||
File: &FileSource{
|
|
||||||
Path: testFile,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectToken: testContent,
|
|
||||||
expectPanic: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "direct token unchanged",
|
|
||||||
config: AuthClientConfig{
|
|
||||||
Method: AuthMethodToken,
|
|
||||||
Token: "direct-token",
|
|
||||||
},
|
|
||||||
expectToken: "direct-token",
|
|
||||||
expectPanic: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid tokenSource should panic",
|
|
||||||
config: AuthClientConfig{
|
|
||||||
Method: AuthMethodToken,
|
|
||||||
TokenSource: &ValueSource{
|
|
||||||
Type: "file",
|
|
||||||
File: &FileSource{
|
|
||||||
Path: "/non/existent/file",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectPanic: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.expectPanic {
|
|
||||||
err := tt.config.Complete()
|
|
||||||
require.Error(t, err)
|
|
||||||
} else {
|
|
||||||
err := tt.config.Complete()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, tt.expectToken, tt.config.Token)
|
|
||||||
require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,9 +15,6 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config/types"
|
"github.com/fatedier/frp/pkg/config/types"
|
||||||
@@ -138,17 +135,6 @@ type AuthServerConfig struct {
|
|||||||
|
|
||||||
func (c *AuthServerConfig) Complete() error {
|
func (c *AuthServerConfig) Complete() error {
|
||||||
c.Method = util.EmptyOr(c.Method, "token")
|
c.Method = util.EmptyOr(c.Method, "token")
|
||||||
|
|
||||||
// Resolve tokenSource during configuration loading
|
|
||||||
if c.Method == AuthMethodToken && c.TokenSource != nil {
|
|
||||||
token, err := c.TokenSource.Resolve(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
|
||||||
}
|
|
||||||
// Move the resolved token to the Token field and clear TokenSource
|
|
||||||
c.Token = token
|
|
||||||
c.TokenSource = nil
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@@ -35,68 +33,9 @@ func TestServerConfigComplete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthServerConfig_Complete(t *testing.T) {
|
func TestAuthServerConfig_Complete(t *testing.T) {
|
||||||
// Create a temporary file for testing
|
require := require.New(t)
|
||||||
tmpDir := t.TempDir()
|
cfg := &AuthServerConfig{}
|
||||||
testFile := filepath.Join(tmpDir, "test_token")
|
err := cfg.Complete()
|
||||||
testContent := "file-token-value"
|
require.NoError(err)
|
||||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
require.EqualValues("token", cfg.Method)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config AuthServerConfig
|
|
||||||
expectToken string
|
|
||||||
expectPanic bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "tokenSource resolved to token",
|
|
||||||
config: AuthServerConfig{
|
|
||||||
Method: AuthMethodToken,
|
|
||||||
TokenSource: &ValueSource{
|
|
||||||
Type: "file",
|
|
||||||
File: &FileSource{
|
|
||||||
Path: testFile,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectToken: testContent,
|
|
||||||
expectPanic: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "direct token unchanged",
|
|
||||||
config: AuthServerConfig{
|
|
||||||
Method: AuthMethodToken,
|
|
||||||
Token: "direct-token",
|
|
||||||
},
|
|
||||||
expectToken: "direct-token",
|
|
||||||
expectPanic: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid tokenSource should panic",
|
|
||||||
config: AuthServerConfig{
|
|
||||||
Method: AuthMethodToken,
|
|
||||||
TokenSource: &ValueSource{
|
|
||||||
Type: "file",
|
|
||||||
File: &FileSource{
|
|
||||||
Path: "/non/existent/file",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectPanic: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.expectPanic {
|
|
||||||
err := tt.config.Complete()
|
|
||||||
require.Error(t, err)
|
|
||||||
} else {
|
|
||||||
err := tt.config.Complete()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, tt.expectToken, tt.config.Token)
|
|
||||||
require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import (
|
|||||||
"github.com/fatedier/frp/pkg/policy/security"
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
|
func (v *ConfigValidator) ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) {
|
||||||
var (
|
var (
|
||||||
warnings Warning
|
warnings Warning
|
||||||
errs error
|
errs error
|
||||||
@@ -35,15 +35,15 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *securi
|
|||||||
|
|
||||||
validators := []func() (Warning, error){
|
validators := []func() (Warning, error){
|
||||||
func() (Warning, error) { return validateFeatureGates(c) },
|
func() (Warning, error) { return validateFeatureGates(c) },
|
||||||
func() (Warning, error) { return validateAuthConfig(&c.Auth, unsafeFeatures) },
|
func() (Warning, error) { return v.validateAuthConfig(&c.Auth) },
|
||||||
func() (Warning, error) { return nil, validateLogConfig(&c.Log) },
|
func() (Warning, error) { return nil, validateLogConfig(&c.Log) },
|
||||||
func() (Warning, error) { return nil, validateWebServerConfig(&c.WebServer) },
|
func() (Warning, error) { return nil, validateWebServerConfig(&c.WebServer) },
|
||||||
func() (Warning, error) { return validateTransportConfig(&c.Transport) },
|
func() (Warning, error) { return validateTransportConfig(&c.Transport) },
|
||||||
func() (Warning, error) { return validateIncludeFiles(c.IncludeConfigFiles) },
|
func() (Warning, error) { return validateIncludeFiles(c.IncludeConfigFiles) },
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range validators {
|
for _, validator := range validators {
|
||||||
w, err := v()
|
w, err := validator()
|
||||||
warnings = AppendError(warnings, w)
|
warnings = AppendError(warnings, w)
|
||||||
errs = AppendError(errs, err)
|
errs = AppendError(errs, err)
|
||||||
}
|
}
|
||||||
@@ -59,7 +59,7 @@ func validateFeatureGates(c *v1.ClientCommonConfig) (Warning, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
|
func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, error) {
|
||||||
var errs error
|
var errs error
|
||||||
if !slices.Contains(SupportedAuthMethods, c.Method) {
|
if !slices.Contains(SupportedAuthMethods, c.Method) {
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", SupportedAuthMethods))
|
errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", SupportedAuthMethods))
|
||||||
@@ -76,9 +76,8 @@ func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeF
|
|||||||
// Validate tokenSource if specified
|
// Validate tokenSource if specified
|
||||||
if c.TokenSource != nil {
|
if c.TokenSource != nil {
|
||||||
if c.TokenSource.Type == "exec" {
|
if c.TokenSource.Type == "exec" {
|
||||||
if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
|
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||||
errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
|
errs = AppendError(errs, err)
|
||||||
"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := c.TokenSource.Validate(); err != nil {
|
if err := c.TokenSource.Validate(); err != nil {
|
||||||
@@ -86,13 +85,13 @@ func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeF
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateOIDCConfig(&c.OIDC, unsafeFeatures); err != nil {
|
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
||||||
errs = AppendError(errs, err)
|
errs = AppendError(errs, err)
|
||||||
}
|
}
|
||||||
return nil, errs
|
return nil, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.UnsafeFeatures) error {
|
func (v *ConfigValidator) validateOIDCConfig(c *v1.AuthOIDCClientConfig) error {
|
||||||
if c.TokenSource == nil {
|
if c.TokenSource == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -104,9 +103,8 @@ func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.Uns
|
|||||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc"))
|
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc"))
|
||||||
}
|
}
|
||||||
if c.TokenSource.Type == "exec" {
|
if c.TokenSource.Type == "exec" {
|
||||||
if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
|
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||||
errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
|
errs = AppendError(errs, err)
|
||||||
"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := c.TokenSource.Validate(); err != nil {
|
if err := c.TokenSource.Validate(); err != nil {
|
||||||
@@ -167,9 +165,10 @@ func ValidateAllClientConfig(
|
|||||||
visitorCfgs []v1.VisitorConfigurer,
|
visitorCfgs []v1.VisitorConfigurer,
|
||||||
unsafeFeatures *security.UnsafeFeatures,
|
unsafeFeatures *security.UnsafeFeatures,
|
||||||
) (Warning, error) {
|
) (Warning, error) {
|
||||||
|
validator := NewConfigValidator(unsafeFeatures)
|
||||||
var warnings Warning
|
var warnings Warning
|
||||||
if c != nil {
|
if c != nil {
|
||||||
warning, err := ValidateClientCommonConfig(c, unsafeFeatures)
|
warning, err := validator.ValidateClientCommonConfig(c)
|
||||||
warnings = AppendError(warnings, warning)
|
warnings = AppendError(warnings, warning)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return warnings, err
|
return warnings, err
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ import (
|
|||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||||
var (
|
var (
|
||||||
warnings Warning
|
warnings Warning
|
||||||
errs error
|
errs error
|
||||||
@@ -42,6 +43,11 @@ func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
|||||||
|
|
||||||
// Validate tokenSource if specified
|
// Validate tokenSource if specified
|
||||||
if c.Auth.TokenSource != nil {
|
if c.Auth.TokenSource != nil {
|
||||||
|
if c.Auth.TokenSource.Type == "exec" {
|
||||||
|
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||||
|
errs = AppendError(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
if err := c.Auth.TokenSource.Validate(); err != nil {
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||||
}
|
}
|
||||||
|
|||||||
28
pkg/config/v1/validation/validator.go
Normal file
28
pkg/config/v1/validation/validator.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigValidator holds the context dependencies for configuration validation.
|
||||||
|
type ConfigValidator struct {
|
||||||
|
unsafeFeatures *security.UnsafeFeatures
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigValidator creates a new ConfigValidator instance.
|
||||||
|
func NewConfigValidator(unsafeFeatures *security.UnsafeFeatures) *ConfigValidator {
|
||||||
|
return &ConfigValidator{
|
||||||
|
unsafeFeatures: unsafeFeatures,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUnsafeFeature checks if a specific unsafe feature is enabled.
|
||||||
|
func (v *ConfigValidator) ValidateUnsafeFeature(feature string) error {
|
||||||
|
if !v.unsafeFeatures.IsEnabled(feature) {
|
||||||
|
return fmt.Errorf("unsafe feature %q is not enabled. "+
|
||||||
|
"To enable it, ensure it is allowed in the configuration or command line flags", feature)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -106,6 +106,8 @@ type Control struct {
|
|||||||
|
|
||||||
// verifies authentication based on selected method
|
// verifies authentication based on selected method
|
||||||
authVerifier auth.Verifier
|
authVerifier auth.Verifier
|
||||||
|
// key used for connection encryption
|
||||||
|
encryptionKey []byte
|
||||||
|
|
||||||
// other components can use this to communicate with client
|
// other components can use this to communicate with client
|
||||||
msgTransporter transport.MessageTransporter
|
msgTransporter transport.MessageTransporter
|
||||||
@@ -157,6 +159,7 @@ func NewControl(
|
|||||||
pxyManager *proxy.Manager,
|
pxyManager *proxy.Manager,
|
||||||
pluginManager *plugin.Manager,
|
pluginManager *plugin.Manager,
|
||||||
authVerifier auth.Verifier,
|
authVerifier auth.Verifier,
|
||||||
|
encryptionKey []byte,
|
||||||
ctlConn net.Conn,
|
ctlConn net.Conn,
|
||||||
ctlConnEncrypted bool,
|
ctlConnEncrypted bool,
|
||||||
loginMsg *msg.Login,
|
loginMsg *msg.Login,
|
||||||
@@ -171,6 +174,7 @@ func NewControl(
|
|||||||
pxyManager: pxyManager,
|
pxyManager: pxyManager,
|
||||||
pluginManager: pluginManager,
|
pluginManager: pluginManager,
|
||||||
authVerifier: authVerifier,
|
authVerifier: authVerifier,
|
||||||
|
encryptionKey: encryptionKey,
|
||||||
conn: ctlConn,
|
conn: ctlConn,
|
||||||
loginMsg: loginMsg,
|
loginMsg: loginMsg,
|
||||||
workConnCh: make(chan net.Conn, poolCount+10),
|
workConnCh: make(chan net.Conn, poolCount+10),
|
||||||
@@ -186,7 +190,7 @@ func NewControl(
|
|||||||
ctl.lastPing.Store(time.Now())
|
ctl.lastPing.Store(time.Now())
|
||||||
|
|
||||||
if ctlConnEncrypted {
|
if ctlConnEncrypted {
|
||||||
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token))
|
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -478,6 +482,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
|||||||
GetWorkConnFn: ctl.GetWorkConn,
|
GetWorkConnFn: ctl.GetWorkConn,
|
||||||
Configurer: pxyConf,
|
Configurer: pxyConf,
|
||||||
ServerCfg: ctl.serverCfg,
|
ServerCfg: ctl.serverCfg,
|
||||||
|
EncryptionKey: ctl.encryptionKey,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return remoteAddr, err
|
return remoteAddr, err
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err
|
|||||||
|
|
||||||
var rwc io.ReadWriteCloser = tmpConn
|
var rwc io.ReadWriteCloser = tmpConn
|
||||||
if pxy.cfg.Transport.UseEncryption {
|
if pxy.cfg.Transport.UseEncryption {
|
||||||
rwc, err = libio.WithEncryption(rwc, []byte(pxy.serverCfg.Auth.Token))
|
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
xl.Errorf("create encryption stream error: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ type BaseProxy struct {
|
|||||||
poolCount int
|
poolCount int
|
||||||
getWorkConnFn GetWorkConnFn
|
getWorkConnFn GetWorkConnFn
|
||||||
serverCfg *v1.ServerConfig
|
serverCfg *v1.ServerConfig
|
||||||
|
encryptionKey []byte
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
userInfo plugin.UserInfo
|
userInfo plugin.UserInfo
|
||||||
loginMsg *msg.Login
|
loginMsg *msg.Login
|
||||||
@@ -213,7 +214,6 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) {
|
|||||||
xl := xlog.FromContextSafe(pxy.Context())
|
xl := xlog.FromContextSafe(pxy.Context())
|
||||||
defer userConn.Close()
|
defer userConn.Close()
|
||||||
|
|
||||||
serverCfg := pxy.serverCfg
|
|
||||||
cfg := pxy.configurer.GetBaseConfig()
|
cfg := pxy.configurer.GetBaseConfig()
|
||||||
// server plugin hook
|
// server plugin hook
|
||||||
rc := pxy.GetResourceController()
|
rc := pxy.GetResourceController()
|
||||||
@@ -240,7 +240,7 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) {
|
|||||||
xl.Tracef("handler user tcp connection, use_encryption: %t, use_compression: %t",
|
xl.Tracef("handler user tcp connection, use_encryption: %t, use_compression: %t",
|
||||||
cfg.Transport.UseEncryption, cfg.Transport.UseCompression)
|
cfg.Transport.UseEncryption, cfg.Transport.UseCompression)
|
||||||
if cfg.Transport.UseEncryption {
|
if cfg.Transport.UseEncryption {
|
||||||
local, err = libio.WithEncryption(local, []byte(serverCfg.Auth.Token))
|
local, err = libio.WithEncryption(local, pxy.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
xl.Errorf("create encryption stream error: %v", err)
|
||||||
return
|
return
|
||||||
@@ -279,6 +279,7 @@ type Options struct {
|
|||||||
GetWorkConnFn GetWorkConnFn
|
GetWorkConnFn GetWorkConnFn
|
||||||
Configurer v1.ProxyConfigurer
|
Configurer v1.ProxyConfigurer
|
||||||
ServerCfg *v1.ServerConfig
|
ServerCfg *v1.ServerConfig
|
||||||
|
EncryptionKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxy(ctx context.Context, options *Options) (pxy Proxy, err error) {
|
func NewProxy(ctx context.Context, options *Options) (pxy Proxy, err error) {
|
||||||
@@ -298,6 +299,7 @@ func NewProxy(ctx context.Context, options *Options) (pxy Proxy, err error) {
|
|||||||
poolCount: options.PoolCount,
|
poolCount: options.PoolCount,
|
||||||
getWorkConnFn: options.GetWorkConnFn,
|
getWorkConnFn: options.GetWorkConnFn,
|
||||||
serverCfg: options.ServerCfg,
|
serverCfg: options.ServerCfg,
|
||||||
|
encryptionKey: options.EncryptionKey,
|
||||||
limiter: limiter,
|
limiter: limiter,
|
||||||
xl: xl,
|
xl: xl,
|
||||||
ctx: xlog.NewContext(ctx, xl),
|
ctx: xlog.NewContext(ctx, xl),
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
|
|||||||
|
|
||||||
var rwc io.ReadWriteCloser = workConn
|
var rwc io.ReadWriteCloser = workConn
|
||||||
if pxy.cfg.Transport.UseEncryption {
|
if pxy.cfg.Transport.UseEncryption {
|
||||||
rwc, err = libio.WithEncryption(rwc, []byte(pxy.serverCfg.Auth.Token))
|
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
xl.Errorf("create encryption stream error: %v", err)
|
||||||
workConn.Close()
|
workConn.Close()
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ type Service struct {
|
|||||||
|
|
||||||
sshTunnelGateway *ssh.Gateway
|
sshTunnelGateway *ssh.Gateway
|
||||||
|
|
||||||
// Verifies authentication based on selected method
|
// Auth runtime and encryption materials
|
||||||
authVerifier auth.Verifier
|
auth *auth.ServerAuth
|
||||||
|
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
@@ -149,6 +149,11 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authRuntime, err := auth.BuildServerAuth(&cfg.Auth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
svr := &Service{
|
svr := &Service{
|
||||||
ctlManager: NewControlManager(),
|
ctlManager: NewControlManager(),
|
||||||
pxyManager: proxy.NewManager(),
|
pxyManager: proxy.NewManager(),
|
||||||
@@ -160,7 +165,7 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
|
|||||||
},
|
},
|
||||||
sshTunnelListener: netpkg.NewInternalListener(),
|
sshTunnelListener: netpkg.NewInternalListener(),
|
||||||
httpVhostRouter: vhost.NewRouters(),
|
httpVhostRouter: vhost.NewRouters(),
|
||||||
authVerifier: auth.NewAuthVerifier(cfg.Auth),
|
auth: authRuntime,
|
||||||
webServer: webServer,
|
webServer: webServer,
|
||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -586,7 +591,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
|||||||
ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch)
|
ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch)
|
||||||
|
|
||||||
// Check auth.
|
// Check auth.
|
||||||
authVerifier := svr.authVerifier
|
authVerifier := svr.auth.Verifier
|
||||||
if internal && loginMsg.ClientSpec.AlwaysAuthPass {
|
if internal && loginMsg.ClientSpec.AlwaysAuthPass {
|
||||||
authVerifier = auth.AlwaysPassVerifier
|
authVerifier = auth.AlwaysPassVerifier
|
||||||
}
|
}
|
||||||
@@ -595,7 +600,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(fatedier): use SessionContext
|
// TODO(fatedier): use SessionContext
|
||||||
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, ctlConn, !internal, loginMsg, svr.cfg)
|
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("create new controller error: %v", err)
|
xl.Warnf("create new controller error: %v", err)
|
||||||
// don't return detailed errors to client
|
// don't return detailed errors to client
|
||||||
|
|||||||
Reference in New Issue
Block a user