From 220479c1a181d3879985843403e7f904aae1a291 Mon Sep 17 00:00:00 2001 From: Ajay Bhargav Baaskaran Date: Fri, 5 Feb 2016 16:12:07 -0800 Subject: [PATCH] [Fixes #30] Updated UID generation in DefaultClaimUidExtractor --- .../DefaultAntiforgeryTokenGenerator.cs | 52 ++++- .../Internal/DefaultClaimUidExtractor.cs | 95 +++++++-- .../Internal/IClaimUidExtractor.cs | 7 +- .../DefaultAntiforgeryTokenGeneratorTest.cs | 9 +- .../Internal/DefaultClaimUidExtractorTest.cs | 190 ++++++++++++++++-- 5 files changed, 298 insertions(+), 55 deletions(-) diff --git a/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultAntiforgeryTokenGenerator.cs b/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultAntiforgeryTokenGenerator.cs index 390a39c..872f7ed 100644 --- a/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultAntiforgeryTokenGenerator.cs +++ b/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultAntiforgeryTokenGenerator.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Security.Principal; using Microsoft.AspNetCore.Http; @@ -60,16 +61,17 @@ public AntiforgeryToken GenerateRequestToken( }; var isIdentityAuthenticated = false; - var identity = httpContext.User?.Identity as ClaimsIdentity; // populate Username and ClaimUid - if (identity != null && identity.IsAuthenticated) + var authenticatedIdentity = GetAuthenticatedIdentity(httpContext.User); + if (authenticatedIdentity != null) { isIdentityAuthenticated = true; - requestToken.ClaimUid = GetClaimUidBlob(_claimUidExtractor.ExtractClaimUid(identity)); + requestToken.ClaimUid = GetClaimUidBlob(_claimUidExtractor.ExtractClaimUid(httpContext.User)); + if (requestToken.ClaimUid == null) { - requestToken.Username = identity.Name; + requestToken.Username = authenticatedIdentity.Name; } } @@ -87,7 +89,7 @@ public AntiforgeryToken GenerateRequestToken( // Application says user is authenticated, but we have no identifier for the user. throw new InvalidOperationException( Resources.FormatAntiforgeryTokenValidator_AuthenticatedUserWithoutUsername( - identity.GetType(), + authenticatedIdentity.GetType(), nameof(IIdentity.IsAuthenticated), "true", nameof(IIdentity.Name), @@ -148,13 +150,13 @@ public bool TryValidateTokenSet( var currentUsername = string.Empty; BinaryBlob currentClaimUid = null; - var identity = httpContext.User?.Identity as ClaimsIdentity; - if (identity != null && identity.IsAuthenticated) + var authenticatedIdentity = GetAuthenticatedIdentity(httpContext.User); + if (authenticatedIdentity != null) { - currentClaimUid = GetClaimUidBlob(_claimUidExtractor.ExtractClaimUid(identity)); + currentClaimUid = GetClaimUidBlob(_claimUidExtractor.ExtractClaimUid(httpContext.User)); if (currentClaimUid == null) { - currentUsername = identity.Name ?? string.Empty; + currentUsername = authenticatedIdentity.Name ?? string.Empty; } } @@ -200,5 +202,37 @@ private static BinaryBlob GetClaimUidBlob(string base64ClaimUid) return new BinaryBlob(256, Convert.FromBase64String(base64ClaimUid)); } + + private static ClaimsIdentity GetAuthenticatedIdentity(ClaimsPrincipal claimsPrincipal) + { + if (claimsPrincipal == null) + { + return null; + } + + var identitiesList = claimsPrincipal.Identities as List; + if (identitiesList != null) + { + for (var i = 0; i < identitiesList.Count; i++) + { + if (identitiesList[i].IsAuthenticated) + { + return identitiesList[i]; + } + } + } + else + { + foreach (var identity in claimsPrincipal.Identities) + { + if (identity.IsAuthenticated) + { + return identity; + } + } + } + + return null; + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultClaimUidExtractor.cs b/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultClaimUidExtractor.cs index 7f08907..1a7ef39 100644 --- a/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultClaimUidExtractor.cs +++ b/src/Microsoft.AspNetCore.Antiforgery/Internal/DefaultClaimUidExtractor.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Generic; -using System.Linq; +using System.Diagnostics; using System.Security.Claims; using Microsoft.Extensions.ObjectPool; @@ -22,42 +22,99 @@ public DefaultClaimUidExtractor(ObjectPool pool } /// - public string ExtractClaimUid(ClaimsIdentity claimsIdentity) + public string ExtractClaimUid(ClaimsPrincipal claimsPrincipal) { - if (claimsIdentity == null || !claimsIdentity.IsAuthenticated) + Debug.Assert(claimsPrincipal != null); + + var uniqueIdentifierParameters = GetUniqueIdentifierParameters(claimsPrincipal.Identities); + if (uniqueIdentifierParameters == null) { - // Skip anonymous users + // No authenticated identities containing claims found. return null; } - var uniqueIdentifierParameters = GetUniqueIdentifierParameters(claimsIdentity); var claimUidBytes = ComputeSha256(uniqueIdentifierParameters); return Convert.ToBase64String(claimUidBytes); } - // Internal for testing - internal static IEnumerable GetUniqueIdentifierParameters(ClaimsIdentity claimsIdentity) + public static IList GetUniqueIdentifierParameters(IEnumerable claimsIdentities) { - var nameIdentifierClaim = claimsIdentity.FindFirst( - claim => string.Equals(ClaimTypes.NameIdentifier, claim.Type, StringComparison.Ordinal)); - if (nameIdentifierClaim != null && !string.IsNullOrEmpty(nameIdentifierClaim.Value)) + var identitiesList = claimsIdentities as List; + if (identitiesList == null) + { + identitiesList = new List(claimsIdentities); + } + + for (var i = 0; i < identitiesList.Count; i++) { - return new string[] + var identity = identitiesList[i]; + if (!identity.IsAuthenticated) + { + continue; + } + + var subClaim = identity.FindFirst( + claim => string.Equals("sub", claim.Type, StringComparison.Ordinal)); + if (subClaim != null && !string.IsNullOrEmpty(subClaim.Value)) + { + return new string[] + { + subClaim.Type, + subClaim.Value, + subClaim.Issuer + }; + } + + var nameIdentifierClaim = identity.FindFirst( + claim => string.Equals(ClaimTypes.NameIdentifier, claim.Type, StringComparison.Ordinal)); + if (nameIdentifierClaim != null && !string.IsNullOrEmpty(nameIdentifierClaim.Value)) { - ClaimTypes.NameIdentifier, - nameIdentifierClaim.Value - }; + return new string[] + { + nameIdentifierClaim.Type, + nameIdentifierClaim.Value, + nameIdentifierClaim.Issuer + }; + } + + var upnClaim = identity.FindFirst( + claim => string.Equals(ClaimTypes.Upn, claim.Type, StringComparison.Ordinal)); + if (upnClaim != null && !string.IsNullOrEmpty(upnClaim.Value)) + { + return new string[] + { + upnClaim.Type, + upnClaim.Value, + upnClaim.Issuer + }; + } + } + + // We do not understand any of the ClaimsIdentity instances, fallback on serializing all claims in every claims Identity. + var allClaims = new List(); + for (var i = 0; i < identitiesList.Count; i++) + { + if (identitiesList[i].IsAuthenticated) + { + allClaims.AddRange(identitiesList[i].Claims); + } + } + + if (allClaims.Count == 0) + { + // No authenticated identities containing claims found. + return null; } - // We do not understand this ClaimsIdentity, fallback on serializing the entire claims Identity. - var claims = claimsIdentity.Claims.ToList(); - claims.Sort((a, b) => string.Compare(a.Type, b.Type, StringComparison.Ordinal)); + allClaims.Sort((a, b) => string.Compare(a.Type, b.Type, StringComparison.Ordinal)); - var identifierParameters = new List(claims.Count * 2); - foreach (var claim in claims) + var identifierParameters = new List(allClaims.Count * 3); + for (var i = 0; i < allClaims.Count; i++) { + var claim = allClaims[i]; identifierParameters.Add(claim.Type); identifierParameters.Add(claim.Value); + identifierParameters.Add(claim.Issuer); } return identifierParameters; diff --git a/src/Microsoft.AspNetCore.Antiforgery/Internal/IClaimUidExtractor.cs b/src/Microsoft.AspNetCore.Antiforgery/Internal/IClaimUidExtractor.cs index d4a888d..72ab230 100644 --- a/src/Microsoft.AspNetCore.Antiforgery/Internal/IClaimUidExtractor.cs +++ b/src/Microsoft.AspNetCore.Antiforgery/Internal/IClaimUidExtractor.cs @@ -1,20 +1,21 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.Security.Claims; namespace Microsoft.AspNetCore.Antiforgery.Internal { /// - /// This interface can extract unique identifers for a claims-based identity. + /// This interface can extract unique identifers for a . /// public interface IClaimUidExtractor { /// /// Extracts claims identifier. /// - /// The . + /// The . /// The claims identifier. - string ExtractClaimUid(ClaimsIdentity identity); + string ExtractClaimUid(ClaimsPrincipal claimsPrincipal); } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultAntiforgeryTokenGeneratorTest.cs b/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultAntiforgeryTokenGeneratorTest.cs index 0e4a2a3..5509726 100644 --- a/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultAntiforgeryTokenGeneratorTest.cs +++ b/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultAntiforgeryTokenGeneratorTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Security.Cryptography; using Microsoft.AspNetCore.Http.Internal; @@ -157,7 +158,7 @@ public void GenerateRequestToken_ClaimsBasedIdentity() var expectedClaimUid = new BinaryBlob(256, data); var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(identity)) + mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) .Returns(base64ClaimUId); var tokenProvider = new DefaultAntiforgeryTokenGenerator( @@ -410,7 +411,7 @@ public void TryValidateTokenSet_UsernameMismatch(string identityUsername, string }; var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(identity)) + mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) .Returns((string)null); var tokenProvider = new DefaultAntiforgeryTokenGenerator( @@ -448,7 +449,7 @@ public void TryValidateTokenSet_ClaimUidMismatch() var differentToken = new BinaryBlob(256); var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(identity)) + mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) .Returns(Convert.ToBase64String(differentToken.GetData())); var tokenProvider = new DefaultAntiforgeryTokenGenerator( @@ -590,7 +591,7 @@ public void TryValidateTokenSet_Success_ClaimsBasedUser() }; var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(identity)) + mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) .Returns(Convert.ToBase64String(fieldtoken.ClaimUid.GetData())); var tokenProvider = new DefaultAntiforgeryTokenGenerator( diff --git a/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultClaimUidExtractorTest.cs b/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultClaimUidExtractorTest.cs index 2968358..59e2d3a 100644 --- a/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultClaimUidExtractorTest.cs +++ b/test/Microsoft.AspNetCore.Antiforgery.Test/Internal/DefaultClaimUidExtractorTest.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Security.Claims; using Microsoft.Extensions.ObjectPool; @@ -15,19 +17,6 @@ public class DefaultClaimUidExtractorTest private static readonly ObjectPool _pool = new DefaultObjectPoolProvider().Create(new AntiforgerySerializationContextPooledObjectPolicy()); - [Fact] - public void ExtractClaimUid_NullIdentity() - { - // Arrange - var extractor = new DefaultClaimUidExtractor(_pool); - - // Act - var claimUid = extractor.ExtractClaimUid(null); - - // Assert - Assert.Null(claimUid); - } - [Fact] public void ExtractClaimUid_Unauthenticated() { @@ -39,7 +28,7 @@ public void ExtractClaimUid_Unauthenticated() .Returns(false); // Act - var claimUid = extractor.ExtractClaimUid(mockIdentity.Object); + var claimUid = extractor.ExtractClaimUid(new ClaimsPrincipal(mockIdentity.Object)); // Assert Assert.Null(claimUid); @@ -52,21 +41,22 @@ public void ExtractClaimUid_ClaimsIdentity() var mockIdentity = new Mock(); mockIdentity.Setup(o => o.IsAuthenticated) .Returns(true); + mockIdentity.Setup(o => o.Claims).Returns(new Claim[] { new Claim(ClaimTypes.Name, "someName") }); var extractor = new DefaultClaimUidExtractor(_pool); // Act - var claimUid = extractor.ExtractClaimUid(mockIdentity.Object); + var claimUid = extractor.ExtractClaimUid(new ClaimsPrincipal(mockIdentity.Object )); // Assert Assert.NotNull(claimUid); - Assert.Equal("47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", claimUid); + Assert.Equal("yhXE+2v4zSXHtRHmzm4cmrhZca2J0g7yTUwtUerdeF4=", claimUid); } [Fact] public void DefaultUniqueClaimTypes_NotPresent_SerializesAllClaimTypes() { - var identity = new ClaimsIdentity(); + var identity = new ClaimsIdentity("someAuthentication"); identity.AddClaim(new Claim(ClaimTypes.Email, "someone@antifrogery.com")); identity.AddClaim(new Claim(ClaimTypes.GivenName, "some")); identity.AddClaim(new Claim(ClaimTypes.Surname, "one")); @@ -79,7 +69,7 @@ public void DefaultUniqueClaimTypes_NotPresent_SerializesAllClaimTypes() var claimsIdentity = (ClaimsIdentity)identity; // Act - var identiferParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(claimsIdentity) + var identiferParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { claimsIdentity }) .ToArray(); var claims = claimsIdentity.Claims.ToList(); claims.Sort((a, b) => string.Compare(a.Type, b.Type, StringComparison.Ordinal)); @@ -90,6 +80,7 @@ public void DefaultUniqueClaimTypes_NotPresent_SerializesAllClaimTypes() { Assert.Equal(identiferParameters[index++], claim.Type); Assert.Equal(identiferParameters[index++], claim.Value); + Assert.Equal(identiferParameters[index++], claim.Issuer); } } @@ -97,18 +88,177 @@ public void DefaultUniqueClaimTypes_NotPresent_SerializesAllClaimTypes() public void DefaultUniqueClaimTypes_Present() { // Arrange - var identity = new ClaimsIdentity(); + var identity = new ClaimsIdentity("someAuthentication"); identity.AddClaim(new Claim("fooClaim", "fooClaimValue")); identity.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(identity); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + + // Assert + Assert.Equal(new string[] + { + ClaimTypes.NameIdentifier, + "nameIdentifierValue", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_PrefersSubClaimOverNameIdentifierAndUpn() + { + // Arrange + var identity = new ClaimsIdentity("someAuthentication"); + identity.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); + identity.AddClaim(new Claim("sub", "subClaimValue")); + identity.AddClaim(new Claim(ClaimTypes.Upn, "upnClaimValue")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + + // Assert + Assert.Equal(new string[] + { + "sub", + "subClaimValue", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_PrefersNameIdentifierOverUpn() + { + // Arrange + var identity = new ClaimsIdentity("someAuthentication"); + identity.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); + identity.AddClaim(new Claim(ClaimTypes.Upn, "upnClaimValue")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); // Assert Assert.Equal(new string[] { ClaimTypes.NameIdentifier, "nameIdentifierValue", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_UsesUpnIfPresent() + { + // Arrange + var identity = new ClaimsIdentity("someAuthentication"); + identity.AddClaim(new Claim("fooClaim", "fooClaimValue")); + identity.AddClaim(new Claim(ClaimTypes.Upn, "upnClaimValue")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + + // Assert + Assert.Equal(new string[] + { + ClaimTypes.Upn, + "upnClaimValue", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_MultipleIdentities_UsesOnlyAuthenticatedIdentities() + { + // Arrange + var identity1 = new ClaimsIdentity(); // no authentication + identity1.AddClaim(new Claim("sub", "subClaimValue")); + var identity2 = new ClaimsIdentity("someAuthentication"); + identity2.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity1, identity2 }); + + // Assert + Assert.Equal(new string[] + { + ClaimTypes.NameIdentifier, + "nameIdentifierValue", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_NoKnownClaimTypesFound_SortsAndReturnsAllClaimsFromAuthenticatedIdentities() + { + // Arrange + var identity1 = new ClaimsIdentity(); // no authentication + identity1.AddClaim(new Claim("sub", "subClaimValue")); + var identity2 = new ClaimsIdentity("someAuthentication"); + identity2.AddClaim(new Claim(ClaimTypes.Email, "email@domain.com")); + var identity3 = new ClaimsIdentity("someAuthentication"); + identity3.AddClaim(new Claim(ClaimTypes.Country, "countryValue")); + var identity4 = new ClaimsIdentity("someAuthentication"); + identity4.AddClaim(new Claim(ClaimTypes.Name, "claimName")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters( + new ClaimsIdentity[] { identity1, identity2, identity3, identity4 }); + + // Assert + Assert.Equal(new List + { + ClaimTypes.Country, + "countryValue", + "LOCAL AUTHORITY", + ClaimTypes.Email, + "email@domain.com", + "LOCAL AUTHORITY", + ClaimTypes.Name, + "claimName", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_PrefersNameFromFirstIdentity_OverSubFromSecondIdentity() + { + // Arrange + var identity1 = new ClaimsIdentity("someAuthentication"); + identity1.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); + var identity2 = new ClaimsIdentity("someAuthentication"); + identity2.AddClaim(new Claim("sub", "subClaimValue")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters( + new ClaimsIdentity[] { identity1, identity2 }); + + // Assert + Assert.Equal(new string[] + { + ClaimTypes.NameIdentifier, + "nameIdentifierValue", + "LOCAL AUTHORITY", + }, uniqueIdentifierParameters); + } + + [Fact] + public void GetUniqueIdentifierParameters_PrefersUpnFromFirstIdentity_OverNameFromSecondIdentity() + { + // Arrange + var identity1 = new ClaimsIdentity("someAuthentication"); + identity1.AddClaim(new Claim(ClaimTypes.Upn, "upnValue")); + var identity2 = new ClaimsIdentity("someAuthentication"); + identity2.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); + + // Act + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters( + new ClaimsIdentity[] { identity1, identity2 }); + + // Assert + Assert.Equal(new string[] + { + ClaimTypes.Upn, + "upnValue", + "LOCAL AUTHORITY", }, uniqueIdentifierParameters); } }