Created
April 16, 2025 16:03
-
-
Save Triavanicus/585cbc807019c624c8fabf5b864cb84b to your computer and use it in GitHub Desktop.
Over-engineered lua classes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
---@class Interface | |
---@class Class | |
---@field toTable fun(self): table | |
---@field sameClass fun(self, other: Class): boolean | |
---@field isInstance fun(self, classType: ClassT): boolean | |
---@field implements fun(self, interface: Interface): boolean | |
---@class ClassT | |
---@field new fun(self, string): Class | |
---@field super fun(self): nil | |
local function wrapMethod(cls, func) | |
local wrapped = function(self, ...) | |
local prev = self.__currentClass | |
self.__currentClass = cls | |
local result = func(self, ...) | |
self.__currentClass = prev | |
return result | |
end | |
-- Store the original function in a weak table | |
if not _G.__original_methods then | |
_G.__original_methods = setmetatable({}, { __mode = "v" }) | |
end | |
_G.__original_methods[wrapped] = func | |
return wrapped | |
end | |
local function wrapMethods(cls) | |
for k, v in pairs(cls) do | |
if type(v) == "function" and k ~= "new" and k ~= "super" then | |
cls[k] = wrapMethod(cls, v) | |
end | |
end | |
end | |
local function isInstance(obj, class) | |
local mt = getmetatable(obj) | |
while mt do | |
if mt == class then return true end | |
-- Check the entire inheritance chain | |
local current = mt | |
while current do | |
if current == class then return true end | |
current = current.super | |
end | |
mt = getmetatable(mt) | |
end | |
return false | |
end | |
local function sameClass(obj1, obj2) | |
local mt1 = getmetatable(obj1) | |
local mt2 = getmetatable(obj2) | |
return mt1 and mt2 and mt1.__type == mt2.__type | |
end | |
local function getSuper(obj) | |
local currentClass = obj.__currentClass | |
if not currentClass then | |
error("super() called outside a wrapped method") | |
end | |
local parent = currentClass.super | |
if not parent then | |
error("No super exists for class: " .. tostring(currentClass.__type)) | |
end | |
-- Create a proxy that will use the parent's methods | |
local proxy = { | |
__parent = parent, | |
__obj = obj | |
} | |
setmetatable(proxy, { | |
__index = function(_, key) | |
local method = parent[key] | |
if type(method) == "function" then | |
return function(_, ...) | |
local prev = obj.__currentClass | |
obj.__currentClass = parent | |
local result = method(obj, ...) | |
obj.__currentClass = prev | |
return result | |
end | |
else | |
return method | |
end | |
end | |
}) | |
return proxy | |
end | |
local function toTable(obj) | |
local result = {} | |
for k, v in pairs(obj) do | |
if type(v) ~= "function" and k ~= "__currentClass" and k ~= "super" then | |
if type(v) == "table" then | |
result[k] = toTable(v) | |
else | |
result[k] = v | |
end | |
end | |
end | |
return result | |
end | |
local function generateClassName() | |
local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | |
local length = 16 | |
local result = "" | |
for i = 1, length do | |
local rand = math.random(1, #chars) | |
result = result .. chars:sub(rand, rand) | |
end | |
return result | |
end | |
local function countParameters(func) | |
local info = debug.getinfo(func, "u") | |
return info.nparams - 1 -- Subtract 1 for self parameter | |
end | |
local function Interface(stub) | |
-- Get the caller's info to find the variable name | |
local callerInfo = debug.getinfo(2, "S") | |
local interfaceName = "Interface" | |
-- Try to find the variable name in the caller's environment | |
local env = getfenv(2) | |
for name, value in pairs(env) do | |
if value == stub then | |
interfaceName = name | |
break | |
end | |
end | |
local interface = { | |
__type = interfaceName, | |
requiredMethods = {} | |
} | |
-- Convert stub object to required methods | |
for methodName, method in pairs(stub) do | |
if type(method) == "function" then | |
interface.requiredMethods[methodName] = { | |
name = methodName, | |
argCount = countParameters(method), | |
stub = method | |
} | |
end | |
end | |
function interface:implements(obj) | |
for methodName, methodInfo in pairs(self.requiredMethods) do | |
local method = obj[methodName] | |
if type(method) ~= "function" then | |
return false | |
end | |
-- Validate argument count | |
local originalMethod = _G.__original_methods[method] or method | |
local actualArgCount = countParameters(originalMethod) | |
if actualArgCount ~= methodInfo.argCount then | |
return false | |
end | |
end | |
return true | |
end | |
return interface | |
end | |
local function implements(obj, interface) | |
local mt = getmetatable(obj) | |
if not mt then return false end | |
-- Check if the class directly implements the interface | |
if mt.__interfaces then | |
for _, impl in pairs(mt.__interfaces) do | |
if impl == interface then | |
return true | |
end | |
end | |
end | |
-- Check parent classes | |
local current = mt | |
while current do | |
if current.__interfaces then | |
for _, impl in pairs(current.__interfaces) do | |
if impl == interface then | |
return true | |
end | |
end | |
end | |
current = current.super | |
end | |
return false | |
end | |
local function validateInterfaces(cls) | |
if not cls.__interfaces_to_validate then return end | |
-- Create a temporary table to store interfaces that need validation | |
local interfacesToValidate = {} | |
for _, interface in ipairs(cls.__interfaces_to_validate) do | |
if not cls.__interfaces[interface] then | |
table.insert(interfacesToValidate, interface) | |
end | |
end | |
-- Validate only the new interfaces | |
for _, interface in ipairs(interfacesToValidate) do | |
if not interface.requiredMethods then | |
error("Invalid interface provided") | |
end | |
-- Check if all required methods are implemented | |
for methodName, methodInfo in pairs(interface.requiredMethods) do | |
local method = cls[methodName] | |
if type(method) ~= "function" then | |
-- Get the interface's definition info | |
local interfaceDefInfo = debug.getinfo(interface.requiredMethods[methodName].stub, "S") | |
local interfaceLocation = interfaceDefInfo.short_src .. ":" .. interfaceDefInfo.linedefined | |
error("Class does not implement required method: " .. methodName .. | |
"\nClass defined at: " .. cls.__definition_location .. | |
"\nInterface defined at: " .. interfaceLocation, 0) | |
end | |
-- Validate argument count | |
local originalMethod = _G.__original_methods[method] or method | |
local actualArgCount = countParameters(originalMethod) | |
if actualArgCount ~= methodInfo.argCount then | |
-- Get the method's definition info from the original function | |
local methodDefInfo = debug.getinfo(originalMethod, "S") | |
local methodLocation = methodDefInfo.short_src .. ":" .. methodDefInfo.linedefined | |
-- Get the interface's definition info | |
local interfaceDefInfo = debug.getinfo(interface.requiredMethods[methodName].stub, "S") | |
local interfaceLocation = interfaceDefInfo.short_src .. ":" .. interfaceDefInfo.linedefined | |
error("Method " .. | |
methodName .. | |
" has incorrect number of arguments. Expected " .. methodInfo.argCount .. ", got " .. actualArgCount .. | |
"\nMethod defined at: " .. methodLocation .. | |
"\nInterface defined at: " .. interfaceLocation, 0) | |
end | |
end | |
-- Add interface to class's interface list | |
cls.__interfaces[interface] = interface | |
end | |
-- Clear the validation list and mark as validated | |
cls.__interfaces_to_validate = nil | |
cls.__interfaces_validated = true | |
end | |
local function Class(base, interfaces) | |
local cls = {} | |
cls.__index = cls | |
cls.__type = generateClassName() | |
cls.__interfaces = {} | |
cls.__interfaces_to_validate = interfaces | |
cls.__interfaces_validated = false | |
-- Get the caller's info to find the call location | |
local callerInfo = debug.getinfo(2, "Sl") | |
cls.__definition_location = callerInfo.short_src .. ":" .. callerInfo.currentline | |
if base then | |
-- Set up proper inheritance chain | |
local mt = { __index = base } | |
setmetatable(cls, mt) | |
cls.super = base | |
-- Inherit interfaces from base class | |
if base.__interfaces then | |
for interface, _ in pairs(base.__interfaces) do | |
cls.__interfaces[interface] = interface | |
end | |
end | |
-- Inherit interfaces from all parent classes | |
local parent = base | |
while parent and parent.super do | |
parent = parent.super | |
if parent and parent.__interfaces then | |
for interface, _ in pairs(parent.__interfaces) do | |
cls.__interfaces[interface] = interface | |
end | |
end | |
end | |
-- Copy methods from base class | |
for k, v in pairs(base) do | |
if type(v) == "function" and k ~= "new" and k ~= "super" then | |
rawset(cls, k, v) | |
end | |
end | |
end | |
---@generic T | |
---@param self T | |
---@return T | |
function cls:new(...) | |
-- Validate interfaces only once when first object is created | |
if not self.__interfaces_validated then | |
validateInterfaces(self) | |
end | |
local obj = setmetatable({}, self) | |
obj.super = function() return getSuper(obj) end | |
obj.sameClass = function(_, other) return sameClass(obj, other) end | |
obj.isInstance = function(_, class_type) return isInstance(obj, class_type) end | |
obj.implements = function(_, interface) return implements(obj, interface) end | |
-- Set the current class before calling init | |
obj.__currentClass = cls | |
if obj.init then | |
obj:init(...) | |
end | |
obj.__currentClass = nil | |
return obj | |
end | |
-- Add toTable method to all instances | |
function cls:toTable() | |
return toTable(self) | |
end | |
-- Create a metatable for the class that automatically wraps methods | |
setmetatable(cls, { | |
__newindex = function(t, k, v) | |
if type(v) == "function" and k ~= "new" and k ~= "super" then | |
rawset(t, k, wrapMethod(t, v)) | |
else | |
rawset(t, k, v) | |
end | |
end | |
}) | |
-- Wrap any existing methods | |
wrapMethods(cls) | |
return cls | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment