Skip to content

Commit

Permalink
Fix requirements inside loops and functions (#316)
Browse files Browse the repository at this point in the history
* fix require statements inside loops

* fix require statements inside functions

* fix RectangularRegion with random coerced parameter

* fix require statements with random closure variables
  • Loading branch information
dfremont authored Nov 26, 2024
1 parent 4ae8ee2 commit fe28e13
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 35 deletions.
7 changes: 3 additions & 4 deletions src/scenic/core/dynamics/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs):
self._objects = [] # ordered for reproducibility
self._sampledObjects = self._objects
self._externalParameters = []
self._pendingRequirements = defaultdict(list)
self._pendingRequirements = []
self._requirements = []
# things needing to be sampled to evaluate the requirements
self._requirementDeps = set()
Expand Down Expand Up @@ -409,9 +409,8 @@ def _registerObject(self, obj):

def _addRequirement(self, ty, reqID, req, line, name, prob):
"""Save a requirement defined at compile-time for later processing."""
assert reqID not in self._pendingRequirements
preq = PendingRequirement(ty, req, line, prob, name, self._ego)
self._pendingRequirements[reqID] = preq
self._pendingRequirements.append((reqID, preq))

def _addDynamicRequirement(self, ty, req, line, name):
"""Add a requirement defined during a dynamic simulation."""
Expand All @@ -429,7 +428,7 @@ def _compileRequirements(self):
namespace = self._dummyNamespace if self._dummyNamespace else self.__dict__
requirementSyntax = self._requirementSyntax
assert requirementSyntax is not None
for reqID, requirement in self._pendingRequirements.items():
for reqID, requirement in self._pendingRequirements:
syntax = requirementSyntax[reqID] if requirementSyntax else None

# Catch the simple case where someone has most likely forgotten the "monitor"
Expand Down
4 changes: 3 additions & 1 deletion src/scenic/core/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3275,7 +3275,9 @@ def __init__(self, position, heading, width, length, name=None):
self.circumcircle = (self.position, self.radius)

