Source code for scenic.syntax.relations


"""Extracting relations (for later pruning) from the syntax of requirements."""

import math
from collections import defaultdict
from ast import Compare, BinOp, Eq, NotEq, Lt, LtE, Gt, GtE, Call, Add, Sub, Expression, Name

from scenic.core.distributions import needsSampling
from scenic.core.object_types import Point, Object
from scenic.core.utils import InvalidScenarioError, InconsistentScenarioError

[docs]def inferRelationsFrom(reqNode, namespace, ego, line): """Infer relations between objects implied by a requirement.""" matcher = RequirementMatcher(namespace) inferRelativeHeadingRelations(matcher, reqNode, ego, line) inferDistanceRelations(matcher, reqNode, ego, line)
[docs]def inferRelativeHeadingRelations(matcher, reqNode, ego, line): """Infer bounds on relative headings from a requirement.""" rhMatcher = lambda node: matcher.matchUnaryFunction('RelativeHeading', node) allBounds = matcher.matchBounds(reqNode, rhMatcher) for target, bounds in allBounds.items(): if not isinstance(target, Object): continue assert target is not ego if ego is None: raise InvalidScenarioError('relative heading w.r.t. unassigned ego on line {line}') lower, upper = bounds if lower < -math.pi: lower = -math.pi if upper > math.pi: upper = math.pi if lower == -math.pi and upper == math.pi: continue # skip trivial bounds rel = RelativeHeadingRelation(target, lower, upper) ego._relations.append(rel) conv = RelativeHeadingRelation(ego, -upper, -lower) target._relations.append(conv)
[docs]def inferDistanceRelations(matcher, reqNode, ego, line): """Infer bounds on distances from a requirement.""" distMatcher = lambda node: matcher.matchUnaryFunction('DistanceFrom', node) allBounds = matcher.matchBounds(reqNode, distMatcher) for target, bounds in allBounds.items(): if not isinstance(target, Object): continue assert target is not ego if ego is None: raise InvalidScenarioError('distance w.r.t. unassigned ego on line {line}') lower, upper = bounds if lower < 0: lower = 0 if upper == float('inf'): continue # skip trivial bounds rel = DistanceRelation(target, lower, upper) ego._relations.append(rel) conv = DistanceRelation(ego, lower, upper) target._relations.append(conv)
[docs]class BoundRelation: """Abstract relation bounding something about another object.""" def __init__(self, target, lower, upper): self.target = target self.lower, self.upper = lower, upper
[docs]class RelativeHeadingRelation(BoundRelation): """Relation bounding another object's relative heading with respect to this one.""" pass
[docs]class DistanceRelation(BoundRelation): """Relation bounding another object's distance from this one.""" pass
class RequirementMatcher: def __init__(self, namespace): self.namespace = namespace def inconsistencyError(self, node, message): raise InconsistentScenarioError(node.lineno, message) def matchUnaryFunction(self, name, node): """Match a call to a specified unary function, returning the value of its argument.""" if not (isinstance(node, Call) and isinstance(node.func, Name) and node.func.id == name): return None if len(node.args) != 1: return None if len(node.keywords) != 0: return None return self.matchValue(node.args[0]) def matchBounds(self, node, matchAtom): """Match upper/lower bounds on something matched by the given function. Returns a dict of all bounds found, mapping the bounded quantity to a pair (low, high) of lower/upper bounds. """ if not isinstance(node, Compare): return {} bounds = defaultdict(lambda: (float('-inf'), float('inf'))) first = node.left for second, op in zip(node.comparators, node.ops): lower, upper, target = self.matchBoundsInner(first, second, op, matchAtom) first = second if target is None: continue bestLower, bestUpper = bounds[target] if lower is not None and lower > bestLower: bestLower = lower if upper is not None and upper < bestUpper: bestUpper = upper bounds[target] = (bestLower, bestUpper) return bounds def matchBoundsInner(self, left, right, op, matchAtom): """Extract bounds from a single comparison operator.""" # Reduce > and >= to < and <= if isinstance(op, Gt): return self.matchBoundsInner(right, left, Lt(), matchAtom) elif isinstance(op, GtE): return self.matchBoundsInner(right, left, LtE(), matchAtom) # Try matching a constant lower bound on the atom or its absolute value lconst = self.matchConstant(left) if isinstance(lconst, (int, float)): target = matchAtom(right) if target is not None: # CONST op QUANTITY return (lconst, lconst, target) if isinstance(op, Eq) else (lconst, None, target) else: bounds = self.matchAbsBounds(right, lconst, op, False, matchAtom) if bounds is not None: # CONST op abs(QUANTITY [+/- CONST]) return bounds # Try matching a constant upper bound on the atom or its absolute value rconst = self.matchConstant(right) if isinstance(rconst, (int, float)): target = matchAtom(left) if target is not None: # QUANTITY op CONST return (rconst, rconst, target) if isinstance(op, Eq) else (None, rconst, target) else: bounds = self.matchAbsBounds(left, rconst, op, True, matchAtom) if bounds is not None: # abs(QUANTITY [+/- CONST]) op CONST return bounds return None, None, None def matchAbsBounds(self, node, const, op, isUpperBound, matchAtom): """Extract bounds on an atom from a comparison involving its absolute value.""" if not (isinstance(node, Call) and isinstance(node.func, Name) and node.func.id == 'abs'): return None # not an invocation of abs if not isUpperBound and not isinstance(op, Eq): return None # lower bounds on abs value don't bound underlying quantity if const < 0: self.inconsistencyError(node, f'absolute value cannot be negative') assert len(node.args) == 1 arg = node.args[0] target = matchAtom(arg) if target is not None: # abs(QUANTITY) </= CONST return (-const, const, target) elif isinstance(arg, BinOp) and isinstance(arg.op, (Add, Sub)): # abs(X +/- Y) </= CONST match = None slconst = self.matchConstant(arg.left) target = matchAtom(arg.right) if (isinstance(slconst, (int, float)) and target is not None): # abs(CONST +/- QUANTITY) </= CONST match = slconst else: srconst = self.matchConstant(arg.right) target = matchAtom(arg.left) if (isinstance(srconst, (int, float)) and target is not None): # abs(QUANTITY +/- CONST) </= CONST match = srconst if match is not None: if isinstance(arg.op, Add): # abs(QUANTITY + CONST) </= CONST return (-const - match, const - match, target) else: # abs(QUANTITY - CONST) </= CONST return (-const + match, const + match, target) return None def matchConstant(self, node): """Match constant values, i.e. values known prior to sampling.""" value = self.matchValue(node) return None if needsSampling(value) else value def matchValue(self, node): """Match any expression which can be evaluated, returning its value. This method could have undesirable side-effects, but conditions in requirements should not have side-effects to begin with. """ try: code = compile(Expression(node), '<internal>', 'eval') value = eval(code, dict(self.namespace)) except Exception: return None return value