mirror of
https://github.com/fatedier/frp.git
synced 2026-01-10 10:13:16 +00:00
add tokenSource support for auth configuration (#4865)
This commit is contained in:
@@ -212,7 +212,9 @@ func LoadServerConfig(path string, strict bool) (*v1.ServerConfig, bool, error)
|
||||
}
|
||||
}
|
||||
if svrCfg != nil {
|
||||
svrCfg.Complete()
|
||||
if err := svrCfg.Complete(); err != nil {
|
||||
return nil, isLegacyFormat, err
|
||||
}
|
||||
}
|
||||
return svrCfg, isLegacyFormat, nil
|
||||
}
|
||||
@@ -280,7 +282,9 @@ func LoadClientConfig(path string, strict bool) (
|
||||
}
|
||||
|
||||
if cliCfg != nil {
|
||||
cliCfg.Complete()
|
||||
if err := cliCfg.Complete(); err != nil {
|
||||
return nil, nil, nil, isLegacyFormat, err
|
||||
}
|
||||
}
|
||||
for _, c := range proxyCfgs {
|
||||
c.Complete(cliCfg.User)
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -77,18 +79,21 @@ type ClientCommonConfig struct {
|
||||
IncludeConfigFiles []string `json:"includes,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ClientCommonConfig) Complete() {
|
||||
func (c *ClientCommonConfig) Complete() error {
|
||||
c.ServerAddr = util.EmptyOr(c.ServerAddr, "0.0.0.0")
|
||||
c.ServerPort = util.EmptyOr(c.ServerPort, 7000)
|
||||
c.LoginFailExit = util.EmptyOr(c.LoginFailExit, lo.ToPtr(true))
|
||||
c.NatHoleSTUNServer = util.EmptyOr(c.NatHoleSTUNServer, "stun.easyvoip.com:3478")
|
||||
|
||||
c.Auth.Complete()
|
||||
if err := c.Auth.Complete(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Log.Complete()
|
||||
c.Transport.Complete()
|
||||
c.WebServer.Complete()
|
||||
|
||||
c.UDPPacketSize = util.EmptyOr(c.UDPPacketSize, 1500)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientTransportConfig struct {
|
||||
@@ -184,12 +189,27 @@ type AuthClientConfig struct {
|
||||
// Token specifies the authorization token used to create keys to be sent
|
||||
// to the server. The server must have a matching token for authorization
|
||||
// to succeed. By default, this value is "".
|
||||
Token string `json:"token,omitempty"`
|
||||
OIDC AuthOIDCClientConfig `json:"oidc,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
// TokenSource specifies a dynamic source for the authorization token.
|
||||
// This is mutually exclusive with Token field.
|
||||
TokenSource *ValueSource `json:"tokenSource,omitempty"`
|
||||
OIDC AuthOIDCClientConfig `json:"oidc,omitempty"`
|
||||
}
|
||||
|
||||
func (c *AuthClientConfig) Complete() {
|
||||
func (c *AuthClientConfig) Complete() error {
|
||||
c.Method = util.EmptyOr(c.Method, "token")
|
||||
|
||||
// Resolve tokenSource during configuration loading
|
||||
if c.Method == AuthMethodToken && c.TokenSource != nil {
|
||||
token, err := c.TokenSource.Resolve(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
||||
}
|
||||
// Move the resolved token to the Token field and clear TokenSource
|
||||
c.Token = token
|
||||
c.TokenSource = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthOIDCClientConfig struct {
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -24,7 +26,8 @@ import (
|
||||
func TestClientConfigComplete(t *testing.T) {
|
||||
require := require.New(t)
|
||||
c := &ClientConfig{}
|
||||
c.Complete()
|
||||
err := c.Complete()
|
||||
require.NoError(err)
|
||||
|
||||
require.EqualValues("token", c.Auth.Method)
|
||||
require.Equal(true, lo.FromPtr(c.Transport.TCPMux))
|
||||
@@ -33,3 +36,70 @@ func TestClientConfigComplete(t *testing.T) {
|
||||
require.Equal(true, lo.FromPtr(c.Transport.TLS.DisableCustomTLSFirstByte))
|
||||
require.NotEmpty(c.NatHoleSTUNServer)
|
||||
}
|
||||
|
||||
func TestAuthClientConfig_Complete(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test_token")
|
||||
testContent := "client-token-value"
|
||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AuthClientConfig
|
||||
expectToken string
|
||||
expectPanic bool
|
||||
}{
|
||||
{
|
||||
name: "tokenSource resolved to token",
|
||||
config: AuthClientConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: testFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectToken: testContent,
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "direct token unchanged",
|
||||
config: AuthClientConfig{
|
||||
Method: AuthMethodToken,
|
||||
Token: "direct-token",
|
||||
},
|
||||
expectToken: "direct-token",
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tokenSource should panic",
|
||||
config: AuthClientConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "/non/existent/file",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectPanic: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectPanic {
|
||||
err := tt.config.Complete()
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
err := tt.config.Complete()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectToken, tt.config.Token)
|
||||
require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
@@ -98,8 +101,10 @@ type ServerConfig struct {
|
||||
HTTPPlugins []HTTPPluginOptions `json:"httpPlugins,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ServerConfig) Complete() {
|
||||
c.Auth.Complete()
|
||||
func (c *ServerConfig) Complete() error {
|
||||
if err := c.Auth.Complete(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Log.Complete()
|
||||
c.Transport.Complete()
|
||||
c.WebServer.Complete()
|
||||
@@ -120,17 +125,31 @@ func (c *ServerConfig) Complete() {
|
||||
c.UserConnTimeout = util.EmptyOr(c.UserConnTimeout, 10)
|
||||
c.UDPPacketSize = util.EmptyOr(c.UDPPacketSize, 1500)
|
||||
c.NatHoleAnalysisDataReserveHours = util.EmptyOr(c.NatHoleAnalysisDataReserveHours, 7*24)
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthServerConfig struct {
|
||||
Method AuthMethod `json:"method,omitempty"`
|
||||
AdditionalScopes []AuthScope `json:"additionalScopes,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TokenSource *ValueSource `json:"tokenSource,omitempty"`
|
||||
OIDC AuthOIDCServerConfig `json:"oidc,omitempty"`
|
||||
}
|
||||
|
||||
func (c *AuthServerConfig) Complete() {
|
||||
func (c *AuthServerConfig) Complete() error {
|
||||
c.Method = util.EmptyOr(c.Method, "token")
|
||||
|
||||
// Resolve tokenSource during configuration loading
|
||||
if c.Method == AuthMethodToken && c.TokenSource != nil {
|
||||
token, err := c.TokenSource.Resolve(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve auth.tokenSource: %w", err)
|
||||
}
|
||||
// Move the resolved token to the Token field and clear TokenSource
|
||||
c.Token = token
|
||||
c.TokenSource = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthOIDCServerConfig struct {
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -24,9 +26,77 @@ import (
|
||||
func TestServerConfigComplete(t *testing.T) {
|
||||
require := require.New(t)
|
||||
c := &ServerConfig{}
|
||||
c.Complete()
|
||||
err := c.Complete()
|
||||
require.NoError(err)
|
||||
|
||||
require.EqualValues("token", c.Auth.Method)
|
||||
require.Equal(true, lo.FromPtr(c.Transport.TCPMux))
|
||||
require.Equal(true, lo.FromPtr(c.DetailedErrorsToClient))
|
||||
}
|
||||
|
||||
func TestAuthServerConfig_Complete(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test_token")
|
||||
testContent := "file-token-value"
|
||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AuthServerConfig
|
||||
expectToken string
|
||||
expectPanic bool
|
||||
}{
|
||||
{
|
||||
name: "tokenSource resolved to token",
|
||||
config: AuthServerConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: testFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectToken: testContent,
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "direct token unchanged",
|
||||
config: AuthServerConfig{
|
||||
Method: AuthMethodToken,
|
||||
Token: "direct-token",
|
||||
},
|
||||
expectToken: "direct-token",
|
||||
expectPanic: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tokenSource should panic",
|
||||
config: AuthServerConfig{
|
||||
Method: AuthMethodToken,
|
||||
TokenSource: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "/non/existent/file",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectPanic: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectPanic {
|
||||
err := tt.config.Complete()
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
err := tt.config.Complete()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectToken, tt.config.Token)
|
||||
require.Nil(t, tt.config.TokenSource, "TokenSource should be cleared after resolution")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,18 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig) (Warning, error) {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||
}
|
||||
|
||||
// Validate token/tokenSource mutual exclusivity
|
||||
if c.Auth.Token != "" && c.Auth.TokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.Auth.TokenSource != nil {
|
||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateLogConfig(&c.Log); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,18 @@ func ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||
}
|
||||
|
||||
// Validate token/tokenSource mutual exclusivity
|
||||
if c.Auth.Token != "" && c.Auth.TokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.Auth.TokenSource != nil {
|
||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateLogConfig(&c.Log); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
|
||||
93
pkg/config/v1/value_source.go
Normal file
93
pkg/config/v1/value_source.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright 2025 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValueSource provides a way to dynamically resolve configuration values
|
||||
// from various sources like files, environment variables, or external services.
|
||||
type ValueSource struct {
|
||||
Type string `json:"type"`
|
||||
File *FileSource `json:"file,omitempty"`
|
||||
}
|
||||
|
||||
// FileSource specifies how to load a value from a file.
|
||||
type FileSource struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Validate validates the ValueSource configuration.
|
||||
func (v *ValueSource) Validate() error {
|
||||
if v == nil {
|
||||
return errors.New("valueSource cannot be nil")
|
||||
}
|
||||
|
||||
switch v.Type {
|
||||
case "file":
|
||||
if v.File == nil {
|
||||
return errors.New("file configuration is required when type is 'file'")
|
||||
}
|
||||
return v.File.Validate()
|
||||
default:
|
||||
return fmt.Errorf("unsupported value source type: %s (only 'file' is supported)", v.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve resolves the value from the configured source.
|
||||
func (v *ValueSource) Resolve(ctx context.Context) (string, error) {
|
||||
if err := v.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch v.Type {
|
||||
case "file":
|
||||
return v.File.Resolve(ctx)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported value source type: %s", v.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the FileSource configuration.
|
||||
func (f *FileSource) Validate() error {
|
||||
if f == nil {
|
||||
return errors.New("fileSource cannot be nil")
|
||||
}
|
||||
|
||||
if f.Path == "" {
|
||||
return errors.New("file path cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve reads and returns the content from the specified file.
|
||||
func (f *FileSource) Resolve(_ context.Context) (string, error) {
|
||||
if err := f.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(f.Path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file %s: %v", f.Path, err)
|
||||
}
|
||||
|
||||
// Trim whitespace, which is important for file-based tokens
|
||||
return strings.TrimSpace(string(content)), nil
|
||||
}
|
||||
246
pkg/config/v1/value_source_test.go
Normal file
246
pkg/config/v1/value_source_test.go
Normal file
@@ -0,0 +1,246 @@
|
||||
// Copyright 2025 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValueSource_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vs *ValueSource
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil valueSource",
|
||||
vs: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported type",
|
||||
vs: &ValueSource{
|
||||
Type: "unsupported",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "file type without file config",
|
||||
vs: &ValueSource{
|
||||
Type: "file",
|
||||
File: nil,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid file type with absolute path",
|
||||
vs: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "/tmp/test",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid file type with relative path",
|
||||
vs: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "configs/token",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.vs.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValueSource.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSource_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fs *FileSource
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil fileSource",
|
||||
fs: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
fs: &FileSource{
|
||||
Path: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "relative path (allowed)",
|
||||
fs: &FileSource{
|
||||
Path: "relative/path",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path",
|
||||
fs: &FileSource{
|
||||
Path: "/absolute/path",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.fs.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FileSource.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSource_Resolve(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test_token")
|
||||
testContent := "test-token-value\n\t "
|
||||
expectedContent := "test-token-value"
|
||||
|
||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fs *FileSource
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid file path",
|
||||
fs: &FileSource{
|
||||
Path: testFile,
|
||||
},
|
||||
want: expectedContent,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
fs: &FileSource{
|
||||
Path: "/non/existent/file",
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt (should fail validation)",
|
||||
fs: &FileSource{
|
||||
Path: "../../../etc/passwd",
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.fs.Resolve(context.Background())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FileSource.Resolve() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("FileSource.Resolve() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueSource_Resolve(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test_token")
|
||||
testContent := "test-token-value"
|
||||
|
||||
err := os.WriteFile(testFile, []byte(testContent), 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
vs *ValueSource
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid file type",
|
||||
vs: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: testFile,
|
||||
},
|
||||
},
|
||||
want: testContent,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported type",
|
||||
vs: &ValueSource{
|
||||
Type: "unsupported",
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "file type with path traversal",
|
||||
vs: &ValueSource{
|
||||
Type: "file",
|
||||
File: &FileSource{
|
||||
Path: "../../../etc/passwd",
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.vs.Resolve(ctx)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValueSource.Resolve() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ValueSource.Resolve() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -105,7 +105,10 @@ func (s *TunnelServer) Run() error {
|
||||
s.writeToClient(err.Error())
|
||||
return fmt.Errorf("parse flags from ssh client error: %v", err)
|
||||
}
|
||||
clientCfg.Complete()
|
||||
if err := clientCfg.Complete(); err != nil {
|
||||
s.writeToClient(fmt.Sprintf("failed to complete client config: %v", err))
|
||||
return fmt.Errorf("complete client config error: %v", err)
|
||||
}
|
||||
if sshConn.Permissions != nil {
|
||||
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
|
||||
}
|
||||
|
||||
@@ -37,7 +37,9 @@ type Client struct {
|
||||
|
||||
func NewClient(options ClientOptions) (*Client, error) {
|
||||
if options.Common != nil {
|
||||
options.Common.Complete()
|
||||
if err := options.Common.Complete(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ln := netpkg.NewInternalListener()
|
||||
|
||||
Reference in New Issue
Block a user