Last active
May 14, 2023 18:38
-
-
Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Roslyn Analyzer to warn about access to forbidden types
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Collections.Immutable; | |
using System.Linq; | |
using Microsoft.CodeAnalysis; | |
using Microsoft.CodeAnalysis.Diagnostics; | |
using Microsoft.CodeAnalysis.Operations; | |
// CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs | |
[DiagnosticAnalyzer(LanguageNames.CSharp)] | |
public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer | |
{ | |
public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer); | |
static readonly LocalizableString Description = "Restricts the set of types that may be used."; | |
const string Title = "Forbidden Type Analyzer"; | |
const string MessageFormat = "Access to type {0} is forbidden."; | |
const string Category = "API Usage"; | |
readonly HashSet<string> _forbiddenTypeNames; | |
static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[] | |
{ | |
"System.Console", | |
"System.Environment", | |
"System.IntPtr", | |
"System.Type" | |
}; | |
static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor( | |
DiagnosticId, | |
Title, | |
MessageFormat, | |
Category, | |
DiagnosticSeverity.Warning, | |
isEnabledByDefault: true, | |
Description); | |
public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList) | |
{ | |
} | |
ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames) | |
{ | |
_forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal); | |
} | |
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule); | |
public override void Initialize(AnalysisContext compilationContext) | |
{ | |
compilationContext.EnableConcurrentExecution(); | |
compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); | |
compilationContext.RegisterOperationAction(context => | |
{ | |
var type = context.Operation switch | |
{ | |
IObjectCreationOperation objectCreation => objectCreation.Type, | |
IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType, | |
IMemberReferenceOperation memberReference => memberReference.Member.ContainingType, | |
IArrayCreationOperation arrayCreation => arrayCreation.Type, | |
IAddressOfOperation addressOf => addressOf.Type, | |
IConversionOperation conversion => conversion.OperatorMethod?.ContainingType, | |
IUnaryOperation unary => unary.OperatorMethod?.ContainingType, | |
IBinaryOperation binary => binary.OperatorMethod?.ContainingType, | |
IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType, | |
_ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}") | |
}; | |
VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax); | |
}, | |
OperationKind.ObjectCreation, | |
OperationKind.Invocation, | |
OperationKind.EventReference, | |
OperationKind.FieldReference, | |
OperationKind.MethodReference, | |
OperationKind.PropertyReference, | |
OperationKind.ArrayCreation, | |
OperationKind.AddressOf, | |
OperationKind.Conversion, | |
OperationKind.UnaryOperator, | |
OperationKind.BinaryOperator, | |
OperationKind.Increment, | |
OperationKind.Decrement); | |
} | |
bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode) | |
{ | |
do | |
{ | |
if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type)) | |
{ | |
return false; | |
} | |
if (type is null) | |
{ | |
// Type will be null for arrays and pointers. | |
return true; | |
} | |
var typeName = type.ToString(); | |
if (typeName is null) | |
{ | |
return true; | |
} | |
if (_forbiddenTypeNames.Contains(typeName)) | |
{ | |
reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName)); | |
return false; | |
} | |
type = type.ContainingType; | |
} | |
while (!(type is null)); | |
return true; | |
} | |
bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition) | |
{ | |
switch (type) | |
{ | |
case INamedTypeSymbol namedTypeSymbol: | |
originalDefinition = namedTypeSymbol.ConstructedFrom; | |
foreach (var typeArgument in namedTypeSymbol.TypeArguments) | |
{ | |
if (typeArgument.TypeKind != TypeKind.TypeParameter && | |
typeArgument.TypeKind != TypeKind.Error && | |
!VerifyType(reportDiagnostic, typeArgument, syntaxNode)) | |
{ | |
return false; | |
} | |
} | |
break; | |
case IArrayTypeSymbol arrayTypeSymbol: | |
originalDefinition = null; | |
return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode); | |
case IPointerTypeSymbol pointerTypeSymbol: | |
originalDefinition = null; | |
return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode); | |
default: | |
originalDefinition = type?.OriginalDefinition; | |
break; | |
} | |
return true; | |
} | |
} | |
// UNIT TESTS | |
using System.Collections.Immutable; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using Microsoft.CodeAnalysis; | |
using Microsoft.CodeAnalysis.CSharp.Scripting; | |
using Microsoft.CodeAnalysis.Diagnostics; | |
using Microsoft.CodeAnalysis.Scripting; | |
using Xunit; | |
public class ForbiddenTypeAnalyzerTests | |
{ | |
[Theory] | |
[InlineData(@"if (IsRequest) { | |
Reply(""test""); | |
} | |
else { | |
var env = Environment.CommandLine; | |
Reply(env); | |
}", "System.Environment")] | |
[InlineData(@"if (IsRequest) { | |
Reply(""test""); | |
} | |
else { | |
var env = Environment.GetEnvironmentVariable(""test""); | |
Reply(env); | |
}", "System.Environment")] | |
[InlineData(@"var args = Environment.CommandLine;", "System.Environment")] | |
[InlineData(@"Console.WriteLine(""test"");", "System.Console")] | |
[InlineData(@"var ptr = new IntPtr(32);", "System.IntPtr")] | |
[InlineData(@"var type = (Type)SomeType;", "System.Type")] | |
[InlineData(@"var type = (Type)SomeType; var name = type.Name;", "System.Type")] | |
[InlineData(@"var type = Type.GetType(""name"");", "System.Type")] | |
[InlineData(@"var pointers = new IntPtr[4];", "System.IntPtr")] | |
public async Task ReturnsErrorsForForbiddenTypes(string code, string expectedForbiddenType) | |
{ | |
var options = ScriptOptions.Default | |
.WithImports("System") | |
.WithEmitDebugInformation(true) | |
.WithReferences("System.Runtime.Extensions", "System.Console") | |
.WithAllowUnsafe(false); | |
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options); | |
var compilation = script.GetCompilation(); | |
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>( | |
new ForbiddenTypeAnalyzer()); | |
var compilationWithAnalyzers = new CompilationWithAnalyzers( | |
compilation, | |
analyzers, | |
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty), | |
CancellationToken.None); | |
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync(); | |
var diagnostic = Assert.Single(diagnosticResults); | |
Assert.NotNull(diagnostic); | |
Assert.Equal( | |
$"Access to type {expectedForbiddenType} is forbidden.", | |
diagnostic.GetMessage()); | |
Assert.Equal(ForbiddenTypeAnalyzer.DiagnosticId, diagnostic.Id); | |
} | |
[Fact] | |
public async Task DoesNotReturnsErrorsForAllowedTypes() | |
{ | |
const string code = @"if (IsRequest) { | |
Reply(""test""); | |
} | |
else { | |
var rnd = new Random(); | |
Reply(rnd.Next(1).ToString()); | |
}"; | |
var options = ScriptOptions.Default | |
.WithImports("System") | |
.WithEmitDebugInformation(true) | |
.WithReferences("System.Runtime.Extensions") | |
.WithAllowUnsafe(false); | |
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options); | |
var compilation = script.GetCompilation(); | |
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>( | |
new ForbiddenTypeAnalyzer()); | |
var compilationWithAnalyzers = new CompilationWithAnalyzers( | |
compilation, | |
analyzers, | |
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty), | |
CancellationToken.None); | |
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync(); | |
Assert.Empty(diagnosticResults); | |
} | |
public interface IScriptGlobals | |
{ | |
bool IsRequest { get; } | |
void Reply(string reply); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I opened a bug dotnet/roslyn-analyzers#4399