From cda2cb151e6126fd2c7e0278505c870291c9a92e Mon Sep 17 00:00:00 2001
From: foresturquhart <forest.urquhart@gmail.com>
Date: Thu, 6 Feb 2025 17:03:25 +0000
Subject: [PATCH 1/3] Implement OIDC raw token and hd claim verification

---
 pkg/auth/auth.go        |  2 +-
 pkg/auth/oidc.go        | 56 ++++++++++++++++++++++++++++++++++++++++-
 pkg/auth/oidc_test.go   |  6 ++---
 pkg/config/v1/client.go |  3 +++
 pkg/config/v1/server.go | 10 ++++++++
 5 files changed, 72 insertions(+), 5 deletions(-)

diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go
index ae706986..176b4d9f 100644
--- a/pkg/auth/auth.go
+++ b/pkg/auth/auth.go
@@ -51,7 +51,7 @@ func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) {
 		authVerifier = NewTokenAuth(cfg.AdditionalScopes, cfg.Token)
 	case v1.AuthMethodOIDC:
 		tokenVerifier := NewTokenVerifier(cfg.OIDC)
-		authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier)
+		authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedHostedDomains)
 	}
 	return authVerifier
 }
diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go
index 40ce060f..f241ff46 100644
--- a/pkg/auth/oidc.go
+++ b/pkg/auth/oidc.go
@@ -16,8 +16,11 @@ package auth
 
 import (
 	"context"
+	"encoding/base64"
+	"encoding/json"
 	"fmt"
 	"slices"
+	"strings"
 
 	"github.com/coreos/go-oidc/v3/oidc"
 	"golang.org/x/oauth2/clientcredentials"
@@ -30,6 +33,10 @@ type OidcAuthProvider struct {
 	additionalAuthScopes []v1.AuthScope
 
 	tokenGenerator *clientcredentials.Config
+
+	// rawToken is used to specify a raw JWT token for authentication.
+	// If rawToken is not empty, it will be used directly instead of generating a new token.
+	rawToken string
 }
 
 func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) *OidcAuthProvider {
@@ -53,10 +60,17 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien
 	return &OidcAuthProvider{
 		additionalAuthScopes: additionalAuthScopes,
 		tokenGenerator:       tokenGenerator,
+		rawToken:             cfg.RawToken,
 	}
 }
 
 func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) {
+	// If a raw token is provided, use it directly.
+	if auth.rawToken != "" {
+		return auth.rawToken, nil
+	}
+
+	// Otherwise, generate a new token using the client credentials flow.
 	tokenObj, err := auth.tokenGenerator.Token(context.Background())
 	if err != nil {
 		return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err)
@@ -96,6 +110,9 @@ type OidcAuthConsumer struct {
 
 	verifier          TokenVerifier
 	subjectsFromLogin []string
+
+	// allowedHostedDomains specifies a list of allowed hosted domains for the "hd" claim in the token.
+	allowedHostedDomains []string
 }
 
 func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@@ -112,15 +129,52 @@ func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
 	return provider.Verifier(&verifierConf)
 }
 
-func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier) *OidcAuthConsumer {
+func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedHostedDomains []string) *OidcAuthConsumer {
 	return &OidcAuthConsumer{
 		additionalAuthScopes: additionalAuthScopes,
 		verifier:             verifier,
 		subjectsFromLogin:    []string{},
+		allowedHostedDomains: allowedHostedDomains,
 	}
 }
 
 func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
+	// Decode token without verifying signature to retrieved 'hd' claim.
+	parts := strings.Split(loginMsg.PrivilegeKey, ".")
+	if len(parts) != 3 {
+		return fmt.Errorf("invalid OIDC token format")
+	}
+
+	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
+	if err != nil {
+		return fmt.Errorf("invalid OIDC token: failed to decode payload: %v", err)
+	}
+
+	var claims map[string]any
+	if err := json.Unmarshal(payload, &claims); err != nil {
+		return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err)
+	}
+
+	// Verify hosted domain (hd claim).
+	if len(auth.allowedHostedDomains) > 0 {
+		hd, ok := claims["hd"].(string)
+		if !ok {
+			return fmt.Errorf("OIDC token missing required 'hd' claim")
+		}
+
+		found := false
+		for _, domain := range auth.allowedHostedDomains {
+			if hd == domain {
+				found = true
+				break
+			}
+		}
+		if !found {
+			return fmt.Errorf("OIDC token 'hd' claim [%s] is not in allowed list", hd)
+		}
+	}
+
+	// If hd check passes, proceed with standard verification.
 	token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey)
 	if err != nil {
 		return fmt.Errorf("invalid OIDC token in login: %v", err)
diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go
index 58054186..66ff7fb9 100644
--- a/pkg/auth/oidc_test.go
+++ b/pkg/auth/oidc_test.go
@@ -23,7 +23,7 @@ func (m *mockTokenVerifier) Verify(ctx context.Context, subject string) (*oidc.I
 
 func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
 	err := consumer.VerifyPing(&msg.Ping{
 		PrivilegeKey: "ping-without-login",
 		Timestamp:    time.Now().UnixMilli(),
@@ -34,7 +34,7 @@ func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
 
 func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
 	err := consumer.VerifyLogin(&msg.Login{
 		PrivilegeKey: "ping-after-login",
 	})
@@ -49,7 +49,7 @@ func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
 
 func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
 	err := consumer.VerifyLogin(&msg.Login{
 		PrivilegeKey: "login-with-first-subject",
 	})
diff --git a/pkg/config/v1/client.go b/pkg/config/v1/client.go
index d43ec1bc..1ad194f1 100644
--- a/pkg/config/v1/client.go
+++ b/pkg/config/v1/client.go
@@ -203,4 +203,7 @@ type AuthOIDCClientConfig struct {
 	// AdditionalEndpointParams specifies additional parameters to be sent
 	// this field will be transfer to map[string][]string in OIDC token generator.
 	AdditionalEndpointParams map[string]string `json:"additionalEndpointParams,omitempty"`
+	// RawToken specifies a raw JWT token to use for authentication, bypassing
+	// the OIDC flow.
+	RawToken string `json:"rawToken,omitempty"`
 }
diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go
index 3108cd34..93281cf5 100644
--- a/pkg/config/v1/server.go
+++ b/pkg/config/v1/server.go
@@ -147,6 +147,16 @@ type AuthOIDCServerConfig struct {
 	// SkipIssuerCheck specifies whether to skip checking if the OIDC token's
 	// issuer claim matches the issuer specified in OidcIssuer.
 	SkipIssuerCheck bool `json:"skipIssuerCheck,omitempty"`
+	// AllowedHostedDomains specifies a list of allowed hosted domains for the
+	// "hd" claim in the token.
+	AllowedHostedDomains []string `json:"allowedHostedDomains,omitempty"`
+}
+
+func (c *AuthOIDCServerConfig) Complete() {
+	// Ensure AllowedHostedDomains is an empty slice and not nil
+	if c.AllowedHostedDomains == nil {
+		c.AllowedHostedDomains = []string{}
+	}
 }
 
 type ServerTransportConfig struct {

From b499412aee00f709bb71b50b65f4272fec2b9897 Mon Sep 17 00:00:00 2001
From: foresturquhart <forest.urquhart@gmail.com>
Date: Thu, 6 Feb 2025 17:46:46 +0000
Subject: [PATCH 2/3] Wrap new VerifyLogin logic in allowedHostedDomains length
 check

---
 pkg/auth/oidc.go | 32 ++++++++++++++++----------------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go
index f241ff46..ed5bb543 100644
--- a/pkg/auth/oidc.go
+++ b/pkg/auth/oidc.go
@@ -139,24 +139,24 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri
 }
 
 func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
-	// Decode token without verifying signature to retrieved 'hd' claim.
-	parts := strings.Split(loginMsg.PrivilegeKey, ".")
-	if len(parts) != 3 {
-		return fmt.Errorf("invalid OIDC token format")
-	}
-
-	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
-	if err != nil {
-		return fmt.Errorf("invalid OIDC token: failed to decode payload: %v", err)
-	}
-
-	var claims map[string]any
-	if err := json.Unmarshal(payload, &claims); err != nil {
-		return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err)
-	}
-
 	// Verify hosted domain (hd claim).
 	if len(auth.allowedHostedDomains) > 0 {
+		// Decode token without verifying signature to retrieved 'hd' claim.
+		parts := strings.Split(loginMsg.PrivilegeKey, ".")
+		if len(parts) != 3 {
+			return fmt.Errorf("invalid OIDC token format")
+		}
+
+		payload, err := base64.RawURLEncoding.DecodeString(parts[1])
+		if err != nil {
+			return fmt.Errorf("invalid OIDC token: failed to decode payload: %v", err)
+		}
+
+		var claims map[string]any
+		if err := json.Unmarshal(payload, &claims); err != nil {
+			return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err)
+		}
+
 		hd, ok := claims["hd"].(string)
 		if !ok {
 			return fmt.Errorf("OIDC token missing required 'hd' claim")

From f62cd91f09cc975f88660a2cd70232cf1533c353 Mon Sep 17 00:00:00 2001
From: foresturquhart <forest.urquhart@gmail.com>
Date: Fri, 7 Feb 2025 11:16:38 +0000
Subject: [PATCH 3/3] Change from verifying hosted domains to verifying claims

---
 pkg/auth/auth.go        |  2 +-
 pkg/auth/oidc.go        | 51 ++++++++++++++++++++++++-----------------
 pkg/auth/oidc_test.go   |  6 ++---
 pkg/config/v1/server.go | 11 ++++-----
 4 files changed, 39 insertions(+), 31 deletions(-)

diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go
index 176b4d9f..ca9cfc51 100644
--- a/pkg/auth/auth.go
+++ b/pkg/auth/auth.go
@@ -51,7 +51,7 @@ func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) {
 		authVerifier = NewTokenAuth(cfg.AdditionalScopes, cfg.Token)
 	case v1.AuthMethodOIDC:
 		tokenVerifier := NewTokenVerifier(cfg.OIDC)
-		authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedHostedDomains)
+		authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedClaims)
 	}
 	return authVerifier
 }
diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go
index ed5bb543..6b926d29 100644
--- a/pkg/auth/oidc.go
+++ b/pkg/auth/oidc.go
@@ -20,6 +20,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"slices"
+	"strconv"
 	"strings"
 
 	"github.com/coreos/go-oidc/v3/oidc"
@@ -111,8 +112,8 @@ type OidcAuthConsumer struct {
 	verifier          TokenVerifier
 	subjectsFromLogin []string
 
-	// allowedHostedDomains specifies a list of allowed hosted domains for the "hd" claim in the token.
-	allowedHostedDomains []string
+	// allowedClaims specifies a map of allowed claims for the OIDC token.
+	allowedClaims map[string]string
 }
 
 func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@@ -129,19 +130,19 @@ func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
 	return provider.Verifier(&verifierConf)
 }
 
-func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedHostedDomains []string) *OidcAuthConsumer {
+func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedClaims map[string]string) *OidcAuthConsumer {
 	return &OidcAuthConsumer{
 		additionalAuthScopes: additionalAuthScopes,
 		verifier:             verifier,
 		subjectsFromLogin:    []string{},
-		allowedHostedDomains: allowedHostedDomains,
+		allowedClaims:        allowedClaims,
 	}
 }
 
 func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
-	// Verify hosted domain (hd claim).
-	if len(auth.allowedHostedDomains) > 0 {
-		// Decode token without verifying signature to retrieved 'hd' claim.
+	// Verify allowed claims if configured.
+	if len(auth.allowedClaims) > 0 {
+		// Decode token without verifying signature.
 		parts := strings.Split(loginMsg.PrivilegeKey, ".")
 		if len(parts) != 3 {
 			return fmt.Errorf("invalid OIDC token format")
@@ -157,24 +158,32 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
 			return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err)
 		}
 
-		hd, ok := claims["hd"].(string)
-		if !ok {
-			return fmt.Errorf("OIDC token missing required 'hd' claim")
-		}
-
-		found := false
-		for _, domain := range auth.allowedHostedDomains {
-			if hd == domain {
-				found = true
-				break
+		// Iterate over allowed claims and attempt to verify.
+		for claimName, expectedValue := range auth.allowedClaims {
+			claimValue, ok := claims[claimName]
+			if !ok {
+				return fmt.Errorf("OIDC token missing required claim: %s", claimName)
+			}
+
+			if strClaimValue, ok := claimValue.(string); ok {
+				if strClaimValue != expectedValue {
+					return fmt.Errorf("OIDC token claim '%s' value [%s] does not match expected value [%s]", claimName, strClaimValue, expectedValue)
+				}
+			} else if intClaimValue, ok := claimValue.(int); ok {
+				expectedIntValue, err := strconv.Atoi(expectedValue)
+				if err != nil {
+					return fmt.Errorf("OIDC token claim '%s' is number, expected value [%s] not parseable", claimName, expectedValue)
+				}
+				if intClaimValue != expectedIntValue {
+					return fmt.Errorf("OIDC token claim '%s' value [%d] does not match expected value [%d]", claimName, intClaimValue, expectedIntValue)
+				}
+			} else {
+				return fmt.Errorf("claim %s is of unsupported type", claimName)
 			}
-		}
-		if !found {
-			return fmt.Errorf("OIDC token 'hd' claim [%s] is not in allowed list", hd)
 		}
 	}
 
-	// If hd check passes, proceed with standard verification.
+	// If claim verification passes, proceed with standard verification.
 	token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey)
 	if err != nil {
 		return fmt.Errorf("invalid OIDC token in login: %v", err)
diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go
index 66ff7fb9..dff2b5bc 100644
--- a/pkg/auth/oidc_test.go
+++ b/pkg/auth/oidc_test.go
@@ -23,7 +23,7 @@ func (m *mockTokenVerifier) Verify(ctx context.Context, subject string) (*oidc.I
 
 func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{})
 	err := consumer.VerifyPing(&msg.Ping{
 		PrivilegeKey: "ping-without-login",
 		Timestamp:    time.Now().UnixMilli(),
@@ -34,7 +34,7 @@ func TestPingWithEmptySubjectFromLoginFails(t *testing.T) {
 
 func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{})
 	err := consumer.VerifyLogin(&msg.Login{
 		PrivilegeKey: "ping-after-login",
 	})
@@ -49,7 +49,7 @@ func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) {
 
 func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
 	r := require.New(t)
-	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{})
+	consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{})
 	err := consumer.VerifyLogin(&msg.Login{
 		PrivilegeKey: "login-with-first-subject",
 	})
diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go
index 93281cf5..e233d110 100644
--- a/pkg/config/v1/server.go
+++ b/pkg/config/v1/server.go
@@ -147,15 +147,14 @@ type AuthOIDCServerConfig struct {
 	// SkipIssuerCheck specifies whether to skip checking if the OIDC token's
 	// issuer claim matches the issuer specified in OidcIssuer.
 	SkipIssuerCheck bool `json:"skipIssuerCheck,omitempty"`
-	// AllowedHostedDomains specifies a list of allowed hosted domains for the
-	// "hd" claim in the token.
-	AllowedHostedDomains []string `json:"allowedHostedDomains,omitempty"`
+	// AllowedClaims specifies a map of allowed claims for the OIDC token.
+	AllowedClaims map[string]string `json:"allowedClaims,omitempty"`
 }
 
 func (c *AuthOIDCServerConfig) Complete() {
-	// Ensure AllowedHostedDomains is an empty slice and not nil
-	if c.AllowedHostedDomains == nil {
-		c.AllowedHostedDomains = []string{}
+	// Ensure AllowedClaims is at least an empty map and not nil
+	if c.AllowedClaims == nil {
+		c.AllowedClaims = map[string]string{}
 	}
 }