Source code for dice.core.trace

import ast
import builtins
import inspect
import logging
import sys

from . import symbol


logger = logging.getLogger(__name__)


[docs]class TraceError(Exception): """ Class for trace specific exceptions. """ pass
[docs]class Trace(object): """ Class represent a condition trace in constraint oracle code. It contains a list of commands, including comparisons, operations and ends with a return command. """ def __init__(self, provider, trace_list): """ :param trace_list: A list contains code of the trace. """ self.item = None self.provider = provider self.symbols = {} self.trace = trace_list[:] ret = trace_list[-1] assert isinstance(ret, ast.Return) self.result = ret.value.func.id.lower() args = ret.value.args self.result_patts = None if args: self.result_patts = args[0].s def __repr__(self): lines = [] for line in self.trace: if isinstance(line, ast.Compare): s = str(line.ops[0].__class__.__name__) elif isinstance(line, ast.Call): s = line.func.id else: s = line.value.func.id lines.append(s) return repr(lines) def _exec_call(self, node): func_name = node.func.attr pkg_name = node.func.value.id mod_name = '.'.join([self.provider.name, 'utils', pkg_name]) mod = sys.modules[mod_name] func = getattr(mod, func_name) args = [] for arg in node.args: if isinstance(arg, ast.Name): name = arg.id args.append(self.item.get(name)) else: raise TraceError('Unknown argument type: %s' % arg) return func(*args) def _proc_compare(self, node): assert len(node.ops) == 1 assert len(node.comparators) == 1 assert isinstance(node.left, ast.Name) left = node.left.id op = node.ops[0].__class__.__name__ comparator = node.comparators[0] known_symbols = [] for name in dir(symbol): obj = getattr(symbol, name) if inspect.isclass(obj) and issubclass(obj, symbol.SymbolBase): known_symbols.append(name) exc_types = [] right_value = None if isinstance(comparator, ast.Name): if comparator.id not in known_symbols: raise TraceError("Unknown symbol '%s'" % comparator.id) if op == 'IsNot': sym_type = 'Bytes' exc_types.append(comparator.id) else: sym_type = comparator.id elif isinstance(comparator, ast.Num): sym_type = 'Integer' right_value = comparator.n elif isinstance(comparator, ast.Str): sym_type = 'Bytes' right_value = comparator.s if isinstance(comparator, ast.Call): call_ret = self._exec_call(comparator) test_val = call_ret if isinstance(call_ret, (list, tuple)): test_val = call_ret[0] if isinstance(test_val, builtins.str): sym_type = 'Bytes' elif isinstance(test_val, int): sym_type = 'Integer' if left not in self.symbols: self.symbols[left] = getattr( symbol, sym_type)(exc_types=[exc_types]) sleft = self.symbols[left] sleft_type = sleft.__class__.__name__ if op != 'IsNot': if not issubclass(sleft.__class__, getattr(symbol, sym_type)): raise TraceError( 'Unmatched type %s(operator: %s). Should be %s' % (sym_type, op, sleft_type)) if op == 'Is': pass elif op == 'IsNot': pass elif op == 'Eq': if sleft.scope and right_value not in sleft.scope: raise Exception( 'Unsatisfiable condition. Need equal to "%s", ' 'but scope is %s' % (right_value, sleft.scope) ) sleft.scope = [right_value] elif op == 'NotEq': if sleft.excs is None: sleft.excs = [] sleft.excs.append(right_value) elif op == 'Lt': if sleft_type == 'Integer': sleft.maximum = right_value - 1 elif op == 'LtE': if sleft_type == 'Integer': sleft.maximum = right_value elif op == 'Gt': if sleft_type == 'Integer': sleft.minimum = right_value + 1 elif op == 'GtE': if sleft_type == 'Integer': sleft.minimum = right_value elif op == 'In': sleft.scope = call_ret elif op == 'NotIn': sleft.excs = call_ret else: raise TraceError('Unknown operator: %s' % op) def _proc_call(self, node): func_name = node.func.id assert func_name in ['any', 'all'] assert isinstance(node.args[0], ast.Compare) comp = node.args[0] op = comp.ops[0].__class__.__name__ left = comp.left right = comp.comparators[0] if isinstance(left, ast.Name): sym_left = self.symbols[left.id] assert isinstance(right, ast.Call) right = self._exec_call(right) assert isinstance(right, (list, tuple)) assert op in ['In', 'NotIn'] if func_name == 'all': if op == 'In': sym_left.scopes.append((right, True, 0)) elif op == 'NotIn': sym_left.scopes.append((right, False, 1)) elif func_name == 'any': if op == 'In': sym_left.scopes.append((right, True, 1)) sym_left.scopes.append((right, False, 0)) elif op == 'NotIn': sym_left.scopes.append((right, True, 0)) sym_left.scopes.append((right, False, 1)) elif isinstance(left, ast.Call): sym_right = self.symbols[right.id] left = self._exec_call(left) if func_name == 'all': if op == 'In': raise Exception('TODO') elif op == 'NotIn': sym_right.excludes = left elif func_name == 'any': if op == 'In': pass # TODO else: raise TraceError('Unknown left type %s' % left)
[docs] def solve(self, item): """ Generate a satisfiable random option according to this trace. :param item: Item to which generated option applies. :return: Generated random option. """ self.item = item self.symbols = {} for node in self.trace: if isinstance(node, ast.Compare): self._proc_compare(node) elif isinstance(node, ast.Call): self._proc_call(node) elif isinstance(node, ast.Return): result = {} for name, sym in self.symbols.items(): result[name] = sym.model() return result else: raise TraceError('Unknown node type: %s' % type(node))