From cfd3922cabfdfca0b417651ee4cea5a649cbebe5 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 9 Jan 2025 17:28:11 +0100 Subject: [PATCH] Add basic code for binding partition revalidation (#56649) This adds the binding partition revalidation code from #54654. This is the last piece of that PR that hasn't been merged yet - however the TODO in that PR still stands for future work. This PR itself adds a callback that gets triggered by deleting a binding. It will then walk all code in the system and invalidate code instances of Methods whose lowered source referenced the given global. This walk is quite slow. Future work will add backedges and optimizations to make this faster, but the basic functionality should be in place with this PR. --- base/Base_compiler.jl | 1 + base/invalidation.jl | 111 ++++++++++++++++++++++++++++++++++++++++++ src/gf.c | 5 ++ src/module.c | 26 +++++++++- test/rebinding.jl | 7 +++ 5 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 base/invalidation.jl diff --git a/base/Base_compiler.jl b/base/Base_compiler.jl index 91f327980389d..db3ebb0232e38 100644 --- a/base/Base_compiler.jl +++ b/base/Base_compiler.jl @@ -255,6 +255,7 @@ include("ordering.jl") using .Order include("coreir.jl") +include("invalidation.jl") # For OS specific stuff # We need to strcat things here, before strings are really defined diff --git a/base/invalidation.jl b/base/invalidation.jl new file mode 100644 index 0000000000000..40a010ab7361c --- /dev/null +++ b/base/invalidation.jl @@ -0,0 +1,111 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +struct GlobalRefIterator + mod::Module +end +IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown() +globalrefs(mod::Module) = GlobalRefIterator(mod) + +function iterate(gri::GlobalRefIterator, i = 1) + m = gri.mod + table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m) + i == length(table) && return nothing + b = table[i] + b === nothing && return iterate(gri, i+1) + return ((b::Core.Binding).globalref, i+1) +end + +const TYPE_TYPE_MT = Type.body.name.mt +const NONFUNCTION_MT = Core.MethodTable.name.mt +function foreach_module_mtable(visit, m::Module, world::UInt) + for gb in globalrefs(m) + binding = gb.binding + bpart = lookup_binding_partition(world, binding) + if is_defined_const_binding(binding_kind(bpart)) + v = partition_restriction(bpart) + uw = unwrap_unionall(v) + name = gb.name + if isa(uw, DataType) + tn = uw.name + if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt) + # this is the original/primary binding for the type (name/wrapper) + mt = tn.mt + if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT + @assert mt.module === m + visit(mt) || return false + end + end + elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name + # this is the original/primary binding for the submodule + foreach_module_mtable(visit, v, world) || return false + elseif isa(v, Core.MethodTable) && v.module === m && v.name === name + # this is probably an external method table here, so let's + # assume so as there is no way to precisely distinguish them + visit(v) || return false + end + end + end + return true +end + +function foreach_reachable_mtable(visit, world::UInt) + visit(TYPE_TYPE_MT) || return + visit(NONFUNCTION_MT) || return + for mod in loaded_modules_array() + foreach_module_mtable(visit, mod, world) + end +end + +function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo) + found_any = false + labelchangemap = nothing + stmts = src.code + isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name + isgr(g) = false + for i = 1:length(stmts) + stmt = stmts[i] + if isgr(stmt) + found_any = true + continue + end + for ur in Compiler.userefs(stmt) + arg = ur[] + # If any of the GlobalRefs in this stmt match the one that + # we are about, we need to move out all GlobalRefs to preserve + # effect order, in case we later invalidate a different GR + if isa(arg, GlobalRef) + if isgr(arg) + @assert !isa(stmt, PhiNode) + found_any = true + break + end + end + end + end + return found_any +end + +function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt) + valid_in_valuepos = false + foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable + for method in MethodList(mt) + if isdefined(method, :source) + src = _uncompressed_ir(method) + old_stmts = src.code + if should_invalidate_code_for_globalref(gr, src) + for mi in specializations(method) + ci = mi.cache + while true + if ci.max_world > new_max_world + ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) + end + isdefined(ci, :next) || break + ci = ci.next + end + end + end + end + end + return true + end +end diff --git a/src/gf.c b/src/gf.c index 080d1ebd52ba8..ba28edfbeeff7 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1867,6 +1867,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo JL_UNLOCK(&replaced_mi->def.method->writelock); } +JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world) +{ + invalidate_code_instance(replaced, max_world, 1); +} + static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) { jl_array_t *backedges = replaced_mi->backedges; if (backedges) { diff --git a/src/module.c b/src/module.c index 1b4c5bd78f667..004371b9144b2 100644 --- a/src/module.c +++ b/src/module.c @@ -1032,6 +1032,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var jl_gc_wb(bpart, val); } +void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world) +{ + static jl_value_t *invalidate_code_for_globalref = NULL; + if (invalidate_code_for_globalref == NULL && jl_base_module != NULL) + invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!")); + if (!invalidate_code_for_globalref) + jl_error("Binding invalidation is not permitted during bootstrap."); + if (jl_generating_output()) + jl_error("Binding invalidation is not permitted during image generation."); + jl_value_t *boxed_world = jl_box_ulong(new_world); + JL_GC_PUSH1(&boxed_world); + jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world); + JL_GC_POP(); +} + extern jl_mutex_t world_counter_lock; JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) { @@ -1046,9 +1061,11 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) JL_LOCK(&world_counter_lock); jl_task_t *ct = jl_current_task; + size_t last_world = ct->world_age; size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter); - // TODO: Trigger invalidation here - (void)ct; + ct->world_age = jl_typeinf_world; + jl_invalidate_binding_refs(gr, new_max_world); + ct->world_age = last_world; jl_atomic_store_release(&bpart->max_world, new_max_world); jl_atomic_store_release(&jl_world_counter, new_max_world + 1); JL_UNLOCK(&world_counter_lock); @@ -1334,6 +1351,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod) jl_array_ptr_1d_push(jl_module_init_order, mod); } +JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m) +{ + return jl_atomic_load_relaxed(&m->bindings); +} + JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod) { if (!jl_generating_output() || jl_options.incremental) { diff --git a/test/rebinding.jl b/test/rebinding.jl index c93c34be7a75c..ad0ad1fc1643d 100644 --- a/test/rebinding.jl +++ b/test/rebinding.jl @@ -33,4 +33,11 @@ module Rebinding @test Base.@world(Foo, defined_world_age) == typeof(x) @test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x) @test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x) + + # Test invalidation (const -> undefined) + const delete_me = 1 + f_return_delete_me() = delete_me + @test f_return_delete_me() == 1 + Base.delete_binding(@__MODULE__, :delete_me) + @test_throws UndefVarError f_return_delete_me() end