Skip to content

Instantly share code, notes, and snippets.

@haacked
Last active May 14, 2023 18:38
Show Gist options
  • Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Roslyn Analyzer to warn about access to forbidden types
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
// CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer
{
public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer);
static readonly LocalizableString Description = "Restricts the set of types that may be used.";
const string Title = "Forbidden Type Analyzer";
const string MessageFormat = "Access to type {0} is forbidden.";
const string Category = "API Usage";
readonly HashSet<string> _forbiddenTypeNames;
static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[]
{
"System.Console",
"System.Environment",
"System.IntPtr",
"System.Type"
};
static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(
DiagnosticId,
Title,
MessageFormat,
Category,
DiagnosticSeverity.Warning,
isEnabledByDefault: true,
Description);
public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList)
{
}
ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames)
{
_forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal);
}
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule);
public override void Initialize(AnalysisContext compilationContext)
{
compilationContext.EnableConcurrentExecution();
compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
compilationContext.RegisterOperationAction(context =>
{
var type = context.Operation switch
{
IObjectCreationOperation objectCreation => objectCreation.Type,
IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType,
IMemberReferenceOperation memberReference => memberReference.Member.ContainingType,
IArrayCreationOperation arrayCreation => arrayCreation.Type,
IAddressOfOperation addressOf => addressOf.Type,
IConversionOperation conversion => conversion.OperatorMethod?.ContainingType,
IUnaryOperation unary => unary.OperatorMethod?.ContainingType,
IBinaryOperation binary => binary.OperatorMethod?.ContainingType,
IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType,
_ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}")
};
VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax);
},
OperationKind.ObjectCreation,
OperationKind.Invocation,
OperationKind.EventReference,
OperationKind.FieldReference,
OperationKind.MethodReference,
OperationKind.PropertyReference,
OperationKind.ArrayCreation,
OperationKind.AddressOf,
OperationKind.Conversion,
OperationKind.UnaryOperator,
OperationKind.BinaryOperator,
OperationKind.Increment,
OperationKind.Decrement);
}
bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode)
{
do
{
if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type))
{
return false;
}
if (type is null)
{
// Type will be null for arrays and pointers.
return true;
}
var typeName = type.ToString();
if (typeName is null)
{
return true;
}
if (_forbiddenTypeNames.Contains(typeName))
{
reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName));
return false;
}
type = type.ContainingType;
}
while (!(type is null));
return true;
}
bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition)
{
switch (type)
{
case INamedTypeSymbol namedTypeSymbol:
originalDefinition = namedTypeSymbol.ConstructedFrom;
foreach (var typeArgument in namedTypeSymbol.TypeArguments)
{
if (typeArgument.TypeKind != TypeKind.TypeParameter &&
typeArgument.TypeKind != TypeKind.Error &&
!VerifyType(reportDiagnostic, typeArgument, syntaxNode))
{
return false;
}
}
break;
case IArrayTypeSymbol arrayTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode);
case IPointerTypeSymbol pointerTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode);
default:
originalDefinition = type?.OriginalDefinition;
break;
}
return true;
}
}
// UNIT TESTS
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Scripting;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Scripting;
using Xunit;
public class ForbiddenTypeAnalyzerTests
{
[Theory]
[InlineData(@"if (IsRequest) {
Reply(""test"");
}
else {
var env = Environment.CommandLine;
Reply(env);
}", "System.Environment")]
[InlineData(@"if (IsRequest) {
Reply(""test"");
}
else {
var env = Environment.GetEnvironmentVariable(""test"");
Reply(env);
}", "System.Environment")]
[InlineData(@"var args = Environment.CommandLine;", "System.Environment")]
[InlineData(@"Console.WriteLine(""test"");", "System.Console")]
[InlineData(@"var ptr = new IntPtr(32);", "System.IntPtr")]
[InlineData(@"var type = (Type)SomeType;", "System.Type")]
[InlineData(@"var type = (Type)SomeType; var name = type.Name;", "System.Type")]
[InlineData(@"var type = Type.GetType(""name"");", "System.Type")]
[InlineData(@"var pointers = new IntPtr[4];", "System.IntPtr")]
public async Task ReturnsErrorsForForbiddenTypes(string code, string expectedForbiddenType)
{
var options = ScriptOptions.Default
.WithImports("System")
.WithEmitDebugInformation(true)
.WithReferences("System.Runtime.Extensions", "System.Console")
.WithAllowUnsafe(false);
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
var compilation = script.GetCompilation();
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
new ForbiddenTypeAnalyzer());
var compilationWithAnalyzers = new CompilationWithAnalyzers(
compilation,
analyzers,
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
CancellationToken.None);
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync();
var diagnostic = Assert.Single(diagnosticResults);
Assert.NotNull(diagnostic);
Assert.Equal(
$"Access to type {expectedForbiddenType} is forbidden.",
diagnostic.GetMessage());
Assert.Equal(ForbiddenTypeAnalyzer.DiagnosticId, diagnostic.Id);
}
[Fact]
public async Task DoesNotReturnsErrorsForAllowedTypes()
{
const string code = @"if (IsRequest) {
Reply(""test"");
}
else {
var rnd = new Random();
Reply(rnd.Next(1).ToString());
}";
var options = ScriptOptions.Default
.WithImports("System")
.WithEmitDebugInformation(true)
.WithReferences("System.Runtime.Extensions")
.WithAllowUnsafe(false);
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
var compilation = script.GetCompilation();
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
new ForbiddenTypeAnalyzer());
var compilationWithAnalyzers = new CompilationWithAnalyzers(
compilation,
analyzers,
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
CancellationToken.None);
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync();
Assert.Empty(diagnosticResults);
}
public interface IScriptGlobals
{
bool IsRequest { get; }
void Reply(string reply);
}
}
@haacked
Copy link
Author

haacked commented Nov 1, 2020

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment