Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix requirements inside loops and functions #316

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/scenic/core/dynamics/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs):
self._objects = [] # ordered for reproducibility
self._sampledObjects = self._objects
self._externalParameters = []
self._pendingRequirements = defaultdict(list)
self._pendingRequirements = []
self._requirements = []
# things needing to be sampled to evaluate the requirements
self._requirementDeps = set()
Expand Down Expand Up @@ -409,9 +409,8 @@ def _registerObject(self, obj):

def _addRequirement(self, ty, reqID, req, line, name, prob):
"""Save a requirement defined at compile-time for later processing."""
assert reqID not in self._pendingRequirements
preq = PendingRequirement(ty, req, line, prob, name, self._ego)
self._pendingRequirements[reqID] = preq
self._pendingRequirements.append((reqID, preq))

def _addDynamicRequirement(self, ty, req, line, name):
"""Add a requirement defined during a dynamic simulation."""
Expand All @@ -429,7 +428,7 @@ def _compileRequirements(self):
namespace = self._dummyNamespace if self._dummyNamespace else self.__dict__
requirementSyntax = self._requirementSyntax
assert requirementSyntax is not None
for reqID, requirement in self._pendingRequirements.items():
for reqID, requirement in self._pendingRequirements:
syntax = requirementSyntax[reqID] if requirementSyntax else None

# Catch the simple case where someone has most likely forgotten the "monitor"
Expand Down
4 changes: 3 additions & 1 deletion src/scenic/core/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3275,7 +3275,9 @@ def __init__(self, position, heading, width, length, name=None):
self.circumcircle = (self.position, self.radius)

