-
-
Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
| using System; | |
| using System.Collections.Generic; | |
| using System.Collections.Immutable; | |
| using System.Linq; | |
| using Microsoft.CodeAnalysis; | |
| using Microsoft.CodeAnalysis.Diagnostics; | |
| using Microsoft.CodeAnalysis.Operations; | |
| // CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs | |
| [DiagnosticAnalyzer(LanguageNames.CSharp)] | |
| public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer | |
| { | |
| public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer); | |
| static readonly LocalizableString Description = "Restricts the set of types that may be used."; | |
| const string Title = "Forbidden Type Analyzer"; | |
| const string MessageFormat = "Access to type {0} is forbidden."; | |
| const string Category = "API Usage"; | |
| readonly HashSet<string> _forbiddenTypeNames; | |
| static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[] | |
| { | |
| "System.Console", | |
| "System.Environment", | |
| "System.IntPtr", | |
| "System.Type" | |
| }; | |
| static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor( | |
| DiagnosticId, | |
| Title, | |
| MessageFormat, | |
| Category, | |
| DiagnosticSeverity.Warning, | |
| isEnabledByDefault: true, | |
| Description); | |
| public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList) | |
| { | |
| } | |
| ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames) | |
| { | |
| _forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal); | |
| } | |
| public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule); | |
| public override void Initialize(AnalysisContext compilationContext) | |
| { | |
| compilationContext.EnableConcurrentExecution(); | |
| compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); | |
| compilationContext.RegisterOperationAction(context => | |
| { | |
| var type = context.Operation switch | |
| { | |
| IObjectCreationOperation objectCreation => objectCreation.Type, | |
| IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType, | |
| IMemberReferenceOperation memberReference => memberReference.Member.ContainingType, | |
| IArrayCreationOperation arrayCreation => arrayCreation.Type, | |
| IAddressOfOperation addressOf => addressOf.Type, | |
| IConversionOperation conversion => conversion.OperatorMethod?.ContainingType, | |
| IUnaryOperation unary => unary.OperatorMethod?.ContainingType, | |
| IBinaryOperation binary => binary.OperatorMethod?.ContainingType, | |
| IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType, | |
| _ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}") | |
| }; | |
| VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax); | |
| }, | |
| OperationKind.ObjectCreation, | |
| OperationKind.Invocation, | |
| OperationKind.EventReference, | |
| OperationKind.FieldReference, | |
| OperationKind.MethodReference, | |
| OperationKind.PropertyReference, | |
| OperationKind.ArrayCreation, | |
| OperationKind.AddressOf, | |
| OperationKind.Conversion, | |
| OperationKind.UnaryOperator, | |
| OperationKind.BinaryOperator, | |
| OperationKind.Increment, | |
| OperationKind.Decrement); | |
| } | |
| bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode) | |
| { | |
| do | |
| { | |
| if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type)) | |
| { | |
| return false; | |
| } | |
| if (type is null) | |
| { | |
| // Type will be null for arrays and pointers. | |
| return true; | |
| } | |
| var typeName = type.ToString(); | |
| if (typeName is null) | |
| { | |
| return true; | |
| } | |
| if (_forbiddenTypeNames.Contains(typeName)) | |
| { | |
| reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName)); | |
| return false; | |
| } | |
| type = type.ContainingType; | |
| } | |
| while (!(type is null)); | |
| return true; | |
| } | |
| bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition) | |
| { | |
| switch (type) | |
| { | |
| case INamedTypeSymbol namedTypeSymbol: | |
| originalDefinition = namedTypeSymbol.ConstructedFrom; | |
| foreach (var typeArgument in namedTypeSymbol.TypeArguments) | |
| { | |
| if (typeArgument.TypeKind != TypeKind.TypeParameter && | |
| typeArgument.TypeKind != TypeKind.Error && | |
| !VerifyType(reportDiagnostic, typeArgument, syntaxNode)) | |
| { | |
| return false; | |
| } | |
| } | |
| break; | |
| case IArrayTypeSymbol arrayTypeSymbol: | |
| originalDefinition = null; | |
| return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode); | |
| case IPointerTypeSymbol pointerTypeSymbol: | |
| originalDefinition = null; | |
| return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode); | |
| default: | |
| originalDefinition = type?.OriginalDefinition; | |
| break; | |
| } | |
| return true; | |
| } | |
| } | |
| // UNIT TESTS | |
| using System.Collections.Immutable; | |
| using System.Threading; | |
| using System.Threading.Tasks; | |
| using Microsoft.CodeAnalysis; | |
| using Microsoft.CodeAnalysis.CSharp.Scripting; | |
| using Microsoft.CodeAnalysis.Diagnostics; | |
| using Microsoft.CodeAnalysis.Scripting; | |
| using Xunit; | |
| public class ForbiddenTypeAnalyzerTests | |
| { | |
| [Theory] | |
| [InlineData(@"if (IsRequest) { | |
| Reply(""test""); | |
| } | |
| else { | |
| var env = Environment.CommandLine; | |
| Reply(env); | |
| }", "System.Environment")] | |
| [InlineData(@"if (IsRequest) { | |
| Reply(""test""); | |
| } | |
| else { | |
| var env = Environment.GetEnvironmentVariable(""test""); | |
| Reply(env); | |
| }", "System.Environment")] | |
| [InlineData(@"var args = Environment.CommandLine;", "System.Environment")] | |
| [InlineData(@"Console.WriteLine(""test"");", "System.Console")] | |
| [InlineData(@"var ptr = new IntPtr(32);", "System.IntPtr")] | |
| [InlineData(@"var type = (Type)SomeType;", "System.Type")] | |
| [InlineData(@"var type = (Type)SomeType; var name = type.Name;", "System.Type")] | |
| [InlineData(@"var type = Type.GetType(""name"");", "System.Type")] | |
| [InlineData(@"var pointers = new IntPtr[4];", "System.IntPtr")] | |
| public async Task ReturnsErrorsForForbiddenTypes(string code, string expectedForbiddenType) | |
| { | |
| var options = ScriptOptions.Default | |
| .WithImports("System") | |
| .WithEmitDebugInformation(true) | |
| .WithReferences("System.Runtime.Extensions", "System.Console") | |
| .WithAllowUnsafe(false); | |
| var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options); | |
| var compilation = script.GetCompilation(); | |
| var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>( | |
| new ForbiddenTypeAnalyzer()); | |
| var compilationWithAnalyzers = new CompilationWithAnalyzers( | |
| compilation, | |
| analyzers, | |
| new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty), | |
| CancellationToken.None); | |
| var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync(); | |
| var diagnostic = Assert.Single(diagnosticResults); | |
| Assert.NotNull(diagnostic); | |
| Assert.Equal( | |
| $"Access to type {expectedForbiddenType} is forbidden.", | |
| diagnostic.GetMessage()); | |
| Assert.Equal(ForbiddenTypeAnalyzer.DiagnosticId, diagnostic.Id); | |
| } | |
| [Fact] | |
| public async Task DoesNotReturnsErrorsForAllowedTypes() | |
| { | |
| const string code = @"if (IsRequest) { | |
| Reply(""test""); | |
| } | |
| else { | |
| var rnd = new Random(); | |
| Reply(rnd.Next(1).ToString()); | |
| }"; | |
| var options = ScriptOptions.Default | |
| .WithImports("System") | |
| .WithEmitDebugInformation(true) | |
| .WithReferences("System.Runtime.Extensions") | |
| .WithAllowUnsafe(false); | |
| var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options); | |
| var compilation = script.GetCompilation(); | |
| var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>( | |
| new ForbiddenTypeAnalyzer()); | |
| var compilationWithAnalyzers = new CompilationWithAnalyzers( | |
| compilation, | |
| analyzers, | |
| new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty), | |
| CancellationToken.None); | |
| var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync(); | |
| Assert.Empty(diagnosticResults); | |
| } | |
| public interface IScriptGlobals | |
| { | |
| bool IsRequest { get; } | |
| void Reply(string reply); | |
| } | |
| } |
To clarify, I've updated my repro a tiny bit to make sure it imports System.IO and references System.Runtime.Extensions. Then I added the following two lines to Program.cs
await TestCode(@"var x = new StreamWriter(""some-path"");");
await TestCode(@"new StreamWriter(""some-path"").Flush();");And the result is
Compiling `var x = new StreamWriter("some-path");` resulted in 0 diagnostics.
Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.
Contrary to what I said earlier (I messed up my testing). I'm not sure why my code never breaks into the RegisterOperationAction callback.
I forgot this was in the scripting dialect and not a top level function.
Yeah, I can't use C# 9 just yet. 😦
I changed the code to use the regular compilation.
var syntaxTree = CSharpSyntaxTree.ParseText(code);
var compilation = CSharpCompilation.Create(
"assemblyName",
new[] { syntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));And then ran this...
await TestCode(@"
using System.IO;
public class Test
{
public void DoStuff()
{
new StreamWriter(""test"").Flush();
}
}
");And got the result I expected (8,9): warning ForbiddenTypeAnalyzer: Access to type System.IO.StreamWriter is forbidden..
So it does appear to be a difference between the C# scripting dialect and C#.
Just for completeness sake, I tried this out with CSharpSymbolIsBannedAnalyzer. Here's the symbols file I used.
T:System.Console;Don't use System.Console
T:System.Environment;Don't use System.Environment
T:System.Type;Don't use System.Type
T:System.Reflection.MemberInfo;Don't use Reflection.
T:System.IO.StreamWriter;Don't use System.IO.StreamWriter.
Here are the test cases I used (same program as above but with the CSharpSymbolIsBannedAnalyzer analyzer swapped in.
await TestCode(@"new StreamWriter(""some-path"").Flush();");
await TestCode("var env = Environment.CommandLine;");
await TestCode("var type = (Type)SomeType; var name = type.Name;");
await TestCode(@"Console.WriteLine(""test"");");And here are the results.
Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.
Compiling `var env = Environment.CommandLine;` resulted in 1 diagnostics.
(1,11): warning RS0030: The symbol 'Environment' is banned in this project: Don't use System.Environment
Compiling `var type = (Type)SomeType; var name = type.Name;` resulted in 1 diagnostics.
(1,39): warning RS0030: The symbol 'MemberInfo' is banned in this project: Don't use Reflection.
Compiling `Console.WriteLine("test");` resulted in 0 diagnostics.
It seems like there are cases where the CSharpSymbolIsBannedAnalyzer doesn't work with CSharpScript. I'm curious to understand why. Is it a bug? By design?
I opened a bug dotnet/roslyn-analyzers#4399
Well regardless, I'll just need to take some time this evening to debug through this instead of guessing.