Skip to content

Instantly share code, notes, and snippets.

@mtherien
Created May 7, 2020 19:56
Show Gist options
  • Save mtherien/39803dd113180bccd7915f97e2ccb7d8 to your computer and use it in GitHub Desktop.
Save mtherien/39803dd113180bccd7915f97e2ccb7d8 to your computer and use it in GitHub Desktop.
DbContext Extension to get primary keys of an entity
// Source: https://stackoverflow.com/questions/30688909/how-to-get-primary-key-value-with-entity-framework-core
public static class DbContextKeyExtensions
{
private static readonly ConcurrentDictionary<Type, IProperty[]> KeyPropertiesByEntityType = new ConcurrentDictionary<Type, IProperty[]>();
public static string KeyValuesAsString(this EntityEntry entry)
{
if (entry == null)
{
throw new ArgumentNullException(nameof(entry));
}
var values = entry.KeyValuesOf();
return $"{string.Join("; ", values.Select(v => $"({v.Key})=[{v.Value}]"))}";
}
public static IEnumerable<KeyValuePair<string, object>> KeyValuesOf(this EntityEntry entry)
{
if (entry == null)
{
throw new ArgumentNullException(nameof(entry));
}
var keyProperties = entry.Context.GetKeyProperties(entry.Entity.GetType());
foreach (var keyProperty in keyProperties)
{
yield return new KeyValuePair<string, object>(keyProperty.Name, entry.Property(keyProperty.Name).CurrentValue);
}
}
public static IEnumerable<object> KeyOf<TEntity>(this DbContext context, TEntity entity)
where TEntity : class
{
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
var entry = context.Entry(entity);
return entry.KeyOf();
}
public static TKey KeyOf<TEntity, TKey>(this DbContext context, TEntity entity)
where TEntity : class
{
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
var keyParts = context.KeyOf(entity).ToArray();
if (keyParts.Length > 1)
{
throw new InvalidOperationException($"Key is composite and has '{keyParts.Length}' parts.");
}
return (TKey)keyParts[0];
}
public static IEnumerable<object> KeyOf(this EntityEntry entry)
{
if (entry == null)
{
throw new ArgumentNullException(nameof(entry));
}
var keyProperties = entry.Context.GetKeyProperties(entry.Entity.GetType());
return keyProperties
.Select(property => entry.Entity.GetPropertyValue(property.Name))
.AsEnumerable();
}
public static TKey KeyOf<TKey>(this EntityEntry entry)
{
if (entry == null)
{
throw new ArgumentNullException(nameof(entry));
}
var keyParts = entry.KeyOf().ToArray();
if (!keyParts.Any())
{
throw new InvalidOperationException($"Key is composite and has '{keyParts.Count()}' parts.");
}
return (TKey)keyParts[0];
}
private static IEnumerable<IProperty> GetKeyProperties(this IDbContextDependencies context, Type entityType)
{
var keyProperties = KeyPropertiesByEntityType.GetOrAdd(
entityType,
t => context.FindPrimaryKeyProperties(entityType).ToArray());
return keyProperties;
}
private static IEnumerable<IProperty> FindPrimaryKeyProperties(this IDbContextDependencies dbContext, Type entityType)
{
return dbContext.Model.FindEntityType(entityType).FindPrimaryKey().Properties;
}
private static object GetPropertyValue<T>(this T entity, string propertyName)
{
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
if (propertyName == null)
{
throw new ArgumentNullException(nameof(propertyName));
}
if (string.IsNullOrEmpty(propertyName))
{
throw new ArgumentException($"{nameof(propertyName)} must have value", nameof(propertyName));
}
return typeof(T).GetProperty(propertyName)?.GetValue(entity, null);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment