From f62cd91f09cc975f88660a2cd70232cf1533c353 Mon Sep 17 00:00:00 2001
From: foresturquhart <forest.urquhart@gmail.com>
Date: Fri, 7 Feb 2025 11:16:38 +0000
Subject: [PATCH] Change from verifying hosted domains to verifying claims

---
 pkg/auth/auth.go        |  2 +-
 pkg/auth/oidc.go        | 51 ++++++++++++++++++++++++-----------------
 pkg/auth/oidc_test.go   |  6 ++---
 pkg/config/v1/server.go | 11 ++++-----
 4 files changed, 39 insertions(+), 31 deletions(-)

diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go
index 176b4d9f..ca9cfc51 100644
--- a/pkg/auth/auth.go
+++ b/pkg/auth/auth.go
@@ -51,7 +51,7 @@ func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) {
 		authVerifier = NewTokenAuth(cfg.AdditionalScopes, cfg.Token)
 	case v1.AuthMethodOIDC:
 		tokenVerifier := NewTokenVerifier(cfg.OIDC)
-		authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedHostedDomains)
+		authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedClaims)
 	}
 	return authVerifier
 }
diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go
index ed5bb543..6b926d29 100644
--- a/pkg/auth/oidc.go
+++ b/pkg/auth/oidc.go
@@ -20,6 +20,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"slices"
+	"strconv"
 	"strings"
 
 	"github.com/coreos/go-oidc/v3/oidc"
@@ -111,8 +112,8 @@ type OidcAuthConsumer struct {
 	verifier          TokenVerifier
 	subjectsFromLogin []string
 
-	// allowedHostedDomains specifies a list of allowed hosted domains for the "hd" claim in the token.
-	allowedHostedDomains []string
+	// allowedClaims specifies a map of allowed claims for the OIDC token.
+	allowedClaims map[string]string
 }
 
 func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@@ -129,19 +130,19 @@ func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
 	return provider.Verifier(&verifierConf)
 }
 
-func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedHostedDomains []string) *OidcAuthConsumer {
+func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedClaims map[string]string) *OidcAuthConsumer {
 	return &OidcAuthConsumer{
 		additionalAuthScopes: additionalAuthScopes,
 		verifier:             verifier,
 		subjectsFromLogin:    []string{},
-		allowedHostedDomains: allowedHostedDomains,
+		allowedClaims:        allowedClaims,
 	}
 }
 
 func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
-	// Verify hosted domain (hd claim).
-	if len(auth.allowedHostedDomains) > 0 {
-		// Decode token without verifying signature to retrieved 'hd' claim.
+	// Verify allowed claims if configured.
+	if len(auth.allowedClaims) > 0 {
+		// Decode token without verifying signature.
 		parts := strings.Split(loginMsg.PrivilegeKey, ".")
 		if len(parts) != 3 {
 			return fmt.Errorf("invalid OIDC token format")
@@ -157,24 +158,32 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
 			return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err)
 		}
 
-		hd, ok := claims["hd"].(string)
-		if !ok {
-			return fmt.Errorf("OIDC token missing required 'hd' claim")
-		}
-
-		found := false
-		for _, domain := range auth.allowedHostedDomains {
-			if hd == domain {
-				found = true
-				break
+		// Iterate over allowed claims and attempt to verify.
+		for claimName, expectedValue := range auth.allowedClaims {
+			claimValue, ok := claims[claimName]
+			if !ok {
+				return fmt.Errorf("OIDC token missing required claim: %s", claimName)
+			}
+
+			if strClaimValue, ok := claimValue.(string); ok {
+				if strClaimValue != expectedValue {
+					return fmt.Errorf("OIDC token claim '%s' value [%s] does not match expected value [%s]", claimName, strClaimValue, expectedValue)
+				}
+			} else if intClaimValue, ok := claimValue.(int); ok {
+				expectedIntValue, err := strconv.Atoi(expectedValue)
+				if err != nil {
+					return fmt.Errorf("OIDC token claim '%s' is number, expected value [%s] not parseable", claimName, expectedValue)
+				}
+				if intClaimValue != expectedIntValue {
+					return fmt.Errorf("OIDC token claim '%s' value [%d] does not match expected value [%d]", claimName, intClaimValue, expectedIntValue)
+				}
+			} else {
+				return fmt.Errorf("claim %s is of unsupported type", claimName)
 			}
-		}
-		if !found {
-			return fmt.Errorf("OIDC token 'hd' claim [%s] is not in allowed list", hd)
 		}
 	}
 
-	// If hd check passes, proceed with standard verification.
+	// If claim verification passes, proceed with standard verification.
 	token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey)
 	if err != nil {
 		return fmt.Errorf("invalid OIDC token in login: %v", err)
diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go
index 66ff7fb9..dff2b5bc 100644
--- a/pkg/auth/oidc_test.go
+++ b/pkg/auth/oidc_test.go
@@ -23,7 +23,7 @@ func (m *mockTokenVerifier) Verify(ctx context.Context, subject string) (*oidc.I
 
 func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{})
 	err := consumer.VerifyPing(&msg.Ping{
 		PrivilegeKey: "ping-without-login",
 		Timestamp:    time.Now().UnixMilli(),
@@ -34,7 +34,7 @@ func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
 
 func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{})
 	err := consumer.VerifyLogin(&msg.Login{
 		PrivilegeKey: "ping-after-login",
 	})
@@ -49,7 +49,7 @@ func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
 
 func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{})
 	err := consumer.VerifyLogin(&msg.Login{
 		PrivilegeKey: "login-with-first-subject",
 	})
diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go
index 93281cf5..e233d110 100644
--- a/pkg/config/v1/server.go
+++ b/pkg/config/v1/server.go
@@ -147,15 +147,14 @@ type AuthOIDCServerConfig struct {
 	// SkipIssuerCheck specifies whether to skip checking if the OIDC token's
 	// issuer claim matches the issuer specified in OidcIssuer.
 	SkipIssuerCheck bool `json:"skipIssuerCheck,omitempty"`
-	// AllowedHostedDomains specifies a list of allowed hosted domains for the
-	// "hd" claim in the token.
-	AllowedHostedDomains []string `json:"allowedHostedDomains,omitempty"`
+	// AllowedClaims specifies a map of allowed claims for the OIDC token.
+	AllowedClaims map[string]string `json:"allowedClaims,omitempty"`
 }
 
 func (c *AuthOIDCServerConfig) Complete() {
-	// Ensure AllowedHostedDomains is an empty slice and not nil
-	if c.AllowedHostedDomains == nil {
-		c.AllowedHostedDomains = []string{}
+	// Ensure AllowedClaims is at least an empty map and not nil
+	if c.AllowedClaims == nil {
+		c.AllowedClaims = map[string]string{}
 	}
 }