-
-
Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
using System; | |
using System.Collections.Generic; | |
using System.Collections.Immutable; | |
using System.Linq; | |
using Microsoft.CodeAnalysis; | |
using Microsoft.CodeAnalysis.Diagnostics; | |
using Microsoft.CodeAnalysis.Operations; | |
// CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs | |
[DiagnosticAnalyzer(LanguageNames.CSharp)] | |
public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer | |
{ | |
public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer); | |
static readonly LocalizableString Description = "Restricts the set of types that may be used."; | |
const string Title = "Forbidden Type Analyzer"; | |
const string MessageFormat = "Access to type {0} is forbidden."; | |
const string Category = "API Usage"; | |
readonly HashSet<string> _forbiddenTypeNames; | |
static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[] | |
{ | |
"System.Console", | |
"System.Environment", | |
"System.IntPtr", | |
"System.Type" | |
}; | |
static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor( | |
DiagnosticId, | |
Title, | |
MessageFormat, | |
Category, | |
DiagnosticSeverity.Warning, | |
isEnabledByDefault: true, | |
Description); | |
public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList) | |
{ | |
} | |
ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames) | |
{ | |
_forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal); | |
} | |
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule); | |
public override void Initialize(AnalysisContext compilationContext) | |
{ | |
compilationContext.EnableConcurrentExecution(); | |
compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); | |
compilationContext.RegisterOperationAction(context => | |
{ | |
var type = context.Operation switch | |
{ | |
IObjectCreationOperation objectCreation => objectCreation.Type, | |
IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType, | |
IMemberReferenceOperation memberReference => memberReference.Member.ContainingType, | |
IArrayCreationOperation arrayCreation => arrayCreation.Type, | |
IAddressOfOperation addressOf => addressOf.Type, | |
IConversionOperation conversion => conversion.OperatorMethod?.ContainingType, | |
IUnaryOperation unary => unary.OperatorMethod?.ContainingType, | |
IBinaryOperation binary => binary.OperatorMethod?.ContainingType, | |
IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType, | |
_ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}") | |
}; | |
VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax); | |
}, | |
OperationKind.ObjectCreation, | |
OperationKind.Invocation, | |
OperationKind.EventReference, | |
OperationKind.FieldReference, | |
OperationKind.MethodReference, | |
OperationKind.PropertyReference, | |
OperationKind.ArrayCreation, | |
OperationKind.AddressOf, | |
OperationKind.Conversion, | |
OperationKind.UnaryOperator, | |
OperationKind.BinaryOperator, | |
OperationKind.Increment, | |
OperationKind.Decrement); | |
} | |
bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode) | |
{ | |
do | |
{ | |
if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type)) | |
{ | |
return false; | |
} | |
if (type is null) | |
{ | |
// Type will be null for arrays and pointers. | |
return true; | |
} | |
var typeName = type.ToString(); | |
if (typeName is null) | |
{ | |
return true; | |
} | |
if (_forbiddenTypeNames.Contains(typeName)) | |
{ | |
reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName)); | |
return false; | |
} | |
type = type.ContainingType; | |
} | |
while (!(type is null)); | |
return true; | |
} | |
bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition) | |
{ | |
switch (type) | |
{ | |
case INamedTypeSymbol namedTypeSymbol: | |
originalDefinition = namedTypeSymbol.ConstructedFrom; | |
foreach (var typeArgument in namedTypeSymbol.TypeArguments) | |
{ | |
if (typeArgument.TypeKind != TypeKind.TypeParameter && | |
typeArgument.TypeKind != TypeKind.Error && | |
!VerifyType(reportDiagnostic, typeArgument, syntaxNode)) | |
{ | |
return false; | |
} | |
} | |
break; | |
case IArrayTypeSymbol arrayTypeSymbol: | |
originalDefinition = null; | |
return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode); | |
case IPointerTypeSymbol pointerTypeSymbol: | |
originalDefinition = null; | |
return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode); | |
default: | |
originalDefinition = type?.OriginalDefinition; | |
break; | |
} | |
return true; | |
} | |
} | |
// UNIT TESTS | |
using System.Collections.Immutable; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using Microsoft.CodeAnalysis; | |
using Microsoft.CodeAnalysis.CSharp.Scripting; | |
using Microsoft.CodeAnalysis.Diagnostics; | |
using Microsoft.CodeAnalysis.Scripting; | |
using Xunit; | |
public class ForbiddenTypeAnalyzerTests | |
{ | |
[Theory] | |
[InlineData(@"if (IsRequest) { | |
Reply(""test""); | |
} | |
else { | |
var env = Environment.CommandLine; | |
Reply(env); | |
}", "System.Environment")] | |
[InlineData(@"if (IsRequest) { | |
Reply(""test""); | |
} | |
else { | |
var env = Environment.GetEnvironmentVariable(""test""); | |
Reply(env); | |
}", "System.Environment")] | |
[InlineData(@"var args = Environment.CommandLine;", "System.Environment")] | |
[InlineData(@"Console.WriteLine(""test"");", "System.Console")] | |
[InlineData(@"var ptr = new IntPtr(32);", "System.IntPtr")] | |
[InlineData(@"var type = (Type)SomeType;", "System.Type")] | |
[InlineData(@"var type = (Type)SomeType; var name = type.Name;", "System.Type")] | |
[InlineData(@"var type = Type.GetType(""name"");", "System.Type")] | |
[InlineData(@"var pointers = new IntPtr[4];", "System.IntPtr")] | |
public async Task ReturnsErrorsForForbiddenTypes(string code, string expectedForbiddenType) | |
{ | |
var options = ScriptOptions.Default | |
.WithImports("System") | |
.WithEmitDebugInformation(true) | |
.WithReferences("System.Runtime.Extensions", "System.Console") | |
.WithAllowUnsafe(false); | |
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options); | |
var compilation = script.GetCompilation(); | |
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>( | |
new ForbiddenTypeAnalyzer()); | |
var compilationWithAnalyzers = new CompilationWithAnalyzers( | |
compilation, | |
analyzers, | |
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty), | |
CancellationToken.None); | |
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync(); | |
var diagnostic = Assert.Single(diagnosticResults); | |
Assert.NotNull(diagnostic); | |
Assert.Equal( | |
$"Access to type {expectedForbiddenType} is forbidden.", | |
diagnostic.GetMessage()); | |
Assert.Equal(ForbiddenTypeAnalyzer.DiagnosticId, diagnostic.Id); | |
} | |
[Fact] | |
public async Task DoesNotReturnsErrorsForAllowedTypes() | |
{ | |
const string code = @"if (IsRequest) { | |
Reply(""test""); | |
} | |
else { | |
var rnd = new Random(); | |
Reply(rnd.Next(1).ToString()); | |
}"; | |
var options = ScriptOptions.Default | |
.WithImports("System") | |
.WithEmitDebugInformation(true) | |
.WithReferences("System.Runtime.Extensions") | |
.WithAllowUnsafe(false); | |
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options); | |
var compilation = script.GetCompilation(); | |
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>( | |
new ForbiddenTypeAnalyzer()); | |
var compilationWithAnalyzers = new CompilationWithAnalyzers( | |
compilation, | |
analyzers, | |
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty), | |
CancellationToken.None); | |
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync(); | |
Assert.Empty(diagnosticResults); | |
} | |
public interface IScriptGlobals | |
{ | |
bool IsRequest { get; } | |
void Reply(string reply); | |
} | |
} |
Well regardless, I'll just need to take some time this evening to debug through this instead of guessing.
To clarify, I've updated my repro a tiny bit to make sure it imports System.IO
and references System.Runtime.Extensions
. Then I added the following two lines to Program.cs
await TestCode(@"var x = new StreamWriter(""some-path"");");
await TestCode(@"new StreamWriter(""some-path"").Flush();");
And the result is
Compiling `var x = new StreamWriter("some-path");` resulted in 0 diagnostics.
Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.
Contrary to what I said earlier (I messed up my testing). I'm not sure why my code never breaks into the RegisterOperationAction
callback.
I forgot this was in the scripting dialect and not a top level function.
Yeah, I can't use C# 9 just yet. 😦
I changed the code to use the regular compilation.
var syntaxTree = CSharpSyntaxTree.ParseText(code);
var compilation = CSharpCompilation.Create(
"assemblyName",
new[] { syntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
And then ran this...
await TestCode(@"
using System.IO;
public class Test
{
public void DoStuff()
{
new StreamWriter(""test"").Flush();
}
}
");
And got the result I expected (8,9): warning ForbiddenTypeAnalyzer: Access to type System.IO.StreamWriter is forbidden.
.
So it does appear to be a difference between the C# scripting dialect and C#.
Just for completeness sake, I tried this out with CSharpSymbolIsBannedAnalyzer
. Here's the symbols file I used.
T:System.Console;Don't use System.Console
T:System.Environment;Don't use System.Environment
T:System.Type;Don't use System.Type
T:System.Reflection.MemberInfo;Don't use Reflection.
T:System.IO.StreamWriter;Don't use System.IO.StreamWriter.
Here are the test cases I used (same program as above but with the CSharpSymbolIsBannedAnalyzer
analyzer swapped in.
await TestCode(@"new StreamWriter(""some-path"").Flush();");
await TestCode("var env = Environment.CommandLine;");
await TestCode("var type = (Type)SomeType; var name = type.Name;");
await TestCode(@"Console.WriteLine(""test"");");
And here are the results.
Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.
Compiling `var env = Environment.CommandLine;` resulted in 1 diagnostics.
(1,11): warning RS0030: The symbol 'Environment' is banned in this project: Don't use System.Environment
Compiling `var type = (Type)SomeType; var name = type.Name;` resulted in 1 diagnostics.
(1,39): warning RS0030: The symbol 'MemberInfo' is banned in this project: Don't use Reflection.
Compiling `Console.WriteLine("test");` resulted in 0 diagnostics.
It seems like there are cases where the CSharpSymbolIsBannedAnalyzer
doesn't work with CSharpScript
. I'm curious to understand why. Is it a bug? By design?
I opened a bug dotnet/roslyn-analyzers#4399
I forgot this was in the scripting dialect and not a top level function. @jaredpar what are the language rules for these cases?