Created
August 30, 2024 06:20
-
-
Save leizaf/9fc17ec38d7b9e4d3cf4e4cb4afcea3e to your computer and use it in GitHub Desktop.
Zig MLIR Lexer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const cc = std.ascii.control_code; | |
const startsWith = std.mem.startsWith; | |
fn isDigitNotZero(c: u8) bool { | |
return switch (c) { | |
'1'...'9' => true, | |
else => false, | |
}; | |
} | |
pub const Token = struct { | |
tag: Tag, | |
loc: Loc, | |
const kwList = [_]struct { [:0]const u8, Tag }{ | |
.{ "affine_map", .kw_affine_map }, | |
.{ "affine_set", .kw_affine_set }, | |
.{ "array", .kw_array }, | |
.{ "attributes", .kw_attributes }, | |
.{ "bf16", .kw_bf16 }, | |
.{ "ceildiv", .kw_ceildiv }, | |
.{ "complex", .kw_complex }, | |
.{ "dense", .kw_dense }, | |
.{ "dense_resource", .kw_dense_resource }, | |
.{ "distinct", .kw_distinct }, | |
.{ "f16", .kw_f16 }, | |
.{ "f32", .kw_f32 }, | |
.{ "f64", .kw_f64 }, | |
.{ "f80", .kw_f80 }, | |
.{ "f8E5M2", .kw_f8E5M2 }, | |
.{ "f8E4M3FN", .kw_f8E4M3FN }, | |
.{ "f8E5M2FNUZ", .kw_f8E5M2FNUZ }, | |
.{ "f8E4M3FNUZ", .kw_f8E4M3FNUZ }, | |
.{ "f8E4M3B11FNUZ", .kw_f8E4M3B11FNUZ }, | |
.{ "f128", .kw_f128 }, | |
.{ "false", .kw_false }, | |
.{ "floordiv", .kw_floordiv }, | |
.{ "for", .kw_for }, | |
.{ "func", .kw_func }, | |
.{ "index", .kw_index }, | |
.{ "loc", .kw_loc }, | |
.{ "max", .kw_max }, | |
.{ "memref", .kw_memref }, | |
.{ "min", .kw_min }, | |
.{ "mod", .kw_mod }, | |
.{ "none", .kw_none }, | |
.{ "offset", .kw_offset }, | |
.{ "size", .kw_size }, | |
.{ "sparse", .kw_sparse }, | |
.{ "step", .kw_step }, | |
.{ "strided", .kw_strided }, | |
.{ "symbol", .kw_symbol }, | |
.{ "tensor", .kw_tensor }, | |
.{ "tf32", .kw_tf32 }, | |
.{ "to", .kw_to }, | |
.{ "true", .kw_true }, | |
.{ "tuple", .kw_tuple }, | |
.{ "type", .kw_type }, | |
.{ "unit", .kw_unit }, | |
.{ "vector", .kw_vector }, | |
}; | |
const keywords = std.ComptimeStringMap(Tag, kwList); | |
pub fn getKw(s: []const u8) ?Tag { | |
return keywords.get(s); | |
} | |
pub const Tag = union(enum) { | |
eof, | |
// Identifiers | |
id_bare, | |
id_prefixed, | |
// Literals | |
lit_decimal, | |
lit_hex, | |
lit_float, | |
lit_string, | |
lit_int_type, | |
// Punctuation | |
pn_arrow, | |
pn_at, | |
pn_colon, | |
pn_comma, | |
pn_ellipsis, | |
pn_equal, | |
pn_greater, | |
pn_l_brace, | |
pn_l_paren, | |
pn_l_square, | |
pn_less, | |
pn_minus, | |
pn_plus, | |
pn_question, | |
pn_r_brace, | |
pn_r_paren, | |
pn_r_square, | |
pn_star, | |
pn_v_bar, | |
// Keywords | |
kw_affine_map, | |
kw_affine_set, | |
kw_array, | |
kw_attributes, | |
kw_bf16, | |
kw_ceildiv, | |
kw_complex, | |
kw_dense, | |
kw_dense_resource, | |
kw_distinct, | |
kw_f16, | |
kw_f32, | |
kw_f64, | |
kw_f80, | |
kw_f8E5M2, | |
kw_f8E4M3FN, | |
kw_f8E5M2FNUZ, | |
kw_f8E4M3FNUZ, | |
kw_f8E4M3B11FNUZ, | |
kw_f128, | |
kw_false, | |
kw_floordiv, | |
kw_for, | |
kw_func, | |
kw_index, | |
kw_loc, | |
kw_max, | |
kw_memref, | |
kw_min, | |
kw_mod, | |
kw_none, | |
kw_offset, | |
kw_size, | |
kw_sparse, | |
kw_step, | |
kw_strided, | |
kw_symbol, | |
kw_tensor, | |
kw_tf32, | |
kw_to, | |
kw_true, | |
kw_tuple, | |
kw_type, | |
kw_unit, | |
kw_vector, | |
// | |
illegal_char, | |
illegal_string, | |
illegal_ellipsis, | |
illegal_suffix_id, | |
}; | |
pub const Loc = struct { | |
start: usize, | |
end: usize, | |
}; | |
}; | |
pub const Tokenizer = struct { | |
const Self = @This(); | |
buffer: [:0]const u8, | |
index: usize = 0, | |
pub fn init(buffer: [:0]const u8) Tokenizer { | |
return Self{ | |
.buffer = buffer, | |
}; | |
} | |
const State = enum { | |
start, | |
comment, | |
ellipsis, | |
minus, | |
/// Lex a bare identifier or keyword that starts with a letter | |
/// | |
/// bare-id ::= (letter|[_]) (letter|digit|[_$.])* | |
/// | |
bare_id_or_kw, | |
/// Lex a suffix-id (which should follow a prefix) | |
/// | |
/// suffix-id ::= digit+ | (letter|id-punct) (letter|id-punct|digit)* | |
/// id-punct ::= `$` | `.` | `_` | `-` | |
/// | |
suffix_id, | |
/// Lex an int-type literal | |
/// | |
/// integer-type ::= `[su]?i[1-9][0-9]*` | |
/// | |
int_type, | |
/// Lex a decimal integer literal | |
/// | |
/// decimal-literal ::= digit+ | |
/// digit ::= [0-9] | |
/// | |
decimal, | |
/// Lex a float literal | |
/// | |
/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)? | |
/// | |
float, | |
float_exp, | |
float_end, | |
/// Lex a hexadecimal literal | |
/// | |
/// hexadecimal-literal ::= `0x` hex_digit+ | |
/// hex_digit ::= [0-9a-fA-F] | |
/// | |
maybe_hex, | |
hex, | |
/// Lex a string literal | |
/// | |
/// string-literal ::= `"` [^"\n\f\v\r]* `"` | |
/// | |
string, | |
}; | |
fn isIntType(s: []const u8) bool { | |
const intTypePrefixes = [3][]const u8{ "i", "si", "ui" }; | |
inline for (intTypePrefixes) |prefix| { | |
if (s.len > prefix.len and startsWith(u8, s, prefix) and isDigitNotZero(s[prefix.len])) { | |
return true; | |
} | |
} | |
return false; | |
} | |
fn isBareIdentChar(c: u8) bool { | |
return switch (c) { | |
'0'...'9', | |
'a'...'z', | |
'A'...'Z', | |
'$', | |
'.', | |
'_', | |
'-', | |
=> true, | |
else => false, | |
}; | |
} | |
pub fn next(self: *Self) Token { | |
var result = Token{ | |
.tag = .eof, | |
.loc = .{ | |
.start = self.index, | |
.end = undefined, | |
}, | |
}; | |
var state = State.start; | |
while (self.index <= self.buffer.len) : (self.index += 1) { | |
const c = self.buffer[self.index]; | |
if (state == .start) { | |
result.loc.start = self.index; | |
} | |
switch (state) { | |
State.start => switch (c) { | |
0 => { | |
if (self.index != self.buffer.len) { | |
result.tag = .eof; | |
result.loc.start = self.index; | |
self.index += 1; | |
result.loc.end = self.index; | |
return result; | |
} | |
break; | |
}, | |
// Skip whitespace | |
' ', '\n', '\t', '\r', '\x0b', '\x0c' => result.loc.end += 1, | |
'/' => state = State.comment, | |
'.' => state = State.ellipsis, | |
'-' => state = State.minus, | |
'0' => state = State.maybe_hex, | |
'1'...'9' => state = State.decimal, | |
'"' => state = State.string, | |
'#', '%', '^', '!' => state = State.suffix_id, | |
'_' => state = State.bare_id_or_kw, | |
'a'...'z', 'A'...'Z' => { | |
state = State.bare_id_or_kw; | |
continue; | |
}, | |
'@' => { | |
result.tag = .pn_at; | |
self.index += 1; | |
break; | |
}, | |
':' => { | |
result.tag = .pn_colon; | |
self.index += 1; | |
break; | |
}, | |
',' => { | |
result.tag = .pn_comma; | |
self.index += 1; | |
break; | |
}, | |
'=' => { | |
result.tag = .pn_equal; | |
self.index += 1; | |
break; | |
}, | |
'>' => { | |
result.tag = .pn_greater; | |
self.index += 1; | |
break; | |
}, | |
'{' => { | |
result.tag = .pn_l_brace; | |
self.index += 1; | |
break; | |
}, | |
'(' => { | |
result.tag = .pn_l_paren; | |
self.index += 1; | |
break; | |
}, | |
'[' => { | |
result.tag = .pn_l_square; | |
self.index += 1; | |
break; | |
}, | |
'<' => { | |
result.tag = .pn_less; | |
self.index += 1; | |
break; | |
}, | |
'+' => { | |
result.tag = .pn_plus; | |
self.index += 1; | |
break; | |
}, | |
'?' => { | |
result.tag = .pn_question; | |
self.index += 1; | |
break; | |
}, | |
'}' => { | |
result.tag = .pn_r_brace; | |
self.index += 1; | |
break; | |
}, | |
')' => { | |
result.tag = .pn_r_paren; | |
self.index += 1; | |
break; | |
}, | |
']' => { | |
result.tag = .pn_r_square; | |
self.index += 1; | |
break; | |
}, | |
'*' => { | |
result.tag = .pn_star; | |
self.index += 1; | |
break; | |
}, | |
'|' => { | |
result.tag = .pn_v_bar; | |
self.index += 1; | |
break; | |
}, | |
else => { | |
result.tag = .illegal_char; | |
self.index += 1; | |
break; | |
}, | |
}, | |
State.comment => switch (self.buffer[result.loc.start + 1]) { | |
'/' => if (c == '\n' or c == '\r') { | |
state = State.start; | |
}, | |
else => { | |
result.tag = .illegal_char; | |
self.index += 1; | |
break; | |
}, | |
}, | |
State.ellipsis => switch (c) { | |
'.' => {}, | |
else => { | |
if (self.index - result.loc.start == 3) { | |
result.tag = .pn_ellipsis; | |
self.index += 1; | |
break; | |
} | |
result.tag = .illegal_ellipsis; | |
break; | |
}, | |
}, | |
State.minus => switch (c) { | |
'>' => { | |
result.tag = .pn_arrow; | |
self.index += 1; | |
break; | |
}, | |
else => { | |
result.tag = .pn_minus; | |
break; | |
}, | |
}, | |
State.bare_id_or_kw => { | |
const s = self.buffer[result.loc.start .. self.index + 1]; | |
if (isIntType(s)) { | |
state = State.int_type; | |
continue; | |
} | |
if (isBareIdentChar(c)) continue; | |
result.tag = if (Token.getKw(s[0 .. s.len - 1])) |kw| kw else .id_bare; | |
break; | |
}, | |
State.suffix_id => { | |
switch (self.buffer[result.loc.start + 1]) { | |
'0'...'9' => if (std.ascii.isDigit(c)) continue, | |
'a'...'z', 'A'...'Z', '$', '.', '_', '-' => if (isBareIdentChar(c)) continue, | |
else => { | |
result.tag = .illegal_suffix_id; | |
break; | |
}, | |
} | |
result.tag = .id_prefixed; | |
break; | |
}, | |
State.int_type => switch (c) { | |
'0'...'9' => {}, | |
else => { | |
result.tag = .lit_int_type; | |
break; | |
}, | |
}, | |
State.decimal => switch (c) { | |
'0'...'9' => {}, | |
'.' => state = State.float, | |
else => { | |
result.tag = .lit_decimal; | |
break; | |
}, | |
}, | |
State.float => switch (c) { | |
'0'...'9' => {}, | |
'e', 'E' => state = State.float_exp, | |
else => { | |
state = State.float_end; | |
self.index -= 1; | |
}, | |
}, | |
State.float_exp => switch (c) { | |
'0'...'9', '-', '+' => state = State.float_end, | |
else => { | |
state = State.float_end; | |
self.index -= 2; | |
}, | |
}, | |
State.float_end => switch (c) { | |
'0'...'9' => {}, | |
else => { | |
result.tag = .lit_float; | |
break; | |
}, | |
}, | |
State.maybe_hex => { | |
const dist = self.index - result.loc.start; | |
if ((dist == 1 and c == 'x') or (dist == 2 and c == ' ')) continue; | |
if (dist == 3 and std.ascii.isHex(c)) { | |
state = State.hex; | |
continue; | |
} | |
self.index = result.loc.start; | |
state = State.decimal; | |
}, | |
State.hex => switch (c) { | |
'0'...'9', 'a'...'f', 'A'...'F' => {}, | |
else => { | |
result.tag = .lit_hex; | |
break; | |
}, | |
}, | |
State.string => switch (c) { | |
'\n', cc.vt, cc.ff, '\r', 0 => { | |
result.tag = .illegal_string; | |
break; | |
}, | |
'"' => { | |
if (self.buffer[self.index - 1] == '\\') continue; | |
result.tag = .lit_string; | |
self.index += 1; | |
break; | |
}, | |
else => {}, | |
}, | |
} | |
} | |
result.loc.end = self.index; | |
return result; | |
} | |
}; | |
test "keywords" { | |
for (Token.kwList) |kv| { | |
try testTokenize(kv[0], &.{kv[1]}); | |
} | |
} | |
test "single character punctuation" { | |
try testTokenize( | |
\\@ | |
\\: | |
\\, | |
\\= | |
\\> | |
\\{ | |
\\( | |
\\[ | |
\\< | |
\\+ | |
\\? | |
\\} | |
\\) | |
\\] | |
\\* | |
\\| | |
, &.{ | |
.pn_at, | |
.pn_colon, | |
.pn_comma, | |
.pn_equal, | |
.pn_greater, | |
.pn_l_brace, | |
.pn_l_paren, | |
.pn_l_square, | |
.pn_less, | |
.pn_plus, | |
.pn_question, | |
.pn_r_brace, | |
.pn_r_paren, | |
.pn_r_square, | |
.pn_star, | |
.pn_v_bar, | |
}); | |
} | |
test "skip line comment" { | |
try testTokenize( | |
\\// line comment | |
\\{} | |
, &.{ | |
.pn_l_brace, | |
.pn_r_brace, | |
}); | |
} | |
test "ellipsis" { | |
try testTokenize(".. ... .... ", &.{ .illegal_ellipsis, .pn_ellipsis, .illegal_ellipsis }); | |
} | |
test "minus" { | |
try testTokenize("-> -", &.{ .pn_arrow, .pn_minus }); | |
} | |
test "newline in string literal" { | |
try testTokenize( | |
\\" | |
\\" | |
, &.{ .illegal_string, .illegal_string }); | |
} | |
test "identifiers" { | |
try testTokenizeLiteral( | |
\\foo | |
\\_bar | |
\\foo42bar_$.- | |
, &.{ | |
.id_bare, | |
.id_bare, | |
.id_bare, | |
}, &.{ | |
"foo", | |
"_bar", | |
"foo42bar_$.-", | |
}); | |
try testTokenizeLiteral( | |
\\@foo | |
\\@"bar" | |
, &.{ | |
.pn_at, | |
.id_bare, | |
.pn_at, | |
.lit_string, | |
}, &.{ | |
"@", | |
"foo", | |
"@", | |
"\"bar\"", | |
}); | |
try testTokenizeLiteral( | |
\\%42bar | |
, &.{ | |
.id_prefixed, | |
.id_bare, | |
}, &.{ | |
"%42", | |
"bar", | |
}); | |
} | |
test "int-type literal" { | |
try testTokenizeLiteral( | |
\\i123 | |
\\si456 | |
\\ui789 | |
\\i0 | |
, &.{ | |
.lit_int_type, | |
.lit_int_type, | |
.lit_int_type, | |
.id_bare, | |
}, &.{ | |
"i123", | |
"si456", | |
"ui789", | |
"i0", | |
}); | |
} | |
test "general" { | |
try testTokenize("func.func @simple(i64, i1) -> i64 {}", &.{ | |
.id_bare, | |
.pn_at, | |
.id_bare, | |
.pn_l_paren, | |
.lit_int_type, | |
.pn_comma, | |
.lit_int_type, | |
.pn_r_paren, | |
.pn_arrow, | |
.lit_int_type, | |
.pn_l_brace, | |
.pn_r_brace, | |
}); | |
try testTokenize( | |
\\^bb2: | |
\\ %b = arith.addi %a, %a : i64 | |
\\ cf.br ^bb3(%b: i64) | |
, &.{ | |
.id_prefixed, | |
.pn_colon, | |
.id_prefixed, | |
.pn_equal, | |
.id_bare, | |
.id_prefixed, | |
.pn_comma, | |
.id_prefixed, | |
.pn_colon, | |
.lit_int_type, | |
.id_bare, | |
.id_prefixed, | |
.pn_l_paren, | |
.id_prefixed, | |
.pn_colon, | |
.lit_int_type, | |
.pn_r_paren, | |
}); | |
} | |
fn testTokenize(source: [:0]const u8, expected_tags: []const Token.Tag) !void { | |
var tokenizer = Tokenizer.init(source); | |
for (expected_tags) |expected_token_tag| { | |
const token = tokenizer.next(); | |
try std.testing.expectEqual(expected_token_tag, token.tag); | |
} | |
const last_token = tokenizer.next(); | |
try std.testing.expectEqual(Token.Tag.eof, last_token.tag); | |
try std.testing.expectEqual(source.len, last_token.loc.start); | |
try std.testing.expectEqual(source.len, last_token.loc.end); | |
} | |
fn testTokenizeLiteral(source: [:0]const u8, expected_tags: []const Token.Tag, expected_literals: []const []const u8) !void { | |
var tokenizer = Tokenizer.init(source); | |
for (expected_tags, expected_literals) |t, l| { | |
const token = tokenizer.next(); | |
try std.testing.expectEqual(t, token.tag); | |
try std.testing.expectEqualStrings(source[token.loc.start..token.loc.end], l); | |
} | |
const last_token = tokenizer.next(); | |
try std.testing.expectEqual(Token.Tag.eof, last_token.tag); | |
try std.testing.expectEqual(source.len, last_token.loc.start); | |
try std.testing.expectEqual(source.len, last_token.loc.end); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment