From 7526d7a69af9de505f96ccec5ef47b68666a9201 Mon Sep 17 00:00:00 2001 From: fatedier Date: Thu, 25 Dec 2025 00:53:08 +0800 Subject: [PATCH] refactor: separate auth config from runtime and defer token resolution (#5105) --- client/control.go | 12 ++--- client/proxy/proxy.go | 5 +- client/proxy/proxy_manager.go | 7 ++- client/proxy/proxy_wrapper.go | 3 +- client/proxy/sudp.go | 2 +- client/proxy/udp.go | 2 +- client/service.go | 12 ++--- cmd/frpc/sub/proxy.go | 6 ++- cmd/frpc/sub/root.go | 1 + cmd/frps/root.go | 9 +++- cmd/frps/verify.go | 5 +- pkg/auth/auth.go | 63 ++++++++++++++++++++++++ pkg/config/v1/client.go | 13 ----- pkg/config/v1/client_test.go | 71 ++------------------------- pkg/config/v1/server.go | 14 ------ pkg/config/v1/server_test.go | 71 ++------------------------- pkg/config/v1/validation/client.go | 27 +++++----- pkg/config/v1/validation/server.go | 8 ++- pkg/config/v1/validation/validator.go | 28 +++++++++++ server/control.go | 7 ++- server/proxy/http.go | 2 +- server/proxy/proxy.go | 6 ++- server/proxy/udp.go | 2 +- server/service.go | 15 ++++-- 24 files changed, 185 insertions(+), 206 deletions(-) create mode 100644 pkg/config/v1/validation/validator.go diff --git a/client/control.go b/client/control.go index c18ae07c..0f48c36d 100644 --- a/client/control.go +++ b/client/control.go @@ -43,8 +43,8 @@ type SessionContext struct { Conn net.Conn // Indicates whether the connection is encrypted. ConnEncrypted bool - // Sets authentication based on selected method - AuthSetter auth.Setter + // Auth runtime used for login, heartbeats, and encryption. + Auth *auth.ClientAuth // Connector is used to create new connections, which could be real TCP connections or virtual streams. Connector Connector // Virtual net controller @@ -91,7 +91,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro ctl.lastPong.Store(time.Now()) if sessionCtx.ConnEncrypted { - cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, []byte(sessionCtx.Common.Auth.Token)) + cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey()) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro ctl.registerMsgHandlers() ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher) - ctl.pm = proxy.NewManager(ctl.ctx, sessionCtx.Common, ctl.msgTransporter, sessionCtx.VnetController) + ctl.pm = proxy.NewManager(ctl.ctx, sessionCtx.Common, sessionCtx.Auth.EncryptionKey(), ctl.msgTransporter, sessionCtx.VnetController) ctl.vm = visitor.NewManager(ctl.ctx, sessionCtx.RunID, sessionCtx.Common, ctl.connectServer, ctl.msgTransporter, sessionCtx.VnetController) return ctl, nil @@ -133,7 +133,7 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) { m := &msg.NewWorkConn{ RunID: ctl.sessionCtx.RunID, } - if err = ctl.sessionCtx.AuthSetter.SetNewWorkConn(m); err != nil { + if err = ctl.sessionCtx.Auth.Setter.SetNewWorkConn(m); err != nil { xl.Warnf("error during NewWorkConn authentication: %v", err) workConn.Close() return @@ -243,7 +243,7 @@ func (ctl *Control) heartbeatWorker() { sendHeartBeat := func() (bool, error) { xl.Debugf("send heartbeat to server") pingMsg := &msg.Ping{} - if err := ctl.sessionCtx.AuthSetter.SetPing(pingMsg); err != nil { + if err := ctl.sessionCtx.Auth.Setter.SetPing(pingMsg); err != nil { xl.Warnf("error during ping authentication: %v, skip sending ping message", err) return false, err } diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index 876ca579..8faff38d 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -57,6 +57,7 @@ func NewProxy( ctx context.Context, pxyConf v1.ProxyConfigurer, clientCfg *v1.ClientCommonConfig, + encryptionKey []byte, msgTransporter transport.MessageTransporter, vnetController *vnet.Controller, ) (pxy Proxy) { @@ -69,6 +70,7 @@ func NewProxy( baseProxy := BaseProxy{ baseCfg: pxyConf.GetBaseConfig(), clientCfg: clientCfg, + encryptionKey: encryptionKey, limiter: limiter, msgTransporter: msgTransporter, vnetController: vnetController, @@ -86,6 +88,7 @@ func NewProxy( type BaseProxy struct { baseCfg *v1.ProxyBaseConfig clientCfg *v1.ClientCommonConfig + encryptionKey []byte msgTransporter transport.MessageTransporter vnetController *vnet.Controller limiter *rate.Limiter @@ -129,7 +132,7 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { return } } - pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Auth.Token)) + pxy.HandleTCPWorkConnection(conn, m, pxy.encryptionKey) } // Common handler for tcp work connections. diff --git a/client/proxy/proxy_manager.go b/client/proxy/proxy_manager.go index ea5cc553..4615e9a2 100644 --- a/client/proxy/proxy_manager.go +++ b/client/proxy/proxy_manager.go @@ -40,7 +40,8 @@ type Manager struct { closed bool mu sync.RWMutex - clientCfg *v1.ClientCommonConfig + encryptionKey []byte + clientCfg *v1.ClientCommonConfig ctx context.Context } @@ -48,6 +49,7 @@ type Manager struct { func NewManager( ctx context.Context, clientCfg *v1.ClientCommonConfig, + encryptionKey []byte, msgTransporter transport.MessageTransporter, vnetController *vnet.Controller, ) *Manager { @@ -56,6 +58,7 @@ func NewManager( msgTransporter: msgTransporter, vnetController: vnetController, closed: false, + encryptionKey: encryptionKey, clientCfg: clientCfg, ctx: ctx, } @@ -163,7 +166,7 @@ func (pm *Manager) UpdateAll(proxyCfgs []v1.ProxyConfigurer) { for _, cfg := range proxyCfgs { name := cfg.GetBaseConfig().Name if _, ok := pm.proxies[name]; !ok { - pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter, pm.vnetController) + pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.encryptionKey, pm.HandleEvent, pm.msgTransporter, pm.vnetController) if pm.inWorkConnCallback != nil { pxy.SetInWorkConnCallback(pm.inWorkConnCallback) } diff --git a/client/proxy/proxy_wrapper.go b/client/proxy/proxy_wrapper.go index f3f17e2b..4698320a 100644 --- a/client/proxy/proxy_wrapper.go +++ b/client/proxy/proxy_wrapper.go @@ -92,6 +92,7 @@ func NewWrapper( ctx context.Context, cfg v1.ProxyConfigurer, clientCfg *v1.ClientCommonConfig, + encryptionKey []byte, eventHandler event.Handler, msgTransporter transport.MessageTransporter, vnetController *vnet.Controller, @@ -122,7 +123,7 @@ func NewWrapper( xl.Tracef("enable health check monitor") } - pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, pw.msgTransporter, pw.vnetController) + pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, encryptionKey, pw.msgTransporter, pw.vnetController) return pw } diff --git a/client/proxy/sudp.go b/client/proxy/sudp.go index 13741d0d..3a7af19c 100644 --- a/client/proxy/sudp.go +++ b/client/proxy/sudp.go @@ -91,7 +91,7 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) { }) } if pxy.cfg.Transport.UseEncryption { - rwc, err = libio.WithEncryption(rwc, []byte(pxy.clientCfg.Auth.Token)) + rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) if err != nil { conn.Close() xl.Errorf("create encryption stream error: %v", err) diff --git a/client/proxy/udp.go b/client/proxy/udp.go index b70ffe4a..1fca9904 100644 --- a/client/proxy/udp.go +++ b/client/proxy/udp.go @@ -102,7 +102,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) { }) } if pxy.cfg.Transport.UseEncryption { - rwc, err = libio.WithEncryption(rwc, []byte(pxy.clientCfg.Auth.Token)) + rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) if err != nil { conn.Close() xl.Errorf("create encryption stream error: %v", err) diff --git a/client/service.go b/client/service.go index 819a2bc5..b282163e 100644 --- a/client/service.go +++ b/client/service.go @@ -111,8 +111,8 @@ type Service struct { // Uniq id got from frps, it will be attached to loginMsg. runID string - // Sets authentication based on selected method - authSetter auth.Setter + // Auth runtime and encryption materials + auth *auth.ClientAuth // web server for admin UI and apis webServer *httppkg.Server @@ -155,14 +155,14 @@ func NewService(options ServiceOptions) (*Service, error) { webServer = ws } - authSetter, err := auth.NewAuthSetter(options.Common.Auth) + authRuntime, err := auth.BuildClientAuth(&options.Common.Auth) if err != nil { return nil, err } s := &Service{ ctx: context.Background(), - authSetter: authSetter, + auth: authRuntime, webServer: webServer, common: options.Common, configFilePath: options.ConfigFilePath, @@ -296,7 +296,7 @@ func (svr *Service) login() (conn net.Conn, connector Connector, err error) { } // Add auth - if err = svr.authSetter.SetLogin(loginMsg); err != nil { + if err = svr.auth.Setter.SetLogin(loginMsg); err != nil { return } @@ -350,7 +350,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE RunID: svr.runID, Conn: conn, ConnEncrypted: connEncrypted, - AuthSetter: svr.authSetter, + Auth: svr.auth, Connector: connector, VnetController: svr.vnetController, } diff --git a/cmd/frpc/sub/proxy.go b/cmd/frpc/sub/proxy.go index 0748a8b1..ef7fe67f 100644 --- a/cmd/frpc/sub/proxy.go +++ b/cmd/frpc/sub/proxy.go @@ -80,7 +80,8 @@ func NewProxyCommand(name string, c v1.ProxyConfigurer, clientCfg *v1.ClientComm } unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe) - if _, err := validation.ValidateClientCommonConfig(clientCfg, unsafeFeatures); err != nil { + validator := validation.NewConfigValidator(unsafeFeatures) + if _, err := validator.ValidateClientCommonConfig(clientCfg); err != nil { fmt.Println(err) os.Exit(1) } @@ -110,7 +111,8 @@ func NewVisitorCommand(name string, c v1.VisitorConfigurer, clientCfg *v1.Client os.Exit(1) } unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe) - if _, err := validation.ValidateClientCommonConfig(clientCfg, unsafeFeatures); err != nil { + validator := validation.NewConfigValidator(unsafeFeatures) + if _, err := validator.ValidateClientCommonConfig(clientCfg); err != nil { fmt.Println(err) os.Exit(1) } diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index 0750562b..1c2d8d5e 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -142,6 +142,7 @@ func runClient(cfgFilePath string, unsafeFeatures *security.UnsafeFeatures) erro if err != nil { return err } + return startService(cfg, proxyCfgs, visitorCfgs, unsafeFeatures, cfgFilePath) } diff --git a/cmd/frps/root.go b/cmd/frps/root.go index c1bfc880..e6ab008a 100644 --- a/cmd/frps/root.go +++ b/cmd/frps/root.go @@ -18,12 +18,14 @@ import ( "context" "fmt" "os" + "strings" "github.com/spf13/cobra" "github.com/fatedier/frp/pkg/config" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/config/v1/validation" + "github.com/fatedier/frp/pkg/policy/security" "github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/server" @@ -33,6 +35,7 @@ var ( cfgFile string showVersion bool strictConfigMode bool + allowUnsafe []string serverCfg v1.ServerConfig ) @@ -41,6 +44,8 @@ func init() { rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps") rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", true, "strict config parsing mode, unknown fields will cause errors") + rootCmd.PersistentFlags().StringSliceVarP(&allowUnsafe, "allow-unsafe", "", []string{}, + fmt.Sprintf("allowed unsafe features, one or more of: %s", strings.Join(security.ServerUnsafeFeatures, ", "))) config.RegisterServerConfigFlags(rootCmd, &serverCfg) } @@ -77,7 +82,9 @@ var rootCmd = &cobra.Command{ svrCfg = &serverCfg } - warning, err := validation.ValidateServerConfig(svrCfg) + unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe) + validator := validation.NewConfigValidator(unsafeFeatures) + warning, err := validator.ValidateServerConfig(svrCfg) if warning != nil { fmt.Printf("WARNING: %v\n", warning) } diff --git a/cmd/frps/verify.go b/cmd/frps/verify.go index 33ad3f63..7ddef1ab 100644 --- a/cmd/frps/verify.go +++ b/cmd/frps/verify.go @@ -22,6 +22,7 @@ import ( "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config/v1/validation" + "github.com/fatedier/frp/pkg/policy/security" ) func init() { @@ -42,7 +43,9 @@ var verifyCmd = &cobra.Command{ os.Exit(1) } - warning, err := validation.ValidateServerConfig(svrCfg) + unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe) + validator := validation.NewConfigValidator(unsafeFeatures) + warning, err := validator.ValidateServerConfig(svrCfg) if warning != nil { fmt.Printf("WARNING: %v\n", warning) } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 64462a20..366b62ef 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -15,6 +15,7 @@ package auth import ( + "context" "fmt" v1 "github.com/fatedier/frp/pkg/config/v1" @@ -27,6 +28,39 @@ type Setter interface { SetNewWorkConn(*msg.NewWorkConn) error } +type ClientAuth struct { + Setter Setter + key []byte +} + +func (a *ClientAuth) EncryptionKey() []byte { + return a.key +} + +// BuildClientAuth resolves any dynamic auth values and returns a prepared auth runtime. +// Caller must run validation before calling this function. +func BuildClientAuth(cfg *v1.AuthClientConfig) (*ClientAuth, error) { + if cfg == nil { + return nil, fmt.Errorf("auth config is nil") + } + resolved := *cfg + if resolved.Method == v1.AuthMethodToken && resolved.TokenSource != nil { + token, err := resolved.TokenSource.Resolve(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to resolve auth.tokenSource: %w", err) + } + resolved.Token = token + } + setter, err := NewAuthSetter(resolved) + if err != nil { + return nil, err + } + return &ClientAuth{ + Setter: setter, + key: []byte(resolved.Token), + }, nil +} + func NewAuthSetter(cfg v1.AuthClientConfig) (authProvider Setter, err error) { switch cfg.Method { case v1.AuthMethodToken: @@ -52,6 +86,35 @@ type Verifier interface { VerifyNewWorkConn(*msg.NewWorkConn) error } +type ServerAuth struct { + Verifier Verifier + key []byte +} + +func (a *ServerAuth) EncryptionKey() []byte { + return a.key +} + +// BuildServerAuth resolves any dynamic auth values and returns a prepared auth runtime. +// Caller must run validation before calling this function. +func BuildServerAuth(cfg *v1.AuthServerConfig) (*ServerAuth, error) { + if cfg == nil { + return nil, fmt.Errorf("auth config is nil") + } + resolved := *cfg + if resolved.Method == v1.AuthMethodToken && resolved.TokenSource != nil { + token, err := resolved.TokenSource.Resolve(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to resolve auth.tokenSource: %w", err) + } + resolved.Token = token + } + return &ServerAuth{ + Verifier: NewAuthVerifier(resolved), + key: []byte(resolved.Token), + }, nil +} + func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) { switch cfg.Method { case v1.AuthMethodToken: diff --git a/pkg/config/v1/client.go b/pkg/config/v1/client.go index 61bc6ac6..2c5ccc6f 100644 --- a/pkg/config/v1/client.go +++ b/pkg/config/v1/client.go @@ -15,8 +15,6 @@ package v1 import ( - "context" - "fmt" "os" "github.com/samber/lo" @@ -198,17 +196,6 @@ type AuthClientConfig struct { func (c *AuthClientConfig) Complete() error { c.Method = util.EmptyOr(c.Method, "token") - - // Resolve tokenSource during configuration loading - if c.Method == AuthMethodToken && c.TokenSource != nil { - token, err := c.TokenSource.Resolve(context.Background()) - if err != nil { - return fmt.Errorf("failed to resolve auth.tokenSource: %w", err) - } - // Move the resolved token to the Token field and clear TokenSource - c.Token = token - c.TokenSource = nil - } return nil } diff --git a/pkg/config/v1/client_test.go b/pkg/config/v1/client_test.go index 120c4fd4..5473a5f6 100644 --- a/pkg/config/v1/client_test.go +++ b/pkg/config/v1/client_test.go @@ -15,8 +15,6 @@ package v1 import ( - "os" - "path/filepath" "testing" "github.com/samber/lo" @@ -38,68 +36,9 @@ func TestClientConfigComplete(t *testing.T) { } func TestAuthClientConfig_Complete(t *testing.T) { - // Create a temporary file for testing - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test_token") - testContent := "client-token-value" - err := os.WriteFile(testFile, []byte(testContent), 0o600) - require.NoError(t, err) - - tests := []struct { - name string - config AuthClientConfig - expectToken string - expectPanic bool - }{ - { - name: "tokenSource resolved to token", - config: AuthClientConfig{ - Method: AuthMethodToken, - TokenSource: &ValueSource{ - Type: "file", - File: &FileSource{ - Path: testFile, - }, - }, - }, - expectToken: testContent, - expectPanic: false, - }, - { - name: "direct token unchanged", - config: AuthClientConfig{ - Method: AuthMethodToken, - Token: "direct-token", - }, - expectToken: "direct-token", - expectPanic: false, - }, - { - name: "invalid tokenSource should panic", - config: AuthClientConfig{ - Method: AuthMethodToken, - TokenSource: &ValueSource{ - Type: "file", - File: &FileSource{ - Path: "/non/existent/file", - }, - }, - }, - expectPanic: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.expectPanic { - err := tt.config.Complete() - require.Error(t, err) - } else { - err := tt.config.Complete() - require.NoError(t, err) - require.Equal(t, tt.expectToken, tt.config.Token) - require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution") - } - }) - } + require := require.New(t) + cfg := &AuthClientConfig{} + err := cfg.Complete() + require.NoError(err) + require.EqualValues("token", cfg.Method) } diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go index 54aac080..a92aac97 100644 --- a/pkg/config/v1/server.go +++ b/pkg/config/v1/server.go @@ -15,9 +15,6 @@ package v1 import ( - "context" - "fmt" - "github.com/samber/lo" "github.com/fatedier/frp/pkg/config/types" @@ -138,17 +135,6 @@ type AuthServerConfig struct { func (c *AuthServerConfig) Complete() error { c.Method = util.EmptyOr(c.Method, "token") - - // Resolve tokenSource during configuration loading - if c.Method == AuthMethodToken && c.TokenSource != nil { - token, err := c.TokenSource.Resolve(context.Background()) - if err != nil { - return fmt.Errorf("failed to resolve auth.tokenSource: %w", err) - } - // Move the resolved token to the Token field and clear TokenSource - c.Token = token - c.TokenSource = nil - } return nil } diff --git a/pkg/config/v1/server_test.go b/pkg/config/v1/server_test.go index 21d18fb7..cd9381c5 100644 --- a/pkg/config/v1/server_test.go +++ b/pkg/config/v1/server_test.go @@ -15,8 +15,6 @@ package v1 import ( - "os" - "path/filepath" "testing" "github.com/samber/lo" @@ -35,68 +33,9 @@ func TestServerConfigComplete(t *testing.T) { } func TestAuthServerConfig_Complete(t *testing.T) { - // Create a temporary file for testing - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test_token") - testContent := "file-token-value" - err := os.WriteFile(testFile, []byte(testContent), 0o600) - require.NoError(t, err) - - tests := []struct { - name string - config AuthServerConfig - expectToken string - expectPanic bool - }{ - { - name: "tokenSource resolved to token", - config: AuthServerConfig{ - Method: AuthMethodToken, - TokenSource: &ValueSource{ - Type: "file", - File: &FileSource{ - Path: testFile, - }, - }, - }, - expectToken: testContent, - expectPanic: false, - }, - { - name: "direct token unchanged", - config: AuthServerConfig{ - Method: AuthMethodToken, - Token: "direct-token", - }, - expectToken: "direct-token", - expectPanic: false, - }, - { - name: "invalid tokenSource should panic", - config: AuthServerConfig{ - Method: AuthMethodToken, - TokenSource: &ValueSource{ - Type: "file", - File: &FileSource{ - Path: "/non/existent/file", - }, - }, - }, - expectPanic: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.expectPanic { - err := tt.config.Complete() - require.Error(t, err) - } else { - err := tt.config.Complete() - require.NoError(t, err) - require.Equal(t, tt.expectToken, tt.config.Token) - require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution") - } - }) - } + require := require.New(t) + cfg := &AuthServerConfig{} + err := cfg.Complete() + require.NoError(err) + require.EqualValues("token", cfg.Method) } diff --git a/pkg/config/v1/validation/client.go b/pkg/config/v1/validation/client.go index b55fece7..eb4a0253 100644 --- a/pkg/config/v1/validation/client.go +++ b/pkg/config/v1/validation/client.go @@ -27,7 +27,7 @@ import ( "github.com/fatedier/frp/pkg/policy/security" ) -func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) { +func (v *ConfigValidator) ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) { var ( warnings Warning errs error @@ -35,15 +35,15 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *securi validators := []func() (Warning, error){ func() (Warning, error) { return validateFeatureGates(c) }, - func() (Warning, error) { return validateAuthConfig(&c.Auth, unsafeFeatures) }, + func() (Warning, error) { return v.validateAuthConfig(&c.Auth) }, func() (Warning, error) { return nil, validateLogConfig(&c.Log) }, func() (Warning, error) { return nil, validateWebServerConfig(&c.WebServer) }, func() (Warning, error) { return validateTransportConfig(&c.Transport) }, func() (Warning, error) { return validateIncludeFiles(c.IncludeConfigFiles) }, } - for _, v := range validators { - w, err := v() + for _, validator := range validators { + w, err := validator() warnings = AppendError(warnings, w) errs = AppendError(errs, err) } @@ -59,7 +59,7 @@ func validateFeatureGates(c *v1.ClientCommonConfig) (Warning, error) { return nil, nil } -func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) { +func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, error) { var errs error if !slices.Contains(SupportedAuthMethods, c.Method) { errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", SupportedAuthMethods)) @@ -76,9 +76,8 @@ func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeF // Validate tokenSource if specified if c.TokenSource != nil { if c.TokenSource.Type == "exec" { - if !unsafeFeatures.IsEnabled(security.TokenSourceExec) { - errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+ - "To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec)) + if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil { + errs = AppendError(errs, err) } } if err := c.TokenSource.Validate(); err != nil { @@ -86,13 +85,13 @@ func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeF } } - if err := validateOIDCConfig(&c.OIDC, unsafeFeatures); err != nil { + if err := v.validateOIDCConfig(&c.OIDC); err != nil { errs = AppendError(errs, err) } return nil, errs } -func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.UnsafeFeatures) error { +func (v *ConfigValidator) validateOIDCConfig(c *v1.AuthOIDCClientConfig) error { if c.TokenSource == nil { return nil } @@ -104,9 +103,8 @@ func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.Uns errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc")) } if c.TokenSource.Type == "exec" { - if !unsafeFeatures.IsEnabled(security.TokenSourceExec) { - errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+ - "To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec)) + if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil { + errs = AppendError(errs, err) } } if err := c.TokenSource.Validate(); err != nil { @@ -167,9 +165,10 @@ func ValidateAllClientConfig( visitorCfgs []v1.VisitorConfigurer, unsafeFeatures *security.UnsafeFeatures, ) (Warning, error) { + validator := NewConfigValidator(unsafeFeatures) var warnings Warning if c != nil { - warning, err := ValidateClientCommonConfig(c, unsafeFeatures) + warning, err := validator.ValidateClientCommonConfig(c) warnings = AppendError(warnings, warning) if err != nil { return warnings, err diff --git a/pkg/config/v1/validation/server.go b/pkg/config/v1/validation/server.go index 56942272..338ecc82 100644 --- a/pkg/config/v1/validation/server.go +++ b/pkg/config/v1/validation/server.go @@ -21,9 +21,10 @@ import ( "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/policy/security" ) -func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) { +func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) { var ( warnings Warning errs error @@ -42,6 +43,11 @@ func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) { // Validate tokenSource if specified if c.Auth.TokenSource != nil { + if c.Auth.TokenSource.Type == "exec" { + if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil { + errs = AppendError(errs, err) + } + } if err := c.Auth.TokenSource.Validate(); err != nil { errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err)) } diff --git a/pkg/config/v1/validation/validator.go b/pkg/config/v1/validation/validator.go new file mode 100644 index 00000000..1cfe3b21 --- /dev/null +++ b/pkg/config/v1/validation/validator.go @@ -0,0 +1,28 @@ +package validation + +import ( + "fmt" + + "github.com/fatedier/frp/pkg/policy/security" +) + +// ConfigValidator holds the context dependencies for configuration validation. +type ConfigValidator struct { + unsafeFeatures *security.UnsafeFeatures +} + +// NewConfigValidator creates a new ConfigValidator instance. +func NewConfigValidator(unsafeFeatures *security.UnsafeFeatures) *ConfigValidator { + return &ConfigValidator{ + unsafeFeatures: unsafeFeatures, + } +} + +// ValidateUnsafeFeature checks if a specific unsafe feature is enabled. +func (v *ConfigValidator) ValidateUnsafeFeature(feature string) error { + if !v.unsafeFeatures.IsEnabled(feature) { + return fmt.Errorf("unsafe feature %q is not enabled. "+ + "To enable it, ensure it is allowed in the configuration or command line flags", feature) + } + return nil +} diff --git a/server/control.go b/server/control.go index 65d52062..af9c9de3 100644 --- a/server/control.go +++ b/server/control.go @@ -106,6 +106,8 @@ type Control struct { // verifies authentication based on selected method authVerifier auth.Verifier + // key used for connection encryption + encryptionKey []byte // other components can use this to communicate with client msgTransporter transport.MessageTransporter @@ -157,6 +159,7 @@ func NewControl( pxyManager *proxy.Manager, pluginManager *plugin.Manager, authVerifier auth.Verifier, + encryptionKey []byte, ctlConn net.Conn, ctlConnEncrypted bool, loginMsg *msg.Login, @@ -171,6 +174,7 @@ func NewControl( pxyManager: pxyManager, pluginManager: pluginManager, authVerifier: authVerifier, + encryptionKey: encryptionKey, conn: ctlConn, loginMsg: loginMsg, workConnCh: make(chan net.Conn, poolCount+10), @@ -186,7 +190,7 @@ func NewControl( ctl.lastPing.Store(time.Now()) if ctlConnEncrypted { - cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) + cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey) if err != nil { return nil, err } @@ -478,6 +482,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err GetWorkConnFn: ctl.GetWorkConn, Configurer: pxyConf, ServerCfg: ctl.serverCfg, + EncryptionKey: ctl.encryptionKey, }) if err != nil { return remoteAddr, err diff --git a/server/proxy/http.go b/server/proxy/http.go index 9a02dcdd..2c4f1fd4 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -165,7 +165,7 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err var rwc io.ReadWriteCloser = tmpConn if pxy.cfg.Transport.UseEncryption { - rwc, err = libio.WithEncryption(rwc, []byte(pxy.serverCfg.Auth.Token)) + rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) if err != nil { xl.Errorf("create encryption stream error: %v", err) return diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index c7c18c32..564eca28 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -68,6 +68,7 @@ type BaseProxy struct { poolCount int getWorkConnFn GetWorkConnFn serverCfg *v1.ServerConfig + encryptionKey []byte limiter *rate.Limiter userInfo plugin.UserInfo loginMsg *msg.Login @@ -213,7 +214,6 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) { xl := xlog.FromContextSafe(pxy.Context()) defer userConn.Close() - serverCfg := pxy.serverCfg cfg := pxy.configurer.GetBaseConfig() // server plugin hook rc := pxy.GetResourceController() @@ -240,7 +240,7 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) { xl.Tracef("handler user tcp connection, use_encryption: %t, use_compression: %t", cfg.Transport.UseEncryption, cfg.Transport.UseCompression) if cfg.Transport.UseEncryption { - local, err = libio.WithEncryption(local, []byte(serverCfg.Auth.Token)) + local, err = libio.WithEncryption(local, pxy.encryptionKey) if err != nil { xl.Errorf("create encryption stream error: %v", err) return @@ -279,6 +279,7 @@ type Options struct { GetWorkConnFn GetWorkConnFn Configurer v1.ProxyConfigurer ServerCfg *v1.ServerConfig + EncryptionKey []byte } func NewProxy(ctx context.Context, options *Options) (pxy Proxy, err error) { @@ -298,6 +299,7 @@ func NewProxy(ctx context.Context, options *Options) (pxy Proxy, err error) { poolCount: options.PoolCount, getWorkConnFn: options.GetWorkConnFn, serverCfg: options.ServerCfg, + encryptionKey: options.EncryptionKey, limiter: limiter, xl: xl, ctx: xlog.NewContext(ctx, xl), diff --git a/server/proxy/udp.go b/server/proxy/udp.go index 53a07d52..3751dc9b 100644 --- a/server/proxy/udp.go +++ b/server/proxy/udp.go @@ -205,7 +205,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) { var rwc io.ReadWriteCloser = workConn if pxy.cfg.Transport.UseEncryption { - rwc, err = libio.WithEncryption(rwc, []byte(pxy.serverCfg.Auth.Token)) + rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) if err != nil { xl.Errorf("create encryption stream error: %v", err) workConn.Close() diff --git a/server/service.go b/server/service.go index 1fe882d2..de3af837 100644 --- a/server/service.go +++ b/server/service.go @@ -113,8 +113,8 @@ type Service struct { sshTunnelGateway *ssh.Gateway - // Verifies authentication based on selected method - authVerifier auth.Verifier + // Auth runtime and encryption materials + auth *auth.ServerAuth tlsConfig *tls.Config @@ -149,6 +149,11 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) { } } + authRuntime, err := auth.BuildServerAuth(&cfg.Auth) + if err != nil { + return nil, err + } + svr := &Service{ ctlManager: NewControlManager(), pxyManager: proxy.NewManager(), @@ -160,7 +165,7 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) { }, sshTunnelListener: netpkg.NewInternalListener(), httpVhostRouter: vhost.NewRouters(), - authVerifier: auth.NewAuthVerifier(cfg.Auth), + auth: authRuntime, webServer: webServer, tlsConfig: tlsConfig, cfg: cfg, @@ -586,7 +591,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch) // Check auth. - authVerifier := svr.authVerifier + authVerifier := svr.auth.Verifier if internal && loginMsg.ClientSpec.AlwaysAuthPass { authVerifier = auth.AlwaysPassVerifier } @@ -595,7 +600,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter } // TODO(fatedier): use SessionContext - ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, ctlConn, !internal, loginMsg, svr.cfg) + ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg) if err != nil { xl.Warnf("create new controller error: %v", err) // don't return detailed errors to client