Source code for scenic.core.sample_checking

"""The SampleChecker class and it's implementations."""

from abc import ABC, abstractmethod
from collections import deque
import time

from scenic.core.distributions import RejectionException
from scenic.core.requirements import BlanketCollisionRequirement, IntersectionRequirement


class SampleChecker(ABC):
    def __init__(self):
        self.requirements = None

    def setRequirements(self, requirements):
        assert self.requirements is None
        self.requirements = tuple(requirements)

    @abstractmethod
    def checkRequirementsInner(self, sample):
        pass

    def checkRequirements(self, sample):
        assert self.requirements is not None
        try:
            return self.checkRequirementsInner(sample)
        except RejectionException as e:
            return e


[docs]class BasicChecker(SampleChecker): """Basic requirement checker. Evaluates requirements in order, with a tiny bit of tuning. """ def __init__(self, initialCollisionCheck): super().__init__() self.initialCollisionCheck = initialCollisionCheck def setRequirements(self, requirements): target_reqs = [] for req in requirements: if req.optional: # Basic checker ignores optional requirements unless otherwise noted. if ( isinstance(req, BlanketCollisionRequirement) and self.initialCollisionCheck and sum(isinstance(r, IntersectionRequirement) for r in requirements) >= 3 ): target_reqs.append(req) else: target_reqs.append(req) super().setRequirements(target_reqs) def checkRequirementsInner(self, sample): for req in self.requirements: if req.active and req.falsifiedBy(sample): return req.violationMsg return None
[docs]class WeightedAcceptanceChecker(SampleChecker): """Picks the requirement with the lowest time-weighted acceptance chance. Incentivizes exploration by initializing all buffer values to 0. Args: bufferSize: Max samples to use when calculating time-weighted rejection chance. """ def __init__(self, bufferSize=10): super().__init__() self.bufferSize = bufferSize self.buffers = None self.bufferSums = None def setRequirements(self, requirements): super().setRequirements(requirements) self.buffers = {req: deque() for req in self.requirements} self.bufferSums = {req: (0, 0) for req in self.requirements} for req in self.requirements: self.buffers[req].extend([(0, 0)] * self.bufferSize) def checkRequirementsInner(self, sample): for req in self.sortedRequirements(): # Evaluate the requirement with timing info. start = time.perf_counter() rejected = req.falsifiedBy(sample) # Create metrics (Accepted, Time Taken) metrics = (int(not rejected), time.perf_counter() - start) self.updateMetrics(req, metrics) if rejected: return req.violationMsg return None
[docs] def sortedRequirements(self): """Return the list of requirements in sorted order""" # Extract and sort active requirements reqs = [req for req in self.requirements if req.active] reqs.sort(key=self.getWeightedAcceptanceProb) # Remove any optional requirements at the end of the list, since they're useless while reqs and reqs[-1].optional: reqs.pop() return reqs
[docs] def updateMetrics(self, req, new_metrics): """Update the metrics for a given requirement""" # Update buffer target_buffer = self.buffers[req] old_metrics = target_buffer.popleft() target_buffer.append(new_metrics) # Unpack values sum_acc, sum_time = self.bufferSums[req] old_acc, old_time = old_metrics new_acc, new_time = new_metrics # Update sums sum_acc += new_acc - old_acc sum_time += new_time - old_time self.bufferSums[req] = (sum_acc, sum_time)
def getWeightedAcceptanceProb(self, req): sum_acc, sum_time = self.bufferSums[req] return (sum_acc / self.bufferSize) * (sum_time / self.bufferSize)