super().__init__(
polygon=self._makePolygons(position, heading, width, length),
polygon=self._makePolygons(
self.position, self.heading, self.width, self.length
),
z=self.position.z,
name=name,
additionalDeps=deps,
Expand Down
79 changes: 53 additions & 26 deletions src/scenic/core/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,21 @@ def __init__(self, ty, condition, line, prob, name, ego):

# condition is an instance of Proposition. Flatten to get a list of atomic propositions.
atoms = condition.atomics()
bindings = {}
self.globalBindings = {} # bindings to global/builtin names
self.closureBindings = {} # bindings to top-level closure variables
self.cells = [] # cells used in referenced closures
atomGlobals = None
for atom in atoms:
bindings.update(getAllGlobals(atom.closure))
gbindings, cbindings, closures = getNameBindings(atom.closure)
self.globalBindings.update(gbindings)
self.closureBindings.update(cbindings)
for closure in closures:
self.cells.extend(closure.__closure__)
globs = atom.closure.__globals__
if atomGlobals is not None:
assert globs is atomGlobals
else:
atomGlobals = globs
self.bindings = bindings
self.egoObject = ego

def compile(self, namespace, scenario, syntax=None):
Expand All @@ -68,21 +73,28 @@ def compile(self, namespace, scenario, syntax=None):
While we're at it, determine whether the requirement implies any relations
we can use for pruning, and gather all of its dependencies.
"""
bindings, ego, line = self.bindings, self.egoObject, self.line
globalBindings, closureBindings = self.globalBindings, self.closureBindings
cells, ego, line = self.cells, self.egoObject, self.line
condition, ty = self.condition, self.ty

# Convert bound values to distributions as needed
for name, value in bindings.items():
bindings[name] = toDistribution(value)
for name, value in globalBindings.items():
globalBindings[name] = toDistribution(value)
for name, value in closureBindings.items():
closureBindings[name] = toDistribution(value)
cells = tuple((cell, toDistribution(cell.cell_contents)) for cell in cells)
allBindings = dict(globalBindings)
allBindings.update(closureBindings)

# Check whether requirement implies any relations used for pruning
canPrune = condition.check_constrains_sampling()
if canPrune:
relations.inferRelationsFrom(syntax, bindings, ego, line)
relations.inferRelationsFrom(syntax, allBindings, ego, line)

# Gather dependencies of the requirement
deps = set()
for value in bindings.values():
cellVals = (value for cell, value in cells)
for value in itertools.chain(allBindings.values(), cellVals):
if needsSampling(value):
deps.add(value)
if needsLazyEvaluation(value):
Expand All @@ -93,7 +105,7 @@ def compile(self, namespace, scenario, syntax=None):

# If this requirement contains the CanSee specifier, we will need to sample all objects
# to meet the dependencies.
if "CanSee" in bindings:
if "CanSee" in globalBindings:
deps.update(scenario.objects)

if ego is not None:
Expand All @@ -102,13 +114,18 @@ def compile(self, namespace, scenario, syntax=None):

# Construct closure
def closure(values, monitor=None):
# rebind any names referring to sampled objects
# rebind any names referring to sampled objects (for require statements,
# rebind all names, since we want their values at the time the requirement
# was created)
# note: need to extract namespace here rather than close over value
# from above because of https://github.com/uqfoundation/dill/issues/532
namespace = condition.atomics()[0].closure.__globals__
for name, value in bindings.items():
if value in values:
for name, value in globalBindings.items():
if ty == RequirementType.require or value in values:
namespace[name] = values[value]
for cell, value in cells:
cell.cell_contents = values[value]

# rebind ego object, which can be referred to implicitly
boundEgo = None if ego is None else values[ego]
# evaluate requirement condition, reporting errors on the correct line
Expand All @@ -132,24 +149,34 @@ def closure(values, monitor=None):
return CompiledRequirement(self, closure, deps, condition)


def getAllGlobals(req, restrictTo=None):
def getNameBindings(req, restrictTo=None):
"""Find all names the given lambda depends on, along with their current bindings."""
namespace = req.__globals__
if restrictTo is not None and restrictTo is not namespace:
return {}
return {}, {}, ()
externals = inspect.getclosurevars(req)
assert not externals.nonlocals # TODO handle these
globs = dict(externals.builtins)
for name, value in externals.globals.items():
globs[name] = value
if inspect.isfunction(value):
subglobs = getAllGlobals(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in globs:
assert value is globs[name]
else:
globs[name] = value
return globs
globalBindings = externals.builtins

closures = set()
if externals.nonlocals:
closures.add(req)

def handleFunctions(bindings):
for value in bindings.values():
if inspect.isfunction(value):
if value.__closure__ is not None:
closures.add(value)
subglobs, _, _ = getNameBindings(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in globalBindings:
assert value is globalBindings[name]
else:
globalBindings[name] = value

globalBindings.update(externals.globals)
handleFunctions(externals.globals)
handleFunctions(externals.nonlocals)
return globalBindings, externals.nonlocals, closures


class BoundRequirement:
Expand Down
9 changes: 5 additions & 4 deletions src/scenic/syntax/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,11 +1359,12 @@ def createRequirementLike(
"""Create a call to a function that implements requirement-like features, such as `record` and `terminate when`.
Args:
functionName (str): Name of the requirement-like function to call. Its signature must be `(reqId: int, body: () -> bool, lineno: int, name: str | None)`
functionName (str): Name of the requirement-like function to call. Its signature
must be `(reqId: int, body: () -> bool, lineno: int, name: str | None)`
body (ast.AST): AST node to evaluate for checking the condition
lineno (int): Line number in the source code
name (Optional[str], optional): Optional name for requirements. Defaults to None.
prob (Optional[float], optional): Optional probability for requirements. Defaults to None.
name (Optional[str]): Optional name for requirements. Defaults to None.
prob (Optional[float]): Optional probability for requirements. Defaults to None.
"""
propTransformer = PropositionTransformer(self.filename)
newBody, self.nextSyntaxId = propTransformer.transform(body, self.nextSyntaxId)
Expand All @@ -1374,7 +1375,7 @@ def createRequirementLike(
value=ast.Call(
func=ast.Name(functionName, loadCtx),
args=[
ast.Constant(requirementId), # requirement IDre
ast.Constant(requirementId), # requirement ID
newBody, # body
ast.Constant(lineno), # line number
ast.Constant(name), # requirement name
Expand Down
81 changes: 81 additions & 0 deletions tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,87 @@ def test_requirement():
assert all(0 <= x <= 10 for x in xs)


def test_requirement_in_loop():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
for i in range(2):
require ego.position[i] >= 0
"""
)
poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)]
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_requirement_in_function():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
def f(i):
require ego.position[i] >= 0
for i in range(2):
f(i)
"""
)
poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)]
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_requirement_in_function_helper():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
m = 0
def f():
assert m == 0
return ego.y + m
def g():
require ego.x < f()
g()
m = -100
"""
)
poss = [sampleEgo(scenario, maxIterations=60).position for i in range(60)]
assert all(pos.x < pos.y for pos in poss)


def test_requirement_in_function_random_local():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ 0
def f():
local = Range(0, 1)
require ego.x < local
f()
"""
)
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert all(-10 <= x <= 1 for x in xs)


def test_requirement_in_function_random_cell():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ 0
def f(i):
def g():
return i
return g
g = f(Range(0, 1)) # global function with a cell containing a random value
def h():
local = Uniform(True, False)
def inner(): # local function likewise
return local
require (g() >= 0) and ((ego.x < -5) if inner() else (ego.x > 5))
h()
"""
)
xs = [sampleEgo(scenario, maxIterations=150).position.x for i in range(60)]
assert all(x < -5 or x > 5 for x in xs)
assert any(x < -5 for x in xs)
assert any(x > 5 for x in xs)


def test_soft_requirement():
scenario = compileScenic(
"""
Expand Down

0 comments on commit fe28e13

Please sign in to comment.