mirror of
https://github.com/fatedier/frp.git
synced 2026-01-11 22:23:12 +00:00
refactor: separate auth config from runtime and defer token resolution (#5105)
This commit is contained in:
@@ -15,8 +15,6 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -198,17 +196,6 @@ type AuthClientConfig struct {
|
||||
|
||||
func (c *AuthClientConfig) Complete() error {
|
||||
c.Method = util.EmptyOr(c.Method, "token")
|
||||
|
||||
// Resolve tokenSource during configuration loading
|
||||
if c.Method == AuthMethodToken && c.TokenSource != nil {
|
||||
token, err := c.TokenSource.Resolve(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
||||
}
|
||||
// Move the resolved token to the Token field and clear TokenSource
|
||||
c.Token = token
|
||||
c.TokenSource = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -38,68 +36,9 @@ func TestClientConfigComplete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthClientConfig_Complete(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test_token")
|
||||
testContent := "client-token-value"
|
||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AuthClientConfig
|
||||
expectToken string
|
||||
expectPanic bool
|
||||
}{
|
||||
{
|
||||
name: "tokenSource resolved to token",
|
||||
config: AuthClientConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: testFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectToken: testContent,
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "direct token unchanged",
|
||||
config: AuthClientConfig{
|
||||
Method: AuthMethodToken,
|
||||
Token: "direct-token",
|
||||
},
|
||||
expectToken: "direct-token",
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tokenSource should panic",
|
||||
config: AuthClientConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "/non/existent/file",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectPanic: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectPanic {
|
||||
err := tt.config.Complete()
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
err := tt.config.Complete()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectToken, tt.config.Token)
|
||||
require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution")
|
||||
}
|
||||
})
|
||||
}
|
||||
require := require.New(t)
|
||||
cfg := &AuthClientConfig{}
|
||||
err := cfg.Complete()
|
||||
require.NoError(err)
|
||||
require.EqualValues("token", cfg.Method)
|
||||
}
|
||||
|
||||
@@ -15,9 +15,6 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
@@ -138,17 +135,6 @@ type AuthServerConfig struct {
|
||||
|
||||
func (c *AuthServerConfig) Complete() error {
|
||||
c.Method = util.EmptyOr(c.Method, "token")
|
||||
|
||||
// Resolve tokenSource during configuration loading
|
||||
if c.Method == AuthMethodToken && c.TokenSource != nil {
|
||||
token, err := c.TokenSource.Resolve(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
||||
}
|
||||
// Move the resolved token to the Token field and clear TokenSource
|
||||
c.Token = token
|
||||
c.TokenSource = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -35,68 +33,9 @@ func TestServerConfigComplete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthServerConfig_Complete(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test_token")
|
||||
testContent := "file-token-value"
|
||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AuthServerConfig
|
||||
expectToken string
|
||||
expectPanic bool
|
||||
}{
|
||||
{
|
||||
name: "tokenSource resolved to token",
|
||||
config: AuthServerConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: testFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectToken: testContent,
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "direct token unchanged",
|
||||
config: AuthServerConfig{
|
||||
Method: AuthMethodToken,
|
||||
Token: "direct-token",
|
||||
},
|
||||
expectToken: "direct-token",
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tokenSource should panic",
|
||||
config: AuthServerConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "/non/existent/file",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectPanic: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectPanic {
|
||||
err := tt.config.Complete()
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
err := tt.config.Complete()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectToken, tt.config.Token)
|
||||
require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution")
|
||||
}
|
||||
})
|
||||
}
|
||||
require := require.New(t)
|
||||
cfg := &AuthServerConfig{}
|
||||
err := cfg.Complete()
|
||||
require.NoError(err)
|
||||
require.EqualValues("token", cfg.Method)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
|
||||
func (v *ConfigValidator) ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) {
|
||||
var (
|
||||
warnings Warning
|
||||
errs error
|
||||
@@ -35,15 +35,15 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *securi
|
||||
|
||||
validators := []func() (Warning, error){
|
||||
func() (Warning, error) { return validateFeatureGates(c) },
|
||||
func() (Warning, error) { return validateAuthConfig(&c.Auth, unsafeFeatures) },
|
||||
func() (Warning, error) { return v.validateAuthConfig(&c.Auth) },
|
||||
func() (Warning, error) { return nil, validateLogConfig(&c.Log) },
|
||||
func() (Warning, error) { return nil, validateWebServerConfig(&c.WebServer) },
|
||||
func() (Warning, error) { return validateTransportConfig(&c.Transport) },
|
||||
func() (Warning, error) { return validateIncludeFiles(c.IncludeConfigFiles) },
|
||||
}
|
||||
|
||||
for _, v := range validators {
|
||||
w, err := v()
|
||||
for _, validator := range validators {
|
||||
w, err := validator()
|
||||
warnings = AppendError(warnings, w)
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func validateFeatureGates(c *v1.ClientCommonConfig) (Warning, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
|
||||
func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, error) {
|
||||
var errs error
|
||||
if !slices.Contains(SupportedAuthMethods, c.Method) {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", SupportedAuthMethods))
|
||||
@@ -76,9 +76,8 @@ func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeF
|
||||
// Validate tokenSource if specified
|
||||
if c.TokenSource != nil {
|
||||
if c.TokenSource.Type == "exec" {
|
||||
if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
|
||||
errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
|
||||
"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.TokenSource.Validate(); err != nil {
|
||||
@@ -86,13 +85,13 @@ func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeF
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateOIDCConfig(&c.OIDC, unsafeFeatures); err != nil {
|
||||
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.UnsafeFeatures) error {
|
||||
func (v *ConfigValidator) validateOIDCConfig(c *v1.AuthOIDCClientConfig) error {
|
||||
if c.TokenSource == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -104,9 +103,8 @@ func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.Uns
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc"))
|
||||
}
|
||||
if c.TokenSource.Type == "exec" {
|
||||
if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
|
||||
errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
|
||||
"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.TokenSource.Validate(); err != nil {
|
||||
@@ -167,9 +165,10 @@ func ValidateAllClientConfig(
|
||||
visitorCfgs []v1.VisitorConfigurer,
|
||||
unsafeFeatures *security.UnsafeFeatures,
|
||||
) (Warning, error) {
|
||||
validator := NewConfigValidator(unsafeFeatures)
|
||||
var warnings Warning
|
||||
if c != nil {
|
||||
warning, err := ValidateClientCommonConfig(c, unsafeFeatures)
|
||||
warning, err := validator.ValidateClientCommonConfig(c)
|
||||
warnings = AppendError(warnings, warning)
|
||||
if err != nil {
|
||||
return warnings, err
|
||||
|
||||
@@ -21,9 +21,10 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||
var (
|
||||
warnings Warning
|
||||
errs error
|
||||
@@ -42,6 +43,11 @@ func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.Auth.TokenSource != nil {
|
||||
if c.Auth.TokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
|
||||
28
pkg/config/v1/validation/validator.go
Normal file
28
pkg/config/v1/validation/validator.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
// ConfigValidator holds the context dependencies for configuration validation.
|
||||
type ConfigValidator struct {
|
||||
unsafeFeatures *security.UnsafeFeatures
|
||||
}
|
||||
|
||||
// NewConfigValidator creates a new ConfigValidator instance.
|
||||
func NewConfigValidator(unsafeFeatures *security.UnsafeFeatures) *ConfigValidator {
|
||||
return &ConfigValidator{
|
||||
unsafeFeatures: unsafeFeatures,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUnsafeFeature checks if a specific unsafe feature is enabled.
|
||||
func (v *ConfigValidator) ValidateUnsafeFeature(feature string) error {
|
||||
if !v.unsafeFeatures.IsEnabled(feature) {
|
||||
return fmt.Errorf("unsafe feature %q is not enabled. "+
|
||||
"To enable it, ensure it is allowed in the configuration or command line flags", feature)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user