Skip to content

Instantly share code, notes, and snippets.

@Triavanicus
Created April 16, 2025 16:03
Show Gist options
  • Save Triavanicus/585cbc807019c624c8fabf5b864cb84b to your computer and use it in GitHub Desktop.
Save Triavanicus/585cbc807019c624c8fabf5b864cb84b to your computer and use it in GitHub Desktop.
Over-engineered lua classes
---@class Interface
---@class Class
---@field toTable fun(self): table
---@field sameClass fun(self, other: Class): boolean
---@field isInstance fun(self, classType: ClassT): boolean
---@field implements fun(self, interface: Interface): boolean
---@class ClassT
---@field new fun(self, string): Class
---@field super fun(self): nil
local function wrapMethod(cls, func)
local wrapped = function(self, ...)
local prev = self.__currentClass
self.__currentClass = cls
local result = func(self, ...)
self.__currentClass = prev
return result
end
-- Store the original function in a weak table
if not _G.__original_methods then
_G.__original_methods = setmetatable({}, { __mode = "v" })
end
_G.__original_methods[wrapped] = func
return wrapped
end
local function wrapMethods(cls)
for k, v in pairs(cls) do
if type(v) == "function" and k ~= "new" and k ~= "super" then
cls[k] = wrapMethod(cls, v)
end
end
end
local function isInstance(obj, class)
local mt = getmetatable(obj)
while mt do
if mt == class then return true end
-- Check the entire inheritance chain
local current = mt
while current do
if current == class then return true end
current = current.super
end
mt = getmetatable(mt)
end
return false
end
local function sameClass(obj1, obj2)
local mt1 = getmetatable(obj1)
local mt2 = getmetatable(obj2)
return mt1 and mt2 and mt1.__type == mt2.__type
end
local function getSuper(obj)
local currentClass = obj.__currentClass
if not currentClass then
error("super() called outside a wrapped method")
end
local parent = currentClass.super
if not parent then
error("No super exists for class: " .. tostring(currentClass.__type))
end
-- Create a proxy that will use the parent's methods
local proxy = {
__parent = parent,
__obj = obj
}
setmetatable(proxy, {
__index = function(_, key)
local method = parent[key]
if type(method) == "function" then
return function(_, ...)
local prev = obj.__currentClass
obj.__currentClass = parent
local result = method(obj, ...)
obj.__currentClass = prev
return result
end
else
return method
end
end
})
return proxy
end
local function toTable(obj)
local result = {}
for k, v in pairs(obj) do
if type(v) ~= "function" and k ~= "__currentClass" and k ~= "super" then
if type(v) == "table" then
result[k] = toTable(v)
else
result[k] = v
end
end
end
return result
end
local function generateClassName()
local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
local length = 16
local result = ""
for i = 1, length do
local rand = math.random(1, #chars)
result = result .. chars:sub(rand, rand)
end
return result
end
local function countParameters(func)
local info = debug.getinfo(func, "u")
return info.nparams - 1 -- Subtract 1 for self parameter
end
local function Interface(stub)
-- Get the caller's info to find the variable name
local callerInfo = debug.getinfo(2, "S")
local interfaceName = "Interface"
-- Try to find the variable name in the caller's environment
local env = getfenv(2)
for name, value in pairs(env) do
if value == stub then
interfaceName = name
break
end
end
local interface = {
__type = interfaceName,
requiredMethods = {}
}
-- Convert stub object to required methods
for methodName, method in pairs(stub) do
if type(method) == "function" then
interface.requiredMethods[methodName] = {
name = methodName,
argCount = countParameters(method),
stub = method
}
end
end
function interface:implements(obj)
for methodName, methodInfo in pairs(self.requiredMethods) do
local method = obj[methodName]
if type(method) ~= "function" then
return false
end
-- Validate argument count
local originalMethod = _G.__original_methods[method] or method
local actualArgCount = countParameters(originalMethod)
if actualArgCount ~= methodInfo.argCount then
return false
end
end
return true
end
return interface
end
local function implements(obj, interface)
local mt = getmetatable(obj)
if not mt then return false end
-- Check if the class directly implements the interface
if mt.__interfaces then
for _, impl in pairs(mt.__interfaces) do
if impl == interface then
return true
end
end
end
-- Check parent classes
local current = mt
while current do
if current.__interfaces then
for _, impl in pairs(current.__interfaces) do
if impl == interface then
return true
end
end
end
current = current.super
end
return false
end
local function validateInterfaces(cls)
if not cls.__interfaces_to_validate then return end
-- Create a temporary table to store interfaces that need validation
local interfacesToValidate = {}
for _, interface in ipairs(cls.__interfaces_to_validate) do
if not cls.__interfaces[interface] then
table.insert(interfacesToValidate, interface)
end
end
-- Validate only the new interfaces
for _, interface in ipairs(interfacesToValidate) do
if not interface.requiredMethods then
error("Invalid interface provided")
end
-- Check if all required methods are implemented
for methodName, methodInfo in pairs(interface.requiredMethods) do
local method = cls[methodName]
if type(method) ~= "function" then
-- Get the interface's definition info
local interfaceDefInfo = debug.getinfo(interface.requiredMethods[methodName].stub, "S")
local interfaceLocation = interfaceDefInfo.short_src .. ":" .. interfaceDefInfo.linedefined
error("Class does not implement required method: " .. methodName ..
"\nClass defined at: " .. cls.__definition_location ..
"\nInterface defined at: " .. interfaceLocation, 0)
end
-- Validate argument count
local originalMethod = _G.__original_methods[method] or method
local actualArgCount = countParameters(originalMethod)
if actualArgCount ~= methodInfo.argCount then
-- Get the method's definition info from the original function
local methodDefInfo = debug.getinfo(originalMethod, "S")
local methodLocation = methodDefInfo.short_src .. ":" .. methodDefInfo.linedefined
-- Get the interface's definition info
local interfaceDefInfo = debug.getinfo(interface.requiredMethods[methodName].stub, "S")
local interfaceLocation = interfaceDefInfo.short_src .. ":" .. interfaceDefInfo.linedefined
error("Method " ..
methodName ..
" has incorrect number of arguments. Expected " .. methodInfo.argCount .. ", got " .. actualArgCount ..
"\nMethod defined at: " .. methodLocation ..
"\nInterface defined at: " .. interfaceLocation, 0)
end
end
-- Add interface to class's interface list
cls.__interfaces[interface] = interface
end
-- Clear the validation list and mark as validated
cls.__interfaces_to_validate = nil
cls.__interfaces_validated = true
end
local function Class(base, interfaces)
local cls = {}
cls.__index = cls
cls.__type = generateClassName()
cls.__interfaces = {}
cls.__interfaces_to_validate = interfaces
cls.__interfaces_validated = false
-- Get the caller's info to find the call location
local callerInfo = debug.getinfo(2, "Sl")
cls.__definition_location = callerInfo.short_src .. ":" .. callerInfo.currentline
if base then
-- Set up proper inheritance chain
local mt = { __index = base }
setmetatable(cls, mt)
cls.super = base
-- Inherit interfaces from base class
if base.__interfaces then
for interface, _ in pairs(base.__interfaces) do
cls.__interfaces[interface] = interface
end
end
-- Inherit interfaces from all parent classes
local parent = base
while parent and parent.super do
parent = parent.super
if parent and parent.__interfaces then
for interface, _ in pairs(parent.__interfaces) do
cls.__interfaces[interface] = interface
end
end
end
-- Copy methods from base class
for k, v in pairs(base) do
if type(v) == "function" and k ~= "new" and k ~= "super" then
rawset(cls, k, v)
end
end
end
---@generic T
---@param self T
---@return T
function cls:new(...)
-- Validate interfaces only once when first object is created
if not self.__interfaces_validated then
validateInterfaces(self)
end
local obj = setmetatable({}, self)
obj.super = function() return getSuper(obj) end
obj.sameClass = function(_, other) return sameClass(obj, other) end
obj.isInstance = function(_, class_type) return isInstance(obj, class_type) end
obj.implements = function(_, interface) return implements(obj, interface) end
-- Set the current class before calling init
obj.__currentClass = cls
if obj.init then
obj:init(...)
end
obj.__currentClass = nil
return obj
end
-- Add toTable method to all instances
function cls:toTable()
return toTable(self)
end
-- Create a metatable for the class that automatically wraps methods
setmetatable(cls, {
__newindex = function(t, k, v)
if type(v) == "function" and k ~= "new" and k ~= "super" then
rawset(t, k, wrapMethod(t, v))
else
rawset(t, k, v)
end
end
})
-- Wrap any existing methods
wrapMethods(cls)
return cls
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment