Source code for scenic.syntax.relations

"""Extracting relations (for later pruning) from the syntax of requirements."""

from ast import (
    Add,
    BinOp,
    Call,
    Compare,
    Eq,
    Expression,
    Gt,
    GtE,
    Lt,
    LtE,
    Name,
    NotEq,
    Sub,
)
from collections import defaultdict
import math

from scenic.core.distributions import needsSampling
from scenic.core.errors import InconsistentScenarioError, InvalidScenarioError
from scenic.core.object_types import Object, Point


[docs]def inferRelationsFrom(reqNode, namespace, ego, line): """Infer relations between objects implied by a requirement.""" matcher = RequirementMatcher(namespace) inferRelativeHeadingRelations(matcher, reqNode, ego, line) inferDistanceRelations(matcher, reqNode, ego, line)
[docs]def inferRelativeHeadingRelations(matcher, reqNode, ego, line): """Infer bounds on relative headings from a requirement.""" rhMatcher = lambda node: matcher.matchUnaryFunction("RelativeHeading", node) allBounds = matcher.matchBounds(reqNode, rhMatcher) for target, bounds in allBounds: if not isinstance(target, Object): continue assert target is not ego if ego is None: raise InvalidScenarioError( "relative heading w.r.t. unassigned ego on line {line}" ) lower, upper = bounds if lower < -math.pi: lower = -math.pi if upper > math.pi: upper = math.pi if lower == -math.pi and upper == math.pi: continue # skip trivial bounds rel = RelativeHeadingRelation(target, lower, upper) ego._relations.append(rel) conv = RelativeHeadingRelation(ego, -upper, -lower) target._relations.append(conv)
[docs]def inferDistanceRelations(matcher, reqNode, ego, line): """Infer bounds on distances from a requirement.""" distMatcher = lambda node: matcher.matchUnaryFunction("DistanceFrom", node) allBounds = matcher.matchBounds(reqNode, distMatcher) for target, bounds in allBounds: if not isinstance(target, Object): continue assert target is not ego if ego is None: raise InvalidScenarioError("distance w.r.t. unassigned ego on line {line}") lower, upper = bounds if lower < 0: lower = 0 if upper == float("inf"): continue # skip trivial bounds rel = DistanceRelation(target, lower, upper) ego._relations.append(rel) conv = DistanceRelation(ego, lower, upper) target._relations.append(conv)
[docs]class BoundRelation: """Abstract relation bounding something about another object.""" def __init__(self, target, lower, upper): self.target = target self.lower, self.upper = lower, upper
[docs]class RelativeHeadingRelation(BoundRelation): """Relation bounding another object's relative heading with respect to this one.""" pass
[docs]class DistanceRelation(BoundRelation): """Relation bounding another object's distance from this one.""" pass
class RequirementMatcher: def __init__(self, namespace): self.namespace = namespace def inconsistencyError(self, node, message): raise InconsistentScenarioError(node.lineno, message) def matchUnaryFunction(self, name, node): """Match a call to a specified unary function, returning the value of its argument.""" if not ( isinstance(node, Call) and isinstance(node.func, Name) and node.func.id == name ): return None if len(node.args) != 1: return None if len(node.keywords) != 0: return None return self.matchValue(node.args[0]) def matchBounds(self, node, matchAtom): """Match upper/lower bounds on something matched by the given function. Returns a list of all bounds found, pairing the bounded quantity with a pair (low, high) of lower/upper bounds. """ if not isinstance(node, Compare): return {} bounds = defaultdict(lambda: (float("-inf"), float("inf"))) targets = {} first = node.left for second, op in zip(node.comparators, node.ops): lower, upper, target = self.matchBoundsInner(first, second, op, matchAtom) first = second if target is None: continue targetID = id(target) # use id to support unhashable types targets[targetID] = target bestLower, bestUpper = bounds[targetID] if lower is not None and lower > bestLower: bestLower = lower if upper is not None and upper < bestUpper: bestUpper = upper bounds[targetID] = (bestLower, bestUpper) return [(target, bounds[id_]) for id_, target in targets.items()] def matchBoundsInner(self, left, right, op, matchAtom): """Extract bounds from a single comparison operator.""" # Reduce > and >= to < and <= if isinstance(op, Gt): return self.matchBoundsInner(right, left, Lt(), matchAtom) elif isinstance(op, GtE): return self.matchBoundsInner(right, left, LtE(), matchAtom) # Try matching a constant lower bound on the atom or its absolute value lconst = self.matchConstant(left) if isinstance(lconst, (int, float)): target = matchAtom(right) if target is not None: # CONST op QUANTITY return ( (lconst, lconst, target) if isinstance(op, Eq) else (lconst, None, target) ) else: bounds = self.matchAbsBounds(right, lconst, op, False, matchAtom) if bounds is not None: # CONST op abs(QUANTITY [+/- CONST]) return bounds # Try matching a constant upper bound on the atom or its absolute value rconst = self.matchConstant(right) if isinstance(rconst, (int, float)): target = matchAtom(left) if target is not None: # QUANTITY op CONST return ( (rconst, rconst, target) if isinstance(op, Eq) else (None, rconst, target) ) else: bounds = self.matchAbsBounds(left, rconst, op, True, matchAtom) if bounds is not None: # abs(QUANTITY [+/- CONST]) op CONST return bounds return None, None, None def matchAbsBounds(self, node, const, op, isUpperBound, matchAtom): """Extract bounds on an atom from a comparison involving its absolute value.""" if not ( isinstance(node, Call) and isinstance(node.func, Name) and node.func.id == "abs" ): return None # not an invocation of abs if not isUpperBound and not isinstance(op, Eq): return None # lower bounds on abs value don't bound underlying quantity if const < 0: self.inconsistencyError(node, f"absolute value cannot be negative") assert len(node.args) == 1 arg = node.args[0] target = matchAtom(arg) if target is not None: # abs(QUANTITY) </= CONST return (-const, const, target) elif isinstance(arg, BinOp) and isinstance(arg.op, (Add, Sub)): # abs(X +/- Y) </= CONST match = None slconst = self.matchConstant(arg.left) target = matchAtom(arg.right) if isinstance(slconst, (int, float)) and target is not None: # abs(CONST +/- QUANTITY) </= CONST match = slconst else: srconst = self.matchConstant(arg.right) target = matchAtom(arg.left) if isinstance(srconst, (int, float)) and target is not None: # abs(QUANTITY +/- CONST) </= CONST match = srconst if match is not None: if isinstance(arg.op, Add): # abs(QUANTITY + CONST) </= CONST return (-const - match, const - match, target) else: # abs(QUANTITY - CONST) </= CONST return (-const + match, const + match, target) return None def matchConstant(self, node): """Match constant values, i.e. values known prior to sampling.""" value = self.matchValue(node) return None if needsSampling(value) else value def matchValue(self, node): """Match any expression which can be evaluated, returning its value. This method could have undesirable side-effects, but conditions in requirements should not have side-effects to begin with. """ try: code = compile(Expression(node), "<internal>", "eval") value = eval(code, dict(self.namespace)) except Exception: return None return value