super().__init__(
polygon=self._makePolygons(position, heading, width, length),
polygon=self._makePolygons(
self.position, self.heading, self.width, self.length
),
z=self.position.z,
name=name,
additionalDeps=deps,
Expand Down
79 changes: 53 additions & 26 deletions src/scenic/core/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,21 @@ def __init__(self, ty, condition, line, prob, name, ego):

# condition is an instance of Proposition. Flatten to get a list of atomic propositions.
atoms = condition.atomics()
bindings = {}
self.globalBindings = {} # bindings to global/builtin names
self.closureBindings = {} # bindings to top-level closure variables
self.cells = [] # cells used in referenced closures
atomGlobals = None
for atom in atoms:
bindings.update(getAllGlobals(atom.closure))
gbindings, cbindings, closures = getNameBindings(atom.closure)
self.globalBindings.update(gbindings)
self.closureBindings.update(cbindings)
for closure in closures:
self.cells.extend(closure.__closure__)
globs = atom.closure.__globals__
if atomGlobals is not None:
assert globs is atomGlobals
else:
atomGlobals = globs
self.bindings = bindings
self.egoObject = ego

def compile(self, namespace, scenario, syntax=None):
Expand All @@ -68,21 +73,28 @@ def compile(self, namespace, scenario, syntax=None):
While we're at it, determine whether the requirement implies any relations
we can use for pruning, and gather all of its dependencies.
"""
bindings, ego, line = self.bindings, self.egoObject, self.line
globalBindings, closureBindings = self.globalBindings, self.closureBindings
cells, ego, line = self.cells, self.egoObject, self.line
condition, ty = self.condition, self.ty

# Convert bound values to distributions as needed
for name, value in bindings.items():
bindings[name] = toDistribution(value)
for name, value in globalBindings.items():
globalBindings[name] = toDistribution(value)
for name, value in closureBindings.items():
closureBindings[name] = toDistribution(value)
cells = tuple((cell, toDistribution(cell.cell_contents)) for cell in cells)
allBindings = dict(globalBindings)
allBindings.update(closureBindings)

# Check whether requirement implies any relations used for pruning
canPrune = condition.check_constrains_sampling()
if canPrune:
relations.inferRelationsFrom(syntax, bindings, ego, line)
relations.inferRelationsFrom(syntax, allBindings, ego, line)

# Gather dependencies of the requirement
deps = set()
for value in bindings.values():
cellVals = (value for cell, value in cells)
for value in itertools.chain(allBindings.values(), cellVals):
if needsSampling(value):
deps.add(value)
if needsLazyEvaluation(value):
Expand All @@ -93,7 +105,7 @@ def compile(self, namespace, scenario, syntax=None):

# If this requirement contains the CanSee specifier, we will need to sample all objects
# to meet the dependencies.
if "CanSee" in bindings:
if "CanSee" in globalBindings:
deps.update(scenario.objects)

if ego is not None:
Expand All @@ -102,13 +114,18 @@ def compile(self, namespace, scenario, syntax=None):

# Construct closure
def closure(values, monitor=None):
# rebind any names referring to sampled objects
# rebind any names referring to sampled objects (for require statements,
# rebind all names, since we want their values at the time the requirement
# was created)
# note: need to extract namespace here rather than close over value
# from above because of https://github.com/uqfoundation/dill/issues/532
namespace = condition.atomics()[0].closure.__globals__
for name, value in bindings.items():
if value in values:
for name, value in globalBindings.items():
if ty == RequirementType.require or value in values:
namespace[name] = values[value]
for cell, value in cells:
cell.cell_contents = values[value]

# rebind ego object, which can be referred to implicitly
boundEgo = None if ego is None else values[ego]
# evaluate requirement condition, reporting errors on the correct line
Expand All @@ -132,24 +149,34 @@ def closure(values, monitor=None):
return CompiledRequirement(self, closure, deps, condition)


def getAllGlobals(req, restrictTo=None):
def getNameBindings(req, restrictTo=None):
"""Find all names the given lambda depends on, along with their current bindings."""
namespace = req.__globals__
if restrictTo is not None and restrictTo is not namespace:
return {}
return {}, {}, ()
externals = inspect.getclosurevars(req)
assert not externals.nonlocals # TODO handle these
globs = dict(externals.builtins)
for name, value in externals.globals.items():
globs[name] = value
if inspect.isfunction(value):
subglobs = getAllGlobals(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in globs:
assert value is globs[name]
else:
globs[name] = value
return globs
globalBindings = externals.builtins

closures = set()
if externals.nonlocals:
closures.add(req)

def handleFunctions(bindings):
for value in bindings.values():
if inspect.isfunction(value):
if value.__closure__ is not None:
closures.add(value)
subglobs, _, _ = getNameBindings(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in globalBindings:
assert value is globalBindings[name]
else:
globalBindings[name] = value

globalBindings.update(externals.globals)
handleFunctions(externals.globals)
handleFunctions(externals.nonlocals)
return globalBindings, externals.nonlocals, closures


class BoundRequirement:
Expand Down
9 changes: 5 additions & 4 deletions src/scenic/syntax/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,11 +1359,12 @@ def createRequirementLike(
"""Create a call to a function that implements requirement-like features, such as `record` and `terminate when`.

Args:
functionName (str): Name of the requirement-like function to call. Its signature must be `(reqId: int, body: () -> bool, lineno: int, name: str | None)`
functionName (str): Name of the requirement-like function to call. Its signature
must be `(reqId: int, body: () -> bool, lineno: int, name: str | None)`
body (ast.AST): AST node to evaluate for checking the condition
lineno (int): Line number in the source code
name (Optional[str], optional): Optional name for requirements. Defaults to None.
prob (Optional[float], optional): Optional probability for requirements. Defaults to None.
name (Optional[str]): Optional name for requirements. Defaults to None.
prob (Optional[float]): Optional probability for requirements. Defaults to None.
"""
propTransformer = PropositionTransformer(self.filename)
newBody, self.nextSyntaxId = propTransformer.transform(body, self.nextSyntaxId)
Expand All @@ -1374,7 +1375,7 @@ def createRequirementLike(
value=ast.Call(
func=ast.Name(functionName, loadCtx),
args=[
ast.Constant(requirementId), # requirement IDre
ast.Constant(requirementId), # requirement ID
newBody, # body
ast.Constant(lineno), # line number
ast.Constant(name), # requirement name
Expand Down
81 changes: 81 additions & 0 deletions tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,87 @@ def test_requirement():
assert all(0 <= x <= 10 for x in xs)


def test_requirement_in_loop():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
for i in range(2):
require ego.position[i] >= 0
"""
)
poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)]
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_requirement_in_function():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
def f(i):
require ego.position[i] >= 0
for i in range(2):
f(i)
"""
)
poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)]
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_requirement_in_function_helper():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
m = 0
def f():
assert m == 0
return ego.y + m
def g():
require ego.x < f()
g()
m = -100
"""
)
poss = [sampleEgo(scenario, maxIterations=60).position for i in range(60)]
assert all(pos.x < pos.y for pos in poss)


def test_requirement_in_function_random_local():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ 0
def f():
local = Range(0, 1)
require ego.x < local
f()
"""
)
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert all(-10 <= x <= 1 for x in xs)


def test_requirement_in_function_random_cell():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ 0
def f(i):
def g():
return i
return g
g = f(Range(0, 1)) # global function with a cell containing a random value
def h():
local = Uniform(True, False)
def inner(): # local function likewise
return local
require (g() >= 0) and ((ego.x < -5) if inner() else (ego.x > 5))
h()
"""
)
xs = [sampleEgo(scenario, maxIterations=150).position.x for i in range(60)]
assert all(x < -5 or x > 5 for x in xs)
assert any(x < -5 for x in xs)
assert any(x > 5 for x in xs)


def test_soft_requirement():
scenario = compileScenic(
"""
Expand Down
Loading