Skip to content

Commit

Permalink
feat: derive macro (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger-luo authored Mar 21, 2023
1 parent bc3667c commit d44c2e2
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/adt/adt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("types.jl")
include("emit.jl")
include("match.jl")
include("use.jl")
include("derive.jl")
include("print.jl")

end # ADT
92 changes: 92 additions & 0 deletions src/adt/derive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
macro derive(ex)
esc(derive_m(__module__, __source__, ex))
end

function derive_m(mod::Module, line::LineNumberNode, ex::Expr)
@switch ex begin
@case :($name:$(first::Symbol))
others = ()
@case Expr(:tuple, :($name:$(first::Symbol)), [e::Symbol for e in others]...)
@case _
error("Invalid expression")
end

expr_map((first, others...)) do rule
msg = "$(rule) is not defined"
isdefined(mod, rule) || :(error($msg))
derive_rule(getfield(mod, rule), mod, line, name)
end
end

function derive_rule(rule, m::Module, line::LineNumberNode, Self::Symbol)
msg = "derive_rule for $(rule) is not defined"
return :(error($msg))
end

macro derive_rule(jlfn::Expr)
jlfn = JLFunction(jlfn)
esc(derive_rule_m(__module__, jlfn))
end

function derive_rule_m(mod::Module, jlfn::JLFunction)
length(jlfn.args) == 3 || error("Invalid function signature")

fn_type = @match jlfn.name begin
Expr(:., path, name::QuoteNode) => begin
m = guess_module(mod, path)
isdefined(m, name.value) || error("$(jlfn.name) is not defined")
typeof(getfield(m, name.value))
end
name::Symbol => begin
isdefined(mod, name) || error("$(jlfn.name) is not defined")
typeof(getfield(mod, name))
end
_ => error("Invalid function name: $(jlfn.name)")
end

pushfirst!(jlfn.args, :(::$fn_type))
jlfn.name = GlobalRef(@__MODULE__, :derive_rule)
return codegen_ast(jlfn)
end

@derive_rule function hash(m::Module, line::LineNumberNode, Self::Symbol)
isdefined(m, Self) || error("$(Self) is not defined")
quote
function $Base.hash(x::$Self, h::UInt)
type = $ADT.variant_type(x)
h = hash(type, h)
for idx in $ADT.variant_masks(x)
h = hash($Base.getfield(x, idx), h)
end
return h
end
end
end

@derive_rule function isequal(m::Module, line::LineNumberNode, Self::Symbol)
isdefined(m, Self) || error("$(Self) is not defined")
quote
function $Base.isequal(lhs::$Self, rhs::$Self)
$ADT.variant_type(lhs) == $ADT.variant_type(rhs) || return false

for idx in $ADT.variant_masks(lhs) # mask is the same for both
isequal($Base.getfield(lhs, idx), $Base.getfield(rhs, idx)) || return false
end
return true
end
end
end

@derive_rule function ==(m::Module, line::LineNumberNode, Self::Symbol)
isdefined(m, Self) || error("$(Self) is not defined")
quote
function $Base.:(==)(lhs::$Self, rhs::$Self)
$ADT.variant_type(lhs) == $ADT.variant_type(rhs) || return false

for idx in $ADT.variant_masks(lhs) # mask is the same for both
$Base.getfield(lhs, idx) == $Base.getfield(rhs, idx) || return false
end
return true
end
end
end
4 changes: 4 additions & 0 deletions test/adt/adt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ end
include("tree.jl")
include("tree_inline.jl")
end

@testset "derive" begin
include("derive.jl")
end
40 changes: 40 additions & 0 deletions test/adt/derive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module Derive

using Test
using MLStyle
using Expronicon
using Expronicon.ADT: ADT, @adt, @derive

@adt MyADT begin
Token
struct Message
x::Int
y::Int
end
end

@derive MyADT: hash, isequal, ==
@test_throws ErrorException begin
@derive MyADT: isless
end

@testset "hash" begin
@test hash(MyADT.Token) == hash(ADT.variant_type(MyADT.Token))
msg = MyADT.Message(1, 2)
h = hash(ADT.variant_type(msg))
h = hash(1, h)
h = hash(2, h)
@test hash(msg) == h
end

@testset "isequal" begin
@test isequal(MyADT.Message(1, 2), MyADT.Message(1, 2))
@test !isequal(MyADT.Message(1, 2), MyADT.Message(1, 3))
end

@testset "==" begin
@test MyADT.Message(1, 2) == MyADT.Message(1, 2)
@test MyADT.Message(1, 2) != MyADT.Message(1, 3)
end

end # Derive

0 comments on commit d44c2e2

Please sign in to comment.