Skip to content

Commit

Permalink
fix require statements with random closure variables
Browse files Browse the repository at this point in the history
  • Loading branch information
dfremont committed Nov 25, 2024
1 parent 1a50ecf commit ceda6c5
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 24 deletions.
67 changes: 43 additions & 24 deletions src/scenic/core/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,21 @@ def __init__(self, ty, condition, line, prob, name, ego):

# condition is an instance of Proposition. Flatten to get a list of atomic propositions.
atoms = condition.atomics()
bindings = {}
self.globalBindings = {} # bindings to global/builtin names
self.closureBindings = {} # bindings to top-level closure variables
self.cells = [] # cells used in referenced closures
atomGlobals = None
for atom in atoms:
bindings.update(getNameBindings(atom.closure))
gbindings, cbindings, closures = getNameBindings(atom.closure)
self.globalBindings.update(gbindings)
self.closureBindings.update(cbindings)
for closure in closures:
self.cells.extend(closure.__closure__)
globs = atom.closure.__globals__
if atomGlobals is not None:
assert globs is atomGlobals
else:
atomGlobals = globs
self.bindings = bindings
self.egoObject = ego

def compile(self, namespace, scenario, syntax=None):
Expand All @@ -68,21 +73,28 @@ def compile(self, namespace, scenario, syntax=None):
While we're at it, determine whether the requirement implies any relations
we can use for pruning, and gather all of its dependencies.
"""
bindings, ego, line = self.bindings, self.egoObject, self.line
globalBindings, closureBindings = self.globalBindings, self.closureBindings
cells, ego, line = self.cells, self.egoObject, self.line
condition, ty = self.condition, self.ty

# Convert bound values to distributions as needed
for name, value in bindings.items():
bindings[name] = toDistribution(value)
for name, value in globalBindings.items():
globalBindings[name] = toDistribution(value)
for name, value in closureBindings.items():
closureBindings[name] = toDistribution(value)
cells = tuple((cell, toDistribution(cell.cell_contents)) for cell in cells)
allBindings = dict(globalBindings)
allBindings.update(closureBindings)

# Check whether requirement implies any relations used for pruning
canPrune = condition.check_constrains_sampling()
if canPrune:
relations.inferRelationsFrom(syntax, bindings, ego, line)
relations.inferRelationsFrom(syntax, allBindings, ego, line)

# Gather dependencies of the requirement
deps = set()
for value in bindings.values():
cellVals = (value for cell, value in cells)
for value in itertools.chain(allBindings.values(), cellVals):
if needsSampling(value):
deps.add(value)
if needsLazyEvaluation(value):
Expand All @@ -93,7 +105,7 @@ def compile(self, namespace, scenario, syntax=None):

# If this requirement contains the CanSee specifier, we will need to sample all objects
# to meet the dependencies.
if "CanSee" in bindings:
if "CanSee" in globalBindings:
deps.update(scenario.objects)

if ego is not None:
Expand All @@ -108,9 +120,12 @@ def closure(values, monitor=None):
# note: need to extract namespace here rather than close over value
# from above because of https://github.com/uqfoundation/dill/issues/532
namespace = condition.atomics()[0].closure.__globals__
for name, value in bindings.items():
for name, value in globalBindings.items():
if ty == RequirementType.require or value in values:
namespace[name] = values[value]
for cell, value in cells:
cell.cell_contents = values[value]

# rebind ego object, which can be referred to implicitly
boundEgo = None if ego is None else values[ego]
# evaluate requirement condition, reporting errors on the correct line
Expand Down Expand Up @@ -138,26 +153,30 @@ def getNameBindings(req, restrictTo=None):
"""Find all names the given lambda depends on, along with their current bindings."""
namespace = req.__globals__
if restrictTo is not None and restrictTo is not namespace:
return {}
return {}, {}, ()
externals = inspect.getclosurevars(req)
allBindings = dict(externals.builtins)
globalBindings = externals.builtins

def addBindings(bindings):
for name, value in bindings.items():
allBindings[name] = value
closures = set()
if externals.nonlocals:
closures.add(req)

def handleFunctions(bindings):
for value in bindings.values():
if inspect.isfunction(value):
subglobs = getNameBindings(value, restrictTo=namespace)
if value.__closure__ is not None:
closures.add(value)
subglobs, _, _ = getNameBindings(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in allBindings:
assert value is allBindings[name]
if name in globalBindings:
assert value is globalBindings[name]
else:
allBindings[name] = value
globalBindings[name] = value

addBindings(externals.globals)
if restrictTo is None:
# At the top level, include nonlocal variables captured in the closure
addBindings(externals.nonlocals)
return allBindings
globalBindings.update(externals.globals)
handleFunctions(externals.globals)
handleFunctions(externals.nonlocals)
return globalBindings, externals.nonlocals, closures


class BoundRequirement:
Expand Down
37 changes: 37 additions & 0 deletions tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,43 @@ def g():
assert all(pos.x < pos.y for pos in poss)


def test_requirement_in_function_random_local():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ 0
def f():
local = Range(0, 1)
require ego.x < local
f()
"""
)
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert all(-10 <= x <= 1 for x in xs)


def test_requirement_in_function_random_cell():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ 0
def f(i):
def g():
return i
return g
g = f(Range(0, 1)) # global function with a cell containing a random value
def h():
local = Uniform(True, False)
def inner(): # local function likewise
return local
require (g() >= 0) and ((ego.x < -5) if inner() else (ego.x > 5))
h()
"""
)
xs = [sampleEgo(scenario, maxIterations=150).position.x for i in range(60)]
assert all(x < -5 or x > 5 for x in xs)
assert any(x < -5 for x in xs)
assert any(x > 5 for x in xs)


def test_soft_requirement():
scenario = compileScenic(
"""
Expand Down

0 comments on commit ceda6c5

Please sign in to comment.