import ast
import copy
import os
import random
import re
import yaml
from . import trace
[docs]class ConstraintError(Exception):
"""
Constraint module specified exception.
"""
pass
[docs]class ConstraintManager(object):
"""
Manager class contains and manipulates all constraints.
"""
def __init__(self, provider):
"""
:param path: Directory to load constraint YAML file from.
"""
self.provider = provider
path = os.path.join(provider.path, 'oracles')
self.constraints = self._load_constraints(path)
self.item = None
self.status = {}
def _load_constraints(self, path):
"""
Load constraints from a directory containing YAML files.
:param path: Directory to load constraint YAML file from.
"""
cstrs = []
for root, _, files in os.walk(path):
for fname in files:
fpath = os.path.join(root, fname)
with open(fpath) as fp:
cstrs.extend(yaml.load(fp))
cstrs = [Constraint.from_dict(self.provider, c) for c in cstrs]
return cstrs
def _assumption_valid(self, constraint):
"""
Check whether the assumption of a constraint is valid.
:param constraint: The constraint whose assumption to be checked.
"""
if constraint.require is None:
return True
module = ast.parse(constraint.require)
assert len(module.body) == 1
expr = module.body[0]
assert isinstance(expr, ast.Expr)
compare = expr.value
assert isinstance(compare, ast.Compare)
assert len(compare.ops) == 1
left = compare.left
if isinstance(left, ast.Name):
left = left.id
op = compare.ops[0].__class__.__name__
right = compare.comparators[0]
if isinstance(right, ast.Name):
right = right.id
if left in self.status:
left = self.status[left]
if op == 'Is':
return left.lower() == right.lower()
else:
raise ConstraintError('Operator %s is not handled' % op)
[docs] def constrain(self, item):
"""
Apply constraints to an item.
:param item: Item for constraints to apply on.
"""
self.item = item
self.status = {c.name: 'untouched'
for c in self.constraints}
while any(s == 'untouched' for s in self.status.values()):
for constraint in self.constraints:
if self._assumption_valid(constraint):
result = constraint.apply(item)
else:
result = 'skipped'
self.status[constraint.name] = result
[docs]class Constraint(object):
"""
Class for a constraint on specific option of test item.
"""
path_prefix = 'DPATH'
def __init__(self, name, provider,
depends_on=None, require=None, oracle=None):
"""
:param name: Unique string name of the constraint.
:param depends_on: A logical expression shows prerequisite to apply
this constraint.
:param require: Logical expression shows the limit of this constraint.
:param oracle: A block of code shows the details of this constraint.
"""
self.name = name
self.provider = provider
self.depends_on = depends_on
self.require = require
self.oracle = oracle
self.fail_ratio = 0.1
self.traces = self._oracle2traces(oracle)
@classmethod
[docs] def from_dict(cls, provider, data):
"""
Generate a constraint instance from a dictionary
"""
name = data['name']
del data['name']
return cls(name, provider, **data)
def _oracle2traces(self, oracle):
def _translate(oracle):
def _repl(match):
return '%s_%s' % (self.path_prefix,
match.groups()[2].replace('/', '_'))
lines = []
for line in oracle.splitlines():
# Replace a word begin with a slash
line = re.sub(r'((^)|(?<=\W))/([\w/]+)', _repl, line)
lines.append(line)
return '\n'.join(lines)
def _revert_compare(node):
"""
Helper function to revert a compare node to its negation.
"""
rev_node = copy.deepcopy(node)
op = rev_node.ops[0]
if isinstance(op, ast.Is):
rev_node.ops = [ast.IsNot()]
elif isinstance(op, ast.Gt):
rev_node.ops = [ast.LtE()]
elif isinstance(op, ast.Lt):
rev_node.ops = [ast.GtE()]
elif isinstance(op, ast.Eq):
rev_node.ops = [ast.NotEq()]
elif isinstance(op, ast.In):
rev_node.ops = [ast.NotIn()]
else:
raise ConstraintError('Unknown operator: %s' % op)
return rev_node
def _revert_test(node):
"""
Helper function to revert a test node to its negation.
"""
rev_node = copy.deepcopy(node)
# Allow syntax like 'any(a is b)' or 'all(c in d)'
if isinstance(rev_node, ast.Call):
func_name = rev_node.func.id
assert len(rev_node.args) == 1
assert isinstance(rev_node.args[0], ast.Compare)
rev_node.args[0] = _revert_compare(rev_node.args[0])
if func_name == 'any':
rev_node.func.id = 'all'
elif func_name == 'all':
rev_node.func.id = 'any'
elif isinstance(rev_node, ast.Compare):
rev_node = _revert_compare(node)
else:
raise ConstraintError(
'Unknown test node type: %s' % node.__class__.__name__)
return rev_node
def _parse_if(node):
cur_trace.append(node.test)
_parse_block(node.body)
cur_trace.pop()
rev_node = _revert_test(node.test)
cur_trace.append(rev_node)
_parse_block(node.orelse)
cur_trace.pop()
def _parse_assert(node):
cur_trace.append(node.test)
def _parse_block(nodes):
for node in nodes:
if isinstance(node, ast.If):
_parse_if(node)
elif isinstance(node, ast.Assert):
_parse_assert(node)
elif isinstance(node, ast.Return):
cur_trace.append(node)
traces.append(trace.Trace(self.provider, cur_trace))
cur_trace.pop()
else:
raise ConstraintError(
'Unknown node type: %s' % node.__class__.__name__)
def _parse_module(node):
# Accept a single expression to define the only valid condition
if isinstance(node.body[0], ast.Expr):
assert len(node.body) == 1
cur_trace.append(node.body[0].value)
ret = ast.parse('return success()').body[0]
cur_trace.append(ret)
traces.append(trace.Trace(self.provider, cur_trace))
cur_trace.pop()
cur_trace.pop()
else:
_parse_block(v.body)
oracle = _translate(oracle)
root = ast.parse(oracle)
stack = []
stack.append((root))
observed = []
traces = []
cur_trace = []
while stack:
v = stack.pop()
if v not in observed:
if isinstance(v, ast.If):
_parse_if(v)
elif isinstance(v, ast.Module):
_parse_module(v)
else:
raise ConstraintError(
'Unknown node type: %s' % v.__class__.__name__)
return traces
def _choose(self, fail_ratio=None):
fails = []
passes = []
for t in self.traces:
if t.result == 'success':
passes.append(t)
elif t.result == 'fail':
fails.append(t)
if fail_ratio is None:
fail_ratio = self.fail_ratio
if not fails and not passes:
raise ConstraintError(
"Need return function fail() or success() in oracle '%s'" %
self.name)
if not fails:
return random.choice(passes)
if not passes:
return random.choice(fails)
if random.random() < fail_ratio:
return random.choice(fails)
else:
return random.choice(passes)
[docs] def apply(self, item):
"""
Apply this constraint to an item.
:param item: The item to be applied on.
:return: Expected result of constraint item.
"""
def _name2path(name):
if not name.startswith(self.path_prefix):
return name
return name[len(self.path_prefix):].replace('_', '/')
t = self._choose()
sols = t.solve(item)
for name, sol in sols.items():
item.set(_name2path(name), sol)
patts = t.result_patts
if patts is not None:
if isinstance(patts, list):
item.fail_patts |= patts
else:
item.fail_patts.add(patts)
return t.result
def __repr__(self):
return '<%s %s>' % (self.__class__.__name__, self.name)