Skip to content

Instantly share code, notes, and snippets.

@leizaf
Created August 30, 2024 06:20
Show Gist options
  • Save leizaf/9fc17ec38d7b9e4d3cf4e4cb4afcea3e to your computer and use it in GitHub Desktop.
Save leizaf/9fc17ec38d7b9e4d3cf4e4cb4afcea3e to your computer and use it in GitHub Desktop.
Zig MLIR Lexer
const std = @import("std");
const cc = std.ascii.control_code;
const startsWith = std.mem.startsWith;
fn isDigitNotZero(c: u8) bool {
return switch (c) {
'1'...'9' => true,
else => false,
};
}
pub const Token = struct {
tag: Tag,
loc: Loc,
const kwList = [_]struct { [:0]const u8, Tag }{
.{ "affine_map", .kw_affine_map },
.{ "affine_set", .kw_affine_set },
.{ "array", .kw_array },
.{ "attributes", .kw_attributes },
.{ "bf16", .kw_bf16 },
.{ "ceildiv", .kw_ceildiv },
.{ "complex", .kw_complex },
.{ "dense", .kw_dense },
.{ "dense_resource", .kw_dense_resource },
.{ "distinct", .kw_distinct },
.{ "f16", .kw_f16 },
.{ "f32", .kw_f32 },
.{ "f64", .kw_f64 },
.{ "f80", .kw_f80 },
.{ "f8E5M2", .kw_f8E5M2 },
.{ "f8E4M3FN", .kw_f8E4M3FN },
.{ "f8E5M2FNUZ", .kw_f8E5M2FNUZ },
.{ "f8E4M3FNUZ", .kw_f8E4M3FNUZ },
.{ "f8E4M3B11FNUZ", .kw_f8E4M3B11FNUZ },
.{ "f128", .kw_f128 },
.{ "false", .kw_false },
.{ "floordiv", .kw_floordiv },
.{ "for", .kw_for },
.{ "func", .kw_func },
.{ "index", .kw_index },
.{ "loc", .kw_loc },
.{ "max", .kw_max },
.{ "memref", .kw_memref },
.{ "min", .kw_min },
.{ "mod", .kw_mod },
.{ "none", .kw_none },
.{ "offset", .kw_offset },
.{ "size", .kw_size },
.{ "sparse", .kw_sparse },
.{ "step", .kw_step },
.{ "strided", .kw_strided },
.{ "symbol", .kw_symbol },
.{ "tensor", .kw_tensor },
.{ "tf32", .kw_tf32 },
.{ "to", .kw_to },
.{ "true", .kw_true },
.{ "tuple", .kw_tuple },
.{ "type", .kw_type },
.{ "unit", .kw_unit },
.{ "vector", .kw_vector },
};
const keywords = std.ComptimeStringMap(Tag, kwList);
pub fn getKw(s: []const u8) ?Tag {
return keywords.get(s);
}
pub const Tag = union(enum) {
eof,
// Identifiers
id_bare,
id_prefixed,
// Literals
lit_decimal,
lit_hex,
lit_float,
lit_string,
lit_int_type,
// Punctuation
pn_arrow,
pn_at,
pn_colon,
pn_comma,
pn_ellipsis,
pn_equal,
pn_greater,
pn_l_brace,
pn_l_paren,
pn_l_square,
pn_less,
pn_minus,
pn_plus,
pn_question,
pn_r_brace,
pn_r_paren,
pn_r_square,
pn_star,
pn_v_bar,
// Keywords
kw_affine_map,
kw_affine_set,
kw_array,
kw_attributes,
kw_bf16,
kw_ceildiv,
kw_complex,
kw_dense,
kw_dense_resource,
kw_distinct,
kw_f16,
kw_f32,
kw_f64,
kw_f80,
kw_f8E5M2,
kw_f8E4M3FN,
kw_f8E5M2FNUZ,
kw_f8E4M3FNUZ,
kw_f8E4M3B11FNUZ,
kw_f128,
kw_false,
kw_floordiv,
kw_for,
kw_func,
kw_index,
kw_loc,
kw_max,
kw_memref,
kw_min,
kw_mod,
kw_none,
kw_offset,
kw_size,
kw_sparse,
kw_step,
kw_strided,
kw_symbol,
kw_tensor,
kw_tf32,
kw_to,
kw_true,
kw_tuple,
kw_type,
kw_unit,
kw_vector,
//
illegal_char,
illegal_string,
illegal_ellipsis,
illegal_suffix_id,
};
pub const Loc = struct {
start: usize,
end: usize,
};
};
pub const Tokenizer = struct {
const Self = @This();
buffer: [:0]const u8,
index: usize = 0,
pub fn init(buffer: [:0]const u8) Tokenizer {
return Self{
.buffer = buffer,
};
}
const State = enum {
start,
comment,
ellipsis,
minus,
/// Lex a bare identifier or keyword that starts with a letter
///
/// bare-id ::= (letter|[_]) (letter|digit|[_$.])*
///
bare_id_or_kw,
/// Lex a suffix-id (which should follow a prefix)
///
/// suffix-id ::= digit+ | (letter|id-punct) (letter|id-punct|digit)*
/// id-punct ::= `$` | `.` | `_` | `-`
///
suffix_id,
/// Lex an int-type literal
///
/// integer-type ::= `[su]?i[1-9][0-9]*`
///
int_type,
/// Lex a decimal integer literal
///
/// decimal-literal ::= digit+
/// digit ::= [0-9]
///
decimal,
/// Lex a float literal
///
/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
///
float,
float_exp,
float_end,
/// Lex a hexadecimal literal
///
/// hexadecimal-literal ::= `0x` hex_digit+
/// hex_digit ::= [0-9a-fA-F]
///
maybe_hex,
hex,
/// Lex a string literal
///
/// string-literal ::= `"` [^"\n\f\v\r]* `"`
///
string,
};
fn isIntType(s: []const u8) bool {
const intTypePrefixes = [3][]const u8{ "i", "si", "ui" };
inline for (intTypePrefixes) |prefix| {
if (s.len > prefix.len and startsWith(u8, s, prefix) and isDigitNotZero(s[prefix.len])) {
return true;
}
}
return false;
}
fn isBareIdentChar(c: u8) bool {
return switch (c) {
'0'...'9',
'a'...'z',
'A'...'Z',
'$',
'.',
'_',
'-',
=> true,
else => false,
};
}
pub fn next(self: *Self) Token {
var result = Token{
.tag = .eof,
.loc = .{
.start = self.index,
.end = undefined,
},
};
var state = State.start;
while (self.index <= self.buffer.len) : (self.index += 1) {
const c = self.buffer[self.index];
if (state == .start) {
result.loc.start = self.index;
}
switch (state) {
State.start => switch (c) {
0 => {
if (self.index != self.buffer.len) {
result.tag = .eof;
result.loc.start = self.index;
self.index += 1;
result.loc.end = self.index;
return result;
}
break;
},
// Skip whitespace
' ', '\n', '\t', '\r', '\x0b', '\x0c' => result.loc.end += 1,
'/' => state = State.comment,
'.' => state = State.ellipsis,
'-' => state = State.minus,
'0' => state = State.maybe_hex,
'1'...'9' => state = State.decimal,
'"' => state = State.string,
'#', '%', '^', '!' => state = State.suffix_id,
'_' => state = State.bare_id_or_kw,
'a'...'z', 'A'...'Z' => {
state = State.bare_id_or_kw;
continue;
},
'@' => {
result.tag = .pn_at;
self.index += 1;
break;
},
':' => {
result.tag = .pn_colon;
self.index += 1;
break;
},
',' => {
result.tag = .pn_comma;
self.index += 1;
break;
},
'=' => {
result.tag = .pn_equal;
self.index += 1;
break;
},
'>' => {
result.tag = .pn_greater;
self.index += 1;
break;
},
'{' => {
result.tag = .pn_l_brace;
self.index += 1;
break;
},
'(' => {
result.tag = .pn_l_paren;
self.index += 1;
break;
},
'[' => {
result.tag = .pn_l_square;
self.index += 1;
break;
},
'<' => {
result.tag = .pn_less;
self.index += 1;
break;
},
'+' => {
result.tag = .pn_plus;
self.index += 1;
break;
},
'?' => {
result.tag = .pn_question;
self.index += 1;
break;
},
'}' => {
result.tag = .pn_r_brace;
self.index += 1;
break;
},
')' => {
result.tag = .pn_r_paren;
self.index += 1;
break;
},
']' => {
result.tag = .pn_r_square;
self.index += 1;
break;
},
'*' => {
result.tag = .pn_star;
self.index += 1;
break;
},
'|' => {
result.tag = .pn_v_bar;
self.index += 1;
break;
},
else => {
result.tag = .illegal_char;
self.index += 1;
break;
},
},
State.comment => switch (self.buffer[result.loc.start + 1]) {
'/' => if (c == '\n' or c == '\r') {
state = State.start;
},
else => {
result.tag = .illegal_char;
self.index += 1;
break;
},
},
State.ellipsis => switch (c) {
'.' => {},
else => {
if (self.index - result.loc.start == 3) {
result.tag = .pn_ellipsis;
self.index += 1;
break;
}
result.tag = .illegal_ellipsis;
break;
},
},
State.minus => switch (c) {
'>' => {
result.tag = .pn_arrow;
self.index += 1;
break;
},
else => {
result.tag = .pn_minus;
break;
},
},
State.bare_id_or_kw => {
const s = self.buffer[result.loc.start .. self.index + 1];
if (isIntType(s)) {
state = State.int_type;
continue;
}
if (isBareIdentChar(c)) continue;
result.tag = if (Token.getKw(s[0 .. s.len - 1])) |kw| kw else .id_bare;
break;
},
State.suffix_id => {
switch (self.buffer[result.loc.start + 1]) {
'0'...'9' => if (std.ascii.isDigit(c)) continue,
'a'...'z', 'A'...'Z', '$', '.', '_', '-' => if (isBareIdentChar(c)) continue,
else => {
result.tag = .illegal_suffix_id;
break;
},
}
result.tag = .id_prefixed;
break;
},
State.int_type => switch (c) {
'0'...'9' => {},
else => {
result.tag = .lit_int_type;
break;
},
},
State.decimal => switch (c) {
'0'...'9' => {},
'.' => state = State.float,
else => {
result.tag = .lit_decimal;
break;
},
},
State.float => switch (c) {
'0'...'9' => {},
'e', 'E' => state = State.float_exp,
else => {
state = State.float_end;
self.index -= 1;
},
},
State.float_exp => switch (c) {
'0'...'9', '-', '+' => state = State.float_end,
else => {
state = State.float_end;
self.index -= 2;
},
},
State.float_end => switch (c) {
'0'...'9' => {},
else => {
result.tag = .lit_float;
break;
},
},
State.maybe_hex => {
const dist = self.index - result.loc.start;
if ((dist == 1 and c == 'x') or (dist == 2 and c == ' ')) continue;
if (dist == 3 and std.ascii.isHex(c)) {
state = State.hex;
continue;
}
self.index = result.loc.start;
state = State.decimal;
},
State.hex => switch (c) {
'0'...'9', 'a'...'f', 'A'...'F' => {},
else => {
result.tag = .lit_hex;
break;
},
},
State.string => switch (c) {
'\n', cc.vt, cc.ff, '\r', 0 => {
result.tag = .illegal_string;
break;
},
'"' => {
if (self.buffer[self.index - 1] == '\\') continue;
result.tag = .lit_string;
self.index += 1;
break;
},
else => {},
},
}
}
result.loc.end = self.index;
return result;
}
};
test "keywords" {
for (Token.kwList) |kv| {
try testTokenize(kv[0], &.{kv[1]});
}
}
test "single character punctuation" {
try testTokenize(
\\@
\\:
\\,
\\=
\\>
\\{
\\(
\\[
\\<
\\+
\\?
\\}
\\)
\\]
\\*
\\|
, &.{
.pn_at,
.pn_colon,
.pn_comma,
.pn_equal,
.pn_greater,
.pn_l_brace,
.pn_l_paren,
.pn_l_square,
.pn_less,
.pn_plus,
.pn_question,
.pn_r_brace,
.pn_r_paren,
.pn_r_square,
.pn_star,
.pn_v_bar,
});
}
test "skip line comment" {
try testTokenize(
\\// line comment
\\{}
, &.{
.pn_l_brace,
.pn_r_brace,
});
}
test "ellipsis" {
try testTokenize(".. ... .... ", &.{ .illegal_ellipsis, .pn_ellipsis, .illegal_ellipsis });
}
test "minus" {
try testTokenize("-> -", &.{ .pn_arrow, .pn_minus });
}
test "newline in string literal" {
try testTokenize(
\\"
\\"
, &.{ .illegal_string, .illegal_string });
}
test "identifiers" {
try testTokenizeLiteral(
\\foo
\\_bar
\\foo42bar_$.-
, &.{
.id_bare,
.id_bare,
.id_bare,
}, &.{
"foo",
"_bar",
"foo42bar_$.-",
});
try testTokenizeLiteral(
\\@foo
\\@"bar"
, &.{
.pn_at,
.id_bare,
.pn_at,
.lit_string,
}, &.{
"@",
"foo",
"@",
"\"bar\"",
});
try testTokenizeLiteral(
\\%42bar
, &.{
.id_prefixed,
.id_bare,
}, &.{
"%42",
"bar",
});
}
test "int-type literal" {
try testTokenizeLiteral(
\\i123
\\si456
\\ui789
\\i0
, &.{
.lit_int_type,
.lit_int_type,
.lit_int_type,
.id_bare,
}, &.{
"i123",
"si456",
"ui789",
"i0",
});
}
test "general" {
try testTokenize("func.func @simple(i64, i1) -> i64 {}", &.{
.id_bare,
.pn_at,
.id_bare,
.pn_l_paren,
.lit_int_type,
.pn_comma,
.lit_int_type,
.pn_r_paren,
.pn_arrow,
.lit_int_type,
.pn_l_brace,
.pn_r_brace,
});
try testTokenize(
\\^bb2:
\\ %b = arith.addi %a, %a : i64
\\ cf.br ^bb3(%b: i64)
, &.{
.id_prefixed,
.pn_colon,
.id_prefixed,
.pn_equal,
.id_bare,
.id_prefixed,
.pn_comma,
.id_prefixed,
.pn_colon,
.lit_int_type,
.id_bare,
.id_prefixed,
.pn_l_paren,
.id_prefixed,
.pn_colon,
.lit_int_type,
.pn_r_paren,
});
}
fn testTokenize(source: [:0]const u8, expected_tags: []const Token.Tag) !void {
var tokenizer = Tokenizer.init(source);
for (expected_tags) |expected_token_tag| {
const token = tokenizer.next();
try std.testing.expectEqual(expected_token_tag, token.tag);
}
const last_token = tokenizer.next();
try std.testing.expectEqual(Token.Tag.eof, last_token.tag);
try std.testing.expectEqual(source.len, last_token.loc.start);
try std.testing.expectEqual(source.len, last_token.loc.end);
}
fn testTokenizeLiteral(source: [:0]const u8, expected_tags: []const Token.Tag, expected_literals: []const []const u8) !void {
var tokenizer = Tokenizer.init(source);
for (expected_tags, expected_literals) |t, l| {
const token = tokenizer.next();
try std.testing.expectEqual(t, token.tag);
try std.testing.expectEqualStrings(source[token.loc.start..token.loc.end], l);
}
const last_token = tokenizer.next();
try std.testing.expectEqual(Token.Tag.eof, last_token.tag);
try std.testing.expectEqual(source.len, last_token.loc.start);
try std.testing.expectEqual(source.len, last_token.loc.end);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment