Created
May 21, 2017 18:24
-
-
Save sczizzo/fb03f36f170f94974f5ed63ba1310168 to your computer and use it in GitHub Desktop.
module T - Naive typechecking for Ruby
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module T | |
TypeError = Class.new(StandardError) | |
end | |
class T::Base | |
attr_reader :types | |
def initialize(*types) | |
@types = types | |
end | |
def match?(obj) | |
raise NotImplementedError | |
end | |
end | |
class T::Array < T::Base | |
def match?(obj) | |
return false unless obj.is_a?(::Array) | |
return true if types.empty? | |
obj.all? do |item| | |
types.any? { |t| t.match?(item) } | |
end | |
end | |
end | |
class T::Arrow < T::Base | |
attr_reader :arg_ts, :rval_t | |
def initialize(*types) | |
@arg_ts = types[0...-1] | |
@rval_t = types[-1] | |
end | |
def match?(obj) | |
return false unless obj.is_a?(::Array) | |
# If we get an obj without an rval, just check args | |
ignore_rval = obj.size == arg_ts.size | |
args = obj[ignore_rval ? 0..-1 : 0...-1] | |
rval = obj[-1] | |
return false unless arg_ts.size == args.size | |
arg_ts.each_with_index do |t, i| | |
return false unless t.match?(args[i]) | |
end | |
ignore_rval || rval_t.match?(rval) | |
end | |
end | |
class T::Hash < T::Base | |
attr_reader :key_t, :value_t | |
def initialize(*args) | |
return if args.empty? | |
raise ArgumentError unless args.size == 2 | |
@key_t = args.first | |
@value_t = args.last | |
end | |
def match?(obj) | |
return false unless obj.is_a?(::Hash) | |
return true unless key_t && value_t # Don't check KV types | |
obj.all? do |key, value| | |
key_t.match?(key) && value_t.match?(value) | |
end | |
end | |
end | |
class T::Tuple < T::Base | |
def match?(obj) | |
return false unless obj.is_a?(::Array) | |
return false if types.size != obj.size | |
types.each_with_index do |type, i| | |
return false unless type.match?(obj[i]) | |
end | |
true | |
end | |
end | |
class T::Unit < T::Base | |
def initialie; end | |
def match?(obj) | |
obj.is_a?(Array) && obj.empty? | |
end | |
end | |
class T::Nil < T::Base | |
def initialize; end | |
def match?(obj) | |
obj.nil? | |
end | |
end | |
class T::Either < T::Base | |
def match?(obj) | |
types.any? { |t| t.match?(obj) } | |
end | |
end | |
class T::Any < T::Base | |
def match?(obj) | |
true | |
end | |
end | |
class T::Klass | |
attr_reader :args, :klass | |
def initialize(klass, *args) | |
@klass = klass | |
@args = args | |
end | |
def >>(other); self.class.new(T::Hash, self, other) end | |
def ~; self.class.new(T::Arrow, self) end | |
def /(other); self.class.new(T::Arrow, self, other) end | |
def *(other) | |
if klass == T::Arrow && other.klass == T::Arrow | |
self.class.new(T::Arrow, *args, *other.args) | |
elsif klass == T::Arrow | |
self.class.new(T::Arrow, *args, other) | |
else | |
self / other | |
end | |
end | |
def +(other); self.class.new(T::Either, self, other) end | |
def |(other) | |
if klass == T::Either && other.klass == T::Either | |
self.class.new(T::Either, *args, *other.args) | |
elsif klass == T::Either | |
self.class.new(T::Either, *args, other) | |
else | |
self + other | |
end | |
end | |
def ^(other); self.class.new(T::Tuple, self, other) end | |
def &(other) | |
if klass == T::Tuple && other.klass == T::Tuple | |
self.class.new(T::Tuple, *args, *other.args) | |
elsif klass == T::Tuple | |
self.class.new(T::Tuple, *args, other) | |
else | |
self ^ other | |
end | |
end | |
def match?(obj) | |
if klassy? | |
this_klass.match?(obj) | |
elsif extra_klassy? | |
extra_klass.match?(obj) | |
else | |
obj.is_a?(klass) | |
end | |
end | |
def assert!(obj) | |
return if match?(obj) | |
raise TypeError, | |
"object '#{obj.inspect}' does not match type '#{self.inspect}'" | |
end | |
private | |
def klasses | |
@klasses ||= ::T.constants.map(&::T.method(:const_get)).grep(Class) | |
end | |
def klassy? | |
@klassy ||= klasses.include?(klass) | |
end | |
def this_klass | |
@this_klass ||= klass.new(*args) | |
end | |
def extra_klassy? | |
@extra_klassy ||= klass.constants.include?(:T) && klass::T.is_a?(Class) | |
end | |
def extra_klass | |
@extra_klass ||= klass::T.new(*args) | |
end | |
end | |
module T | |
@@enabled = false | |
def self.enabled? | |
@@enabled | |
end | |
def self.enable! | |
@@enabled = true | |
end | |
def self.disable! | |
@@enabled = false | |
end | |
def self.method_missing(*args) | |
type = args.shift | |
type = const_get(type.to_s) | |
Klass.new(type, *args) | |
end | |
end | |
T.enable! | |
def T(annotation) | |
return unless ::T.enabled? | |
annotation.each do |meth, type| | |
old_meth = "__T_#{meth}".to_sym | |
alias_method(old_meth, meth) | |
define_method(meth) do |*args, &block| | |
raise ::T::TypeError, | |
"given type '#{type.klass}' is not an Arrow" unless type.klass == ::T::Arrow | |
raise ::T::TypeError, | |
"given type '#{type.klass}' does not match args" unless type.match?(args) | |
rval = send(old_meth, *args, &block) | |
raise ::T::TypeError, | |
"given type '#{type.klass}' does not match rval" unless type.match?([*args, rval]) | |
puts meth: meth, rval: rval, typecheck?: true | |
rval | |
end | |
end | |
end | |
def T?(assertions) | |
assertions.all? { |obj, type| type.match?(obj) } | |
end | |
def T!(assertions) | |
assertions.each { |obj, type| type.assert!(obj) } | |
end | |
################################################################################ | |
## Playground ################################################################## | |
################################################################################ | |
Actor = Struct.new(:name) | |
foo = Actor.new('foo') | |
T_Actor = T.Actor # Type alias | |
T! foo => T_Actor | |
T! :symbol => T.Symbol | |
tup = [:a, :b] | |
bad_tup = [:a, 'bad'] | |
bad_tup2 = [:a, :b, :bad] | |
T! tup => T.Tuple(T.Symbol, T.Symbol) | |
#T bad_tup => T.Tuple(T.Symbol, T.Symbol) | |
#T bad_tup2 => T.Tuple(T.Symbol, T.Symbol) | |
unit = [] | |
empty = [] | |
anything = ['what', 1, StandardError, :goes, /EVER/] | |
T! unit => T.Tuple | |
T! unit => T.Unit | |
T! empty => T.Array | |
T! empty => T.Unit | |
T! anything => T.Array | |
actor = Actor.new('Tom') | |
T! actor => T.Actor | |
actors = [Actor.new('Joe'), Actor.new('Sue')] | |
T! actors => T.Array(T.Actor) | |
mixed = [1, :a, 'b'] | |
T! mixed => T.Array(T.Fixnum, T.Symbol, T.String) | |
index = { foo: 'bar', top: 'bottom' } | |
T! index => T.Hash(T.Symbol, T.String) | |
ids = [[1, Actor.new('Jane')], [2, Actor.new('Jose')]] | |
T! ids => T.Array(T.Tuple(T.Fixnum, T.Actor)) | |
T_Entry = T.Tuple(T.Fixnum, T.Actor) | |
T! ids => T.Array(T_Entry) | |
this = 1 | |
that = String.new | |
other = nil | |
T_Either = T.Fixnum | T.String | T.Nil | |
T! this => T_Either, | |
that => T_Either, | |
other => T_Either | |
#T :bad => T_Either | |
T_Leaf = T.Fixnum | |
T_Node = T_Leaf & T_Leaf | |
T_Tree = T_Node | T_Leaf | |
left = 1 | |
right = 2 | |
node = [left, right] | |
tree = node | |
T! left => T_Leaf, | |
node => T_Node, | |
tree => T_Tree | |
################################################################################ | |
## User-defined types ########################################################## | |
################################################################################ | |
Error = Struct.new(:reason) | |
# No Error::T, so we'll only check that objects are an instance of Error | |
class Result | |
attr_reader :value, :error | |
def initialize(value, error = nil) | |
@value = value | |
@error = error | |
end | |
def ok? | |
@ok ||= error.nil? | |
end | |
def error? | |
@error ||= !ok? | |
end | |
end | |
# Magic sauce. Now we can traverse the Result object and make sure that not | |
# only is it an instance of Result, but that exactly one of the value and error | |
# attributes matches its assigned type while the other is nil | |
class Result::T | |
attr_reader :value_t, :error_t | |
def initialize(value_t, error_t) | |
@value_t = value_t | |
@error_t = error_t | |
end | |
def match?(obj) | |
return false unless obj.is_a?(Result) | |
(value_t.match?(obj.value) && ::T.Nil.match?(obj.error)) || | |
(error_t.match?(obj.error) && ::T.Nil.match?(obj.value)) | |
end | |
end | |
T_StringResult = T.Result(T.String, T.Error) | |
result = Result.new('foo') | |
T! result => T_StringResult | |
error = Result.new(nil, Error.new('bad result')) | |
T! error => T_StringResult | |
bad_result = Result.new(:bad) | |
puts bad_result_typechecks?: T?(bad_result => T_StringResult) | |
#T! bad_result => T.Result(T.Symbol, T.Error) | |
################################################################################ | |
## In a project ################################################################ | |
################################################################################ | |
class Person | |
attr_reader :name, :age | |
def initialize(name, age) | |
@name = name | |
@age = age | |
end | |
def to_h | |
{ name: name, age: age } | |
end | |
private | |
T to_h: ~(T.Symbol >> (T.String | T.Fixnum)) | |
end | |
class Greeter | |
attr_reader :message | |
def initialize(message) | |
@message = message | |
end | |
def greet(person) | |
message % person.to_h | |
end | |
def joke(person, jokiness) | |
m = jokiness > 10 ? JOKY_MESSAGE : HOKY_MESSAGE | |
m % person.to_h | |
end | |
private | |
T greet: T.Person * T.String | |
T joke: T.Person * T.Fixnum * T.String | |
JOKY_MESSAGE = "How would you know? You're only %{age} years old!" | |
HOKY_MESSAGE = "If I had a dollar for each of your years, I'd have %{age} dollars" | |
end | |
me = Person.new 'Sean', 26 | |
gg = Greeter.new 'hi %{name}' | |
gg.greet(me) | |
stan = Person.new 'Stan', 15 | |
gg.joke(stan, 5) | |
gg.joke(stan, 20) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment