Skip to content
Snippets Groups Projects
Commit 0f3cc5a1 authored by chsieh16's avatar chsieh16
Browse files

Use sympy ITE expressions for representing candidate dtree

parent 48b398f2
No related branches found
No related tags found
No related merge requests found
......@@ -4,9 +4,10 @@ import itertools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, MutableSet, Tuple
import numpy as np
import sympy
from learner_base import LearnerBase
......@@ -15,11 +16,11 @@ class DTreeLearner(LearnerBase):
def __init__(self, state_dim: int, perc_dim: int,
timeout: int = 10000) -> None:
super().__init__()
self.debug_neg_conc= set()
self.debug_neg_perc= set()
self.debug_neg_conc = set() # type: MutableSet[Tuple[float,...]]
self.debug_neg_perc = set() # type: MutableSet[Tuple[float,...]]
self._state_dim: int = state_dim
self._perc_dim: int = perc_dim
self.count_neg_dup =0
self.count_neg_dup = 0
self._s2f_func = lambda x: x
# Given a base or derived feature name,
......@@ -90,7 +91,7 @@ class DTreeLearner(LearnerBase):
@staticmethod
def _compose_s2f_functions(s2f_func_list):
def composed_func(sample):
return sum((list(f(sample)) for f in s2f_func_list),[])
return sum((list(f(sample)) for f in s2f_func_list), [])
return composed_func
@staticmethod
......@@ -130,18 +131,15 @@ class DTreeLearner(LearnerBase):
def add_negative_examples(self, *args) -> None:
for samp in args:
if samp in self.debug_neg_conc:
self.count_neg_dup+=1
print("repeated negative example: "+ str(samp))
raise ValueError()
self.count_neg_dup += 1
raise ValueError("repeated negative example: " + str(samp))
perc_samp = tuple(self._s2f_func(samp))
print(tuple(perc_samp))
if perc_samp in self.debug_neg_perc:
print("repeated negative example: "+ str(perc_samp))
raise ValueError()
raise ValueError("repeated negative example: " + str(perc_samp))
self.debug_neg_perc.add(perc_samp)
self.debug_neg_conc.add(samp)
print(f"number of negative duplicate {self.count_neg_dup}")
feature_vec_list = [self._s2f_func(sample) for sample in args]
......@@ -152,93 +150,57 @@ class DTreeLearner(LearnerBase):
with open(self.data_file, 'a') as d_file:
data_out = csv.writer(d_file)
for f in feature_vec_list:
print(f)
data_out.writerow(itertools.chain(f, [label]))
row = itertools.chain(f, [label]) # append label at the end of each row
data_out.writerow(row)
def learn(self) -> List[Tuple]:
def learn(self) -> sympy.logic.boolalg.Boolean:
res = os.popen(self.exec).read()
print(res)
assert os.path.exists(self.tree_out), "if learned successfully" \
f"there should be a json file in {self.dir_name}"
dnf = self.get_pre_from_json(self.tree_out)
ite_expr = self.get_pre_from_json(self.tree_out)
os.remove(self.tree_out) # Remove the generated json to avoid reusing old trees
ret_dnf: List[Tuple] = []
for conjunct in dnf:
ret_trans = ()
ret_coeffs_list = []
ret_cut_list = []
if len(conjunct) == 0:
# Conjunction of zero clauses is defined as True
# return 0*(z-(0*x-0)) <= inf which is equivalent to True
a_mat = np.zeros(shape=(self.perc_dim, self.state_dim))
b_vec = np.zeros(self.perc_dim)
coeffs_mat = np.zeros(shape=(1, self.perc_dim))
cut_vec = np.array([np.inf])
ret_dnf.append((a_mat, b_vec, coeffs_mat, cut_vec))
continue
# else:
for pred in conjunct:
var, op, cut = pred
coeff_arr = np.zeros(self.perc_dim)
# Convert from dictionary to coefficients
# XXX May use sparse matrix from scipy
for basevar, coeff in self._var_coeff_map[var].items():
trans, j = self._basevar_trans_map[basevar]
coeff_arr[j] = coeff
if not ret_trans:
ret_trans = trans
elif ret_trans != trans:
raise NotImplementedError(
"Not supporting mixing affine transformations in one conjunct.")
if op == "<=":
pass
elif op == ">": # TODO deal with strict unequality
coeff_arr = -coeff_arr
cut = -cut
else:
raise ValueError(f"Unknown operator '{op}'")
ret_coeffs_list.append(coeff_arr)
ret_cut_list.append(cut)
ret_coeffs_mat = np.stack(ret_coeffs_list)
ret_cut_vec = np.array(ret_cut_list)
assert ret_trans
ret_dnf.append(ret_trans + (ret_coeffs_mat, ret_cut_vec))
return ret_dnf
def get_pre_from_json(self, path):
ite_expr = self._subs_basevar_w_states(ite_expr)
return ite_expr
def _subs_basevar_w_states(self, ite_expr) -> sympy.logic.boolalg.Boolean:
state_vars = sympy.symbols([f"x[{i}]" for i in range(self.state_dim)])
state_vec = sympy.Matrix(state_vars)
perc_vars = sympy.symbols([f"z[{i}]" for i in range(self.perc_dim)])
perc_vec = sympy.Matrix(perc_vars)
subs_basevar = []
for basevar, (trans, j) in self._basevar_trans_map.items():
a_mat, b_vec = trans
a_mat, b_vec = sympy.Matrix(a_mat), sympy.Matrix(b_vec)
expanded_basevar = (perc_vec - (a_mat @ state_vec + b_vec))[j]
subs_basevar.append((basevar, expanded_basevar))
return ite_expr.subs(subs_basevar)
@staticmethod
def get_pre_from_json(path):
try:
with open(path) as json_file:
tree = json.load(json_file)
return self.parse_tree(tree)
return DTreeLearner.parse_tree(tree)
except json.JSONDecodeError:
raise ValueError(f"cannot parse {path} as a json file")
def parse_tree(self, tree) -> Optional[List[List]]:
@staticmethod
def parse_tree(tree) -> sympy.logic.boolalg.Boolean:
if tree['children'] is None:
# At a leaf node, return the clause
if tree['classification']:
return [[]] # Non-none value represtns a True leaf node
return sympy.true # True leaf node
else:
return None
return sympy.false # False leaf node
elif len(tree['children']) == 2:
# Post-order traversal
left = self.parse_tree(tree['children'][0])
right = self.parse_tree(tree['children'][1])
if left is None and right is None:
return None
res_left = []
if left is not None:
res_left = [[(tree['attribute'], "<=", tree['cut'])] + conjunct
for conjunct in left]
res_right = []
if right is not None:
res_right = [[(tree['attribute'], ">", tree['cut'])] + conjunct
for conjunct in right]
assert res_left or res_right
return res_left + res_right
left = DTreeLearner.parse_tree(tree['children'][0])
right = DTreeLearner.parse_tree(tree['children'][1])
# Create an ITE expression tree
cond = sympy.sympify(f"{tree['attribute']} <= {tree['cut']}")
return sympy.logic.boolalg.ITE(cond, left, right)
else:
raise ValueError("error parsing the json object as a binary decision tree)")
......@@ -248,10 +210,9 @@ def construct_sample_to_feature_func(a_mat: np.ndarray, b_vec: np.ndarray):
def sample_to_feature_vec(sample):
assert len(sample) == state_dim + perc_dim
state = sample[0: state_dim]
perc = sample[state_dim: state_dim+perc_dim]
perc_bar = perc - (np.dot(state, a_mat.T) + b_vec)
return perc_bar
state = np.array(sample[0: state_dim])
perc = np.array(sample[state_dim: state_dim+perc_dim])
return perc - (a_mat @ state + b_vec)
return sample_to_feature_vec
......@@ -309,6 +270,36 @@ def test_sample_to_feature():
assert np.array_equal(feature_vec, np.array([1., 1.]))
def test_parse_json():
json_obj = json.loads("""
{"attribute":"((1*fvar0_A0) + (1*fvar1_A0))","cut":-0.01318414,"classification":0,
"children":[
{"attribute":"fvar1_A0","cut":0.01403625,"classification":0,
"children":[{"attribute":"","cut":0,"classification":true,"children":null},
{"attribute":"","cut":0,"classification":false,"children":null}]
},
{"attribute":"fvar1_A1","cut":-0.003193465,"classification":0,
"children":[{"attribute":"","cut":0,"classification":true,"children":null},
{"attribute":"","cut":0,"classification":false,"children":null}]
}
]
}""")
tree = DTreeLearner.parse_tree(json_obj)
a_mat_0 = np.array([[0., -1., 0.],
[0., 0., -1.]])
b_vec_0 = np.zeros(2)
a_mat_1 = np.array([[0., -0.75, 0.],
[0., 0., -1.25]])
b_vec_1 = np.zeros(2)
learner = DTreeLearner(state_dim=3, perc_dim=2)
learner.set_grammar([(a_mat_0, b_vec_0), (a_mat_1, b_vec_1)])
print(learner._subs_basevar_w_states(tree))
if __name__ == "__main__":
# test_sample_to_feature()
test_dtree_learner()
test_parse_json()
......@@ -4,6 +4,7 @@ from typing import Dict, Hashable, List, Literal, Sequence, Tuple
import gurobipy as gp
import numpy as np
import sympy
import z3
import dreal
......@@ -13,7 +14,7 @@ class TeacherBase(abc.ABC):
pass
@abc.abstractmethod
def check(self, candidate) -> z3.CheckSatResult:
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
raise NotImplementedError
@abc.abstractmethod
......@@ -81,7 +82,9 @@ class DRealTeacherBase(TeacherBase):
def ctrl_dim(self) -> int:
return len(self._control)
def _gen_cand_pred(self, candidate):
def _gen_cand_pred(self, candidate: sympy.logic.boolalg.Boolean):
raise NotImplementedError("TODO convert from sympy")
a_mat, b_vec, coeff_mat, cut_vec = candidate # TODO parse candidate
# VARIABLE ALIASES
......@@ -104,15 +107,13 @@ class DRealTeacherBase(TeacherBase):
def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
self._set_var_bound(self._old_state, lb, ub)
def check(self, candidate) -> z3.CheckSatResult:
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
bound_preds = [
dreal.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()
]
cand_pred = self._gen_cand_pred(candidate)
print(cand_pred)
query = dreal.And(*bound_preds, *self._not_inv_cons)
print(query)
query = dreal.And(*bound_preds, cand_pred, *self._not_inv_cons)
res = dreal.CheckSatisfiability(query, self._delta)
self._models.clear()
......@@ -121,11 +122,8 @@ class DRealTeacherBase(TeacherBase):
tuple(res[x_i] for x_i in self._old_state) + tuple(res[z_i] for z_i in self._percept)
)
return z3.sat
elif not res:
return z3.unsat
else:
print("dReal result", res)
raise NotImplementedError
return z3.unsat
def dump_model(self, basename: str = "") -> None:
print([
......@@ -196,7 +194,9 @@ class GurobiTeacherBase(TeacherBase):
def _add_unsafe(self) -> None:
raise NotImplementedError
def _set_candidate(self, candidate) -> None:
def _set_candidate(self, candidate: sympy.logic.boolalg.Boolean) -> None:
raise NotImplementedError("TODO convert from sympy")
shape_sel, coeff, intercept, radius = candidate # TODO parse candidate
# Variable Aliases
m = self._gp_model
......@@ -240,7 +240,7 @@ class GurobiTeacherBase(TeacherBase):
self._old_state.lb = lb
self._old_state.ub = ub
def check(self, candidate) -> z3.CheckSatResult:
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
self._set_candidate(candidate)
self._gp_model.optimize()
if self._gp_model.status == gp.GRB.INF_OR_UNBD:
......@@ -275,3 +275,73 @@ class GurobiTeacherBase(TeacherBase):
return '(model is unbounded)'
else: # TODO other status code
return '(gurobi model status code %d)' % self._gp_model.status
class SymPyTeacherBase(TeacherBase):
def __init__(self, name: str,
state_dim: int, perc_dim: int, ctrl_dim: int, norm_ord: Literal[1, 2, "inf"]) -> None:
self._norm_ord = norm_ord
# Old state variables
self._old_state = sympy.symbols(names=[f"x[{i}]" for i in range(state_dim)], real=True)
# New state variables
self._new_state = sympy.symbols(names=[f"x'[{i}]" for i in range(state_dim)], real=True)
# Perception variables
self._percept = sympy.symbols(names=[f"z[{i}]" for i in range(perc_dim)], real=True)
# Control variables
self._control = sympy.symbols(names=[f"u[{i}]" for i in range(ctrl_dim)], real=True)
self._var_bounds = {
var: (-sympy.oo, sympy.oo)
for var in itertools.chain(self._old_state, self._new_state, self._percept, self._control)
} # type: Dict[dreal.Variable, Tuple[float, float]]
self._not_inv_cons = [] # type: List
self._add_system()
self._add_unsafe()
@abc.abstractmethod
def _add_system(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_unsafe(self) -> None:
raise NotImplementedError
@property
def state_dim(self) -> int:
return len(self._old_state)
@property
def perc_dim(self) -> int:
return len(self._percept)
@property
def ctrl_dim(self) -> int:
return len(self._control)
def _set_var_bound(self, sym_vars: List[dreal.Variable], lb: Sequence[float], ub: Sequence[float]) -> None:
assert len(lb) == len(sym_vars)
assert len(ub) == len(sym_vars)
for var_i, lb_i, ub_i in zip(sym_vars, lb, ub):
self._var_bounds[var_i] = (lb_i, ub_i)
def set_old_state_bound(self, lb: Sequence[float], ub: Sequence[float]) -> None:
self._set_var_bound(self._old_state, lb, ub)
def check(self, candidate: sympy.logic.boolalg.Boolean) -> z3.CheckSatResult:
raise NotImplementedError
def dump_system_formula(self, basename: str = "") -> None:
query = sympy.And(
*(sympy.And(lb_i <= var_i, var_i <= ub_i)
for var_i, (lb_i, ub_i) in self._var_bounds.items()),
*self._not_inv_cons)
print(sympy.simplify(query))
def model(self) -> Sequence[Tuple]:
raise NotImplementedError
def reason_unknown(self) -> str:
raise NotImplementedError
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment