Compare commits

..

1 Commits

4 changed files with 30 additions and 38 deletions

View File

@ -51,7 +51,7 @@ func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) {
authVerifier = NewTokenAuth(cfg.AdditionalScopes, cfg.Token) authVerifier = NewTokenAuth(cfg.AdditionalScopes, cfg.Token)
case v1.AuthMethodOIDC: case v1.AuthMethodOIDC:
tokenVerifier := NewTokenVerifier(cfg.OIDC) tokenVerifier := NewTokenVerifier(cfg.OIDC)
authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedClaims) authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedHostedDomains)
} }
return authVerifier return authVerifier
} }

View File

@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"slices" "slices"
"strconv"
"strings" "strings"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
@ -112,8 +111,8 @@ type OidcAuthConsumer struct {
verifier TokenVerifier verifier TokenVerifier
subjectsFromLogin []string subjectsFromLogin []string
// allowedClaims specifies a map of allowed claims for the OIDC token. // allowedHostedDomains specifies a list of allowed hosted domains for the "hd" claim in the token.
allowedClaims map[string]string allowedHostedDomains []string
} }
func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier { func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@ -130,19 +129,19 @@ func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
return provider.Verifier(&verifierConf) return provider.Verifier(&verifierConf)
} }
func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedClaims map[string]string) *OidcAuthConsumer { func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedHostedDomains []string) *OidcAuthConsumer {
return &OidcAuthConsumer{ return &OidcAuthConsumer{
additionalAuthScopes: additionalAuthScopes, additionalAuthScopes: additionalAuthScopes,
verifier: verifier, verifier: verifier,
subjectsFromLogin: []string{}, subjectsFromLogin: []string{},
allowedClaims: allowedClaims, allowedHostedDomains: allowedHostedDomains,
} }
} }
func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) { func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
// Verify allowed claims if configured. // Verify hosted domain (hd claim).
if len(auth.allowedClaims) > 0 { if len(auth.allowedHostedDomains) > 0 {
// Decode token without verifying signature. // Decode token without verifying signature to retrieved 'hd' claim.
parts := strings.Split(loginMsg.PrivilegeKey, ".") parts := strings.Split(loginMsg.PrivilegeKey, ".")
if len(parts) != 3 { if len(parts) != 3 {
return fmt.Errorf("invalid OIDC token format") return fmt.Errorf("invalid OIDC token format")
@ -158,32 +157,24 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err) return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err)
} }
// Iterate over allowed claims and attempt to verify. hd, ok := claims["hd"].(string)
for claimName, expectedValue := range auth.allowedClaims { if !ok {
claimValue, ok := claims[claimName] return fmt.Errorf("OIDC token missing required 'hd' claim")
if !ok { }
return fmt.Errorf("OIDC token missing required claim: %s", claimName)
}
if strClaimValue, ok := claimValue.(string); ok { found := false
if strClaimValue != expectedValue { for _, domain := range auth.allowedHostedDomains {
return fmt.Errorf("OIDC token claim '%s' value [%s] does not match expected value [%s]", claimName, strClaimValue, expectedValue) if hd == domain {
} found = true
} else if intClaimValue, ok := claimValue.(int); ok { break
expectedIntValue, err := strconv.Atoi(expectedValue)
if err != nil {
return fmt.Errorf("OIDC token claim '%s' is number, expected value [%s] not parseable", claimName, expectedValue)
}
if intClaimValue != expectedIntValue {
return fmt.Errorf("OIDC token claim '%s' value [%d] does not match expected value [%d]", claimName, intClaimValue, expectedIntValue)
}
} else {
return fmt.Errorf("claim %s is of unsupported type", claimName)
} }
} }
if !found {
return fmt.Errorf("OIDC token 'hd' claim [%s] is not in allowed list", hd)
}
} }
// If claim verification passes, proceed with standard verification. // If hd check passes, proceed with standard verification.
token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey) token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey)
if err != nil { if err != nil {
return fmt.Errorf("invalid OIDC token in login: %v", err) return fmt.Errorf("invalid OIDC token in login: %v", err)

View File

@ -23,7 +23,7 @@ func (m *mockTokenVerifier) Verify(ctx context.Context, subject string) (*oidc.I
func TestPingWithEmptySubjectFromLoginFails(t *testing.T) { func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
r := require.New(t) r := require.New(t)
consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{}) consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
err := consumer.VerifyPing(&msg.Ping{ err := consumer.VerifyPing(&msg.Ping{
PrivilegeKey: "ping-without-login", PrivilegeKey: "ping-without-login",
Timestamp: time.Now().UnixMilli(), Timestamp: time.Now().UnixMilli(),
@ -34,7 +34,7 @@ func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) { func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
r := require.New(t) r := require.New(t)
consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{}) consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
err := consumer.VerifyLogin(&msg.Login{ err := consumer.VerifyLogin(&msg.Login{
PrivilegeKey: "ping-after-login", PrivilegeKey: "ping-after-login",
}) })
@ -49,7 +49,7 @@ func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) { func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
r := require.New(t) r := require.New(t)
consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{}) consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
err := consumer.VerifyLogin(&msg.Login{ err := consumer.VerifyLogin(&msg.Login{
PrivilegeKey: "login-with-first-subject", PrivilegeKey: "login-with-first-subject",
}) })

View File

@ -147,14 +147,15 @@ type AuthOIDCServerConfig struct {
// SkipIssuerCheck specifies whether to skip checking if the OIDC token's // SkipIssuerCheck specifies whether to skip checking if the OIDC token's
// issuer claim matches the issuer specified in OidcIssuer. // issuer claim matches the issuer specified in OidcIssuer.
SkipIssuerCheck bool `json:"skipIssuerCheck,omitempty"` SkipIssuerCheck bool `json:"skipIssuerCheck,omitempty"`
// AllowedClaims specifies a map of allowed claims for the OIDC token. // AllowedHostedDomains specifies a list of allowed hosted domains for the
AllowedClaims map[string]string `json:"allowedClaims,omitempty"` // "hd" claim in the token.
AllowedHostedDomains []string `json:"allowedHostedDomains,omitempty"`
} }
func (c *AuthOIDCServerConfig) Complete() { func (c *AuthOIDCServerConfig) Complete() {
// Ensure AllowedClaims is at least an empty map and not nil // Ensure AllowedHostedDomains is an empty slice and not nil
if c.AllowedClaims == nil { if c.AllowedHostedDomains == nil {
c.AllowedClaims = map[string]string{} c.AllowedHostedDomains = []string{}
} }
} }