Created
July 11, 2017 21:20
-
-
Save barchito/89bffc5d1b8ec01b785e77ebac295d81 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class CustomThrottlingHandler : ThrottlingHandler | |
{ | |
protected override RequestIdentity SetIdentity(HttpRequestMessage request) | |
{ | |
return new RequestIdentity() | |
{ | |
ClientKey = request.Headers.Contains("Authorization") ? request.Headers.GetValues("Authorization").First() : "anon", | |
ClientIp = base.GetClientIp(request).ToString(), | |
Endpoint = request.RequestUri.AbsolutePath.ToLowerInvariant() | |
}; | |
} | |
} | |
public static class WebApiConfig | |
{ | |
public static void Register(HttpConfiguration config) | |
{ | |
var cors = new EnableCorsAttribute(origins: ConfigurationManager.AppSettings["AccessControlAllowOrigin"], headers: "*", methods: "*"); | |
config.EnableCors(cors); | |
// Web API configuration and services | |
config.MessageHandlers.Add(new JsonWebTokenValidationHandler() | |
{ | |
Audience = ConfigurationManager.AppSettings["Aud"], // client id | |
SymmetricKey = ConfigurationManager.AppSettings["Secret"] // client secret | |
}); | |
config.MessageHandlers.Add(new TraceMessageHandler()); | |
var maxCallPerSecondFromIp = int.Parse(ConfigurationManager.AppSettings["MaxCallPerSecondFromIp"]); | |
config.MessageHandlers.Add(new CustomThrottlingHandler() | |
{ | |
Policy = new ThrottlePolicy(perSecond: maxCallPerSecondFromIp) | |
{ | |
IpThrottling = true | |
}, | |
Repository = new CacheRepository() | |
}); | |
// Web API routes | |
config.MapHttpAttributeRoutes(); | |
config.Routes.MapHttpRoute( | |
name: "DefaultApi", | |
routeTemplate: "api/{controller}/{id}", | |
defaults: new { id = RouteParameter.Optional } | |
); | |
config.Services.Add(typeof(IExceptionLogger), new AiExceptionLogger()); | |
} | |
} | |
public class TraceMessageHandler : DelegatingHandler | |
{ | |
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) | |
{ | |
try | |
{ | |
if (request.RequestUri.AbsolutePath.StartsWith("/api/")) | |
{ | |
if (request.Method == HttpMethod.Get) | |
{ | |
var ai = new TelemetryClient(); | |
var data = request.GetQueryNameValuePairs(); | |
ai.TrackTrace("Query Request", data.ToDictionary(x => x.Key, x => x.Value)); | |
} | |
else | |
{ | |
var ai = new TelemetryClient(); | |
var data = await request.Content.ReadAsStringAsync(); | |
var dict = new Dictionary<string, string>(); | |
dict.Add("props", data); | |
ai.TrackTrace("Data Request", dict); | |
} | |
} | |
} | |
catch (Exception) | |
{ | |
} | |
return await base.SendAsync(request, cancellationToken); | |
} | |
} | |
public class JsonWebTokenValidationHandler : DelegatingHandler | |
{ | |
public string SymmetricKey { get; set; } | |
public string Audience { get; set; } | |
public string Issuer { get; set; } | |
private static bool TryRetrieveToken(HttpRequestMessage request, out string token) | |
{ | |
token = null; | |
IEnumerable<string> authzHeaders; | |
if (!request.Headers.TryGetValues("Authorization", out authzHeaders) || authzHeaders.Count() > 1) | |
{ | |
// Fail if no Authorization header or more than one Authorization headers | |
// are found in the HTTP request | |
return false; | |
} | |
// Remove the bearer token scheme prefix and return the rest as ACS token | |
var bearerToken = authzHeaders.ElementAt(0); | |
token = bearerToken.StartsWith("Bearer ") ? bearerToken.Substring(7) : bearerToken; | |
return true; | |
} | |
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) | |
{ | |
string token; | |
HttpResponseMessage errorResponse = null; | |
if (TryRetrieveToken(request, out token)) | |
{ | |
try | |
{ | |
//var secret = this.SymmetricKey.Replace('-', '+').Replace('_', '/'); | |
Thread.CurrentPrincipal = JsonWebToken.ValidateToken( | |
token, | |
this.SymmetricKey, | |
audience: this.Audience, | |
checkExpiration: true, | |
issuer: this.Issuer); | |
if (HttpContext.Current != null) | |
{ | |
HttpContext.Current.User = new UserPrincipal(Thread.CurrentPrincipal.Identity); | |
} | |
} | |
catch (Jose.JoseException ex) | |
{ | |
errorResponse = request.CreateErrorResponse(HttpStatusCode.Unauthorized, "Wrong token, or you are not authorized."); | |
errorResponse.Headers.Add("Access-Control-Allow-Origin", ConfigurationManager.AppSettings["AccessControlAllowOrigin"]); | |
} | |
catch (JsonWebToken.TokenValidationException ex) | |
{ | |
errorResponse = request.CreateErrorResponse(HttpStatusCode.Unauthorized, "Wrong token, or you are not authorized."); | |
errorResponse.Headers.Add("Access-Control-Allow-Origin", ConfigurationManager.AppSettings["AccessControlAllowOrigin"]); | |
} | |
catch (Exception ex) | |
{ | |
errorResponse = request.CreateErrorResponse(HttpStatusCode.InternalServerError, "Wrong token, or you are not authorized."); | |
errorResponse.Headers.Add("Access-Control-Allow-Origin", ConfigurationManager.AppSettings["AccessControlAllowOrigin"]); | |
} | |
} | |
return errorResponse != null ? | |
Task.FromResult(errorResponse) : | |
base.SendAsync(request, cancellationToken); | |
} | |
} | |
public static class JsonWebToken | |
{ | |
private const string NameClaimType = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"; | |
private const string RoleClaimType = "http://schemas.microsoft.com/ws/2008/06/identity/claims/role"; | |
private const string ActorClaimType = "http://schemas.xmlsoap.org/ws/2009/09/identity/claims/actor"; | |
private const string DefaultIssuer = "LOCAL AUTHORITY"; | |
private const string StringClaimValueType = "http://www.w3.org/2001/XMLSchema#string"; | |
// sort claim types by relevance | |
private static string[] claimTypesForUserName = new string[] { "name", "email", "user_id", "sub", "user_metadata" }; | |
private static string[] claimsToExclude = new string[] { "iss", "sub", "aud", "exp", "iat", "identities" }; | |
private static JavaScriptSerializer jsonSerializer = new JavaScriptSerializer(); | |
public static ClaimsPrincipal ValidateToken(string token, string secretKey, string audience = null, bool checkExpiration = false, string issuer = null) | |
{ | |
var payloadJson = JWT.Decode(token, Base64UrlDecode(secretKey)); | |
var payloadData = jsonSerializer.Deserialize<Dictionary<string, object>>(payloadJson); | |
// audience check | |
object aud; | |
if (!string.IsNullOrEmpty(audience) && payloadData.TryGetValue("aud", out aud)) | |
{ | |
if (!aud.ToString().Equals(audience, StringComparison.Ordinal)) | |
{ | |
throw new TokenValidationException(string.Format("Audience mismatch. Expected: '{0}' and got: '{1}'", audience, aud)); | |
} | |
} | |
// expiration check | |
object exp; | |
if (checkExpiration && payloadData.TryGetValue("exp", out exp)) | |
{ | |
DateTime validTo = FromUnixTime(long.Parse(exp.ToString())); | |
if (DateTime.Compare(validTo, DateTime.UtcNow) <= 0) | |
{ | |
throw new TokenValidationException( | |
string.Format("Token is expired. Expiration: '{0}'. Current: '{1}'", validTo, DateTime.UtcNow)); | |
} | |
} | |
// issuer check | |
object iss; | |
if (payloadData.TryGetValue("iss", out iss)) | |
{ | |
if (!string.IsNullOrEmpty(issuer)) | |
{ | |
if (!iss.ToString().Equals(issuer, StringComparison.Ordinal)) | |
{ | |
throw new TokenValidationException(string.Format("Token issuer mismatch. Expected: '{0}' and got: '{1}'", issuer, iss)); | |
} | |
} | |
else | |
{ | |
// if issuer is not specified, set issuer with jwt[iss] | |
issuer = iss.ToString(); | |
} | |
} | |
return new ClaimsPrincipal(ClaimsIdentityFromJwt(payloadData, issuer)); | |
} | |
private static List<Claim> ClaimsFromJwt(IDictionary<string, object> jwtData, string issuer) | |
{ | |
var list = new List<Claim>(); | |
issuer = issuer ?? DefaultIssuer; | |
foreach (KeyValuePair<string, object> pair in jwtData) | |
{ | |
var claimType = pair.Key; | |
var source = pair.Value as ArrayList; | |
if (source != null) | |
{ | |
foreach (var item in source) | |
{ | |
list.Add(new Claim(claimType, item.ToString(), StringClaimValueType, issuer, issuer)); | |
} | |
continue; | |
} | |
if (claimType.Equals("user_metadata", StringComparison.InvariantCultureIgnoreCase)) | |
{ | |
var metadata = pair.Value as Dictionary<string, object>; | |
foreach (var item in metadata) | |
{ | |
list.Add(new Claim(claimType, $"{{\"{item.Key}\":\"{item.Value}\"}}", StringClaimValueType, issuer, issuer)); | |
} | |
list.Add(new Claim(claimType, JsonConvert.SerializeObject(pair.Value), StringClaimValueType, issuer, issuer)); | |
} | |
else | |
{ | |
list.Add(new Claim(claimType, pair.Value.ToString(), StringClaimValueType, issuer, issuer)); | |
} | |
} | |
if (list.Any(c => c.Type == "sub")) | |
{ | |
list.Add(new Claim(ClaimTypes.NameIdentifier, list.First(c=>c.Type=="sub").Value, StringClaimValueType, issuer,issuer)); | |
} | |
// set claim for user name | |
for (int i = 0; i < claimTypesForUserName.Length; i++) | |
{ | |
if (list.Any(c => c.Type == claimTypesForUserName[i])) | |
{ | |
var nameClaim = new Claim(NameClaimType, list.First(c => c.Type == claimTypesForUserName[i]).Value, StringClaimValueType, issuer, issuer); | |
list.Add(nameClaim); | |
break; | |
} | |
} | |
// dont include specific jwt claims | |
return list.Where(c => !claimsToExclude.Any(t => t == c.Type)).ToList(); | |
} | |
private static ClaimsIdentity ClaimsIdentityFromJwt(IDictionary<string, object> jwtData, string issuer) | |
{ | |
var subject = new ClaimsIdentity("Federation", NameClaimType, RoleClaimType); | |
var claims = ClaimsFromJwt(jwtData, issuer); | |
foreach (Claim claim in claims) | |
{ | |
var type = claim.Type; | |
if (type == ActorClaimType) | |
{ | |
if (subject.Actor != null) | |
{ | |
throw new InvalidOperationException(string.Format( | |
"Jwt10401: Only a single 'Actor' is supported. Found second claim of type: '{0}', value: '{1}'", new object[] { "actor", claim.Value })); | |
} | |
var claim2 = new Claim(type, claim.Value, claim.ValueType, issuer, issuer, subject); | |
subject.AddClaim(claim2); | |
continue; | |
} | |
if (type == "user_id") | |
{ | |
var claim4 = new Claim(ClaimTypes.NameIdentifier, claim.Value, claim.ValueType, issuer, issuer, subject); | |
subject.AddClaim(claim4); | |
} | |
var claim3 = new Claim(type, claim.Value, claim.ValueType, issuer, issuer, subject); | |
subject.AddClaim(claim3); | |
} | |
//subject.Name = subject.Claims.FirstOrDefault(x => x.Type) | |
return subject; | |
} | |
private static DateTime FromUnixTime(long unixTime) | |
{ | |
var epoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); | |
return epoch.AddSeconds(unixTime); | |
} | |
public class TokenValidationException : Exception | |
{ | |
public TokenValidationException(string message) | |
: base(message) | |
{ | |
} | |
} | |
private static byte[] Base64UrlDecode(string arg) | |
{ | |
string s = arg; | |
s = s.Replace('-', '+'); // 62nd char of encoding | |
s = s.Replace('_', '/'); // 63rd char of encoding | |
switch (s.Length % 4) // Pad with trailing '='s | |
{ | |
case 0: break; // No pad chars in this case | |
case 2: s += "=="; break; // Two pad chars | |
case 3: s += "="; break; // One pad char | |
default: | |
throw new System.Exception( | |
"Illegal base64url string!"); | |
} | |
return Convert.FromBase64String(s); // Standard base64 decoder | |
} | |
} | |
public class AiExceptionLogger : ExceptionLogger | |
{ | |
public override void Log(ExceptionLoggerContext context) | |
{ | |
if (context != null && context.Exception != null) | |
{//or reuse instance (recommended!). see note above | |
var ai = new TelemetryClient(); | |
ai.TrackException(context.Exception); | |
} | |
base.Log(context); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment