Skip to content
Snippets Groups Projects
Commit 4fd9c950 authored by aastorg2's avatar aastorg2
Browse files

properly finding Ais

parent 48278131
No related branches found
No related tags found
No related merge requests found
...@@ -6,56 +6,7 @@ from firstpass_learner import FirstpassLearner ...@@ -6,56 +6,7 @@ from firstpass_learner import FirstpassLearner
from firstpass_teacher import FirstpassTeacher from firstpass_teacher import FirstpassTeacher
def visitor(e, seen):
if e in seen:
return
seen[e] = True
yield e
if is_app(e):
# print(" call to ToReal()")
# print(e)
# and e.decl() == "+"
descendants = e.children()
for ch in descendants:
# print("child. "+ str(ch))
# print()
for e in visitor(ch, seen):
if is_app_of(e, Z3_OP_ADD) and e.num_args() == 3:
yield e
# return
return
# ANGELLO: DONT NEED THIS METHOD
if is_quantifier(e):
for e in visitor(e.body(), seen):
yield e
return
def extractAiBi(term):
assert is_app_of(term, Z3_OP_ADD) and term.num_args() == 3
seen = {}
def subterms(term):
descendants = term.children()
for ch in descendants:
if ch in seen:
continue
seen[ch] = True
yield ch
constants_terms = subterms(ch)
for sub in constants_terms:
if sub.decl().kind() == Z3_OP_ANUM:
yield RealVal(sub)
if sub.decl().kind() == Z3_OP_UMINUS:
yield sub
ret = [t for t in subterms(term)]
return ret
# x, y = Ints('x y') # x, y = Ints('x y')
# fml = x + x + y > 2 # fml = x + x + y > 2
# seen = {} # seen = {}
...@@ -78,12 +29,14 @@ def test_synth_region(): ...@@ -78,12 +29,14 @@ def test_synth_region():
# teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0]) # teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
synth_region(positive_examples, teacher) synth_region(positive_examples, teacher)
#TODO: Angello consider a returning a tuple instead of just candidates
def synth_region(positive_examples, teacher, num_max_iterations: int = 10): def synth_region(positive_examples, teacher, num_max_iterations: int = 10):
learner = FirstpassLearner(state_dim=teacher.state_dim, learner = FirstpassLearner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim) perc_dim=teacher.perc_dim)
learner.add_positive_examples(*positive_examples) learner.add_positive_examples(*positive_examples)
past_candidate_list = [] past_candidate_list = []
for k in range(num_max_iterations): for k in range(num_max_iterations):
candidate = learner.learn() candidate = learner.learn()
if candidate is None: # learning FAILED if candidate is None: # learning FAILED
...@@ -107,22 +60,43 @@ def synth_region(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -107,22 +60,43 @@ def synth_region(positive_examples, teacher, num_max_iterations: int = 10):
def test_parse_sygus_output(): def test_parse_sygus_output():
pass # Angello: create learner and add constraints to grammar
# Angello create learner and add constraints to grammar x1 = Real('x1')
# x = Real('x') x2 = Real('x2')
# y = Real('y') z1 = Real('z1')
# x_p = Real('x_p') z2 = Real('z2')
# y_p = Real('y_p') l2_norm = Function("L2_norm",RealSort(),RealSort(),RealSort() )
# sqr = Function('sqr', RealSort(),RealSort() ) # sqr = Function('sqr', RealSort(),RealSort() )
# variable_map = { str(x):x, str(y):y, str(x_p):x_p, str(y_p):y_p, str(sqr):sqr} variable_map = { str(x1):x1, str(x2):x2, str(z1):z1, str(z2):z2, str(l2_norm):l2_norm}
# res = os.popen('../cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 circleSygus.sl').read() res = os.popen('../cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 circleSygus.sl').read()
# print(res) print(f"cvc4 output {res}")
# print("######") expr_body = res[66:-2].replace("(* 1 x1)", "(* (- 1) x1)")
# exprBody = res[67:-2].strip() print(f"expr body {expr_body}")
# print(exprBody) z3_vector_expr:AstVector = parse_smt2_string("(assert "+expr_body+" )", decls=variable_map)
# print("======") print(z3_vector_expr)
get_constants(z3_vector_expr[0])
# z3_expr = parse_smt2_string("(assert "+exprBody+" )", decls=variable_map)
def get_constants(e):
r = set()
def collect(t):
#if is_app(t):
if is_app_of(t, Z3_OP_MUL) and t.num_args() == 2:
print("multiplication term: "+ f"{t}")
print("constant: "+ f"{simplify(t.arg(0)) }")
print("constant: "+ f"{type(simplify(t.arg(0))) }")
# return
else:
for c in t.children():
collect(c)
return
collect(e)
return []
#for a in z3_expr:
# print(type(a)) # bool ref
# seen = {} # seen = {}
# seen_sub = {} # seen_sub = {}
# sub_term_sqr = [] # sub_term_sqr = []
...@@ -149,4 +123,6 @@ def test_parse_sygus_output(): ...@@ -149,4 +123,6 @@ def test_parse_sygus_output():
if __name__ == "__main__": if __name__ == "__main__":
test_synth_region() test_parse_sygus_output()
#test_synth_region()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment