Skip to content
Snippets Groups Projects
Commit b6124e8f authored by aastorg2's avatar aastorg2
Browse files

sygus learner done

properly parsing cvc4 output into shape and body

adding cvc4 executable
parent 73f6c261
No related branches found
No related tags found
No related merge requests found
File added
;; The background theory ;; The background theory
(set-logic NRA) (set-logic LRA)
(define-fun split ((x Real) (c Real) (then_pred Bool) (e1se_pred Bool)) Bool (define-fun split ((x Real) (c Real) (then_pred Bool) (e1se_pred Bool)) Bool
(and (=> (>= x c) then_pred) (=> (<= x c) e1se_pred)) (and (=> (>= x c) then_pred) (=> (<= x c) e1se_pred))
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
(define-fun max ((x1 Real) (x2 Real)) Real (ite (>= x1 x2) x1 x2)) (define-fun max ((x1 Real) (x2 Real)) Real (ite (>= x1 x2) x1 x2))
(define-fun Loo_norm ((x1 Real) (x2 Real)) Real (max (Abs x1) (Abs x2))) ; squares (define-fun Loo_norm ((x1 Real) (x2 Real)) Real (max (Abs x1) (Abs x2))) ; squares
(synth-fun inShape ((x1 Real) (x2 Real) (z1 Real) (z2 Real)) Bool (synth-fun inShape ((x1 Real) (x2 Real) (x3 Real) (z1 Real) (z2 Real)) Bool
;; Declare the non-terminals that would be used in the grammar ;; Declare the non-terminals that would be used in the grammar
( (B Bool) (Dist Real) (X Real) (C Real)) ( (B Bool) (Dist Real) (X Real) (C Real))
...@@ -25,11 +25,12 @@ ...@@ -25,11 +25,12 @@
(split X C B B) (split X C B B)
)) ))
(Dist Real ( (Dist Real (
(L1_norm (- z1 (+ (* C x1) (* C x2) C)) (- z2 (+ (* C x1) (* C x2) C))) (L1_norm (- z1 (+ (* C x1) (* C x2) (* C x3) C) ) (- z2 (+ (* C x1) (* C x2) (* C x3) C)))
(L2_norm (- z1 (+ (* C x1) (* C x2) C)) (- z2 (+ (* C x1) (* C x2) C))) (L2_norm (- z1 (+ (* C x1) (* C x2) (* C x3) C) ) (- z2 (+ (* C x1) (* C x2) (* C x3) C)))
(Loo_norm (- z1 (+ (* C x1) (* C x2) C)) (- z2 (+ (* C x1) (* C x2) C))) (Loo_norm (- z1 (+ (* C x1) (* C x2) (* C x3) C) ) (- z2 (+ (* C x1) (* C x2) (* C x3) C)))
; TODO ellipsoid norm can be done with an upper triangle matrix U by ||Uv|| <= r ; TODO ellipsoid norm can be done with an upper triangle matrix U by ||Uv|| <= r
)) )
)
(X Real (x1 x2)) (X Real (x1 x2))
(C Real ( (Constant Real) )) (C Real ( (Constant Real) ))
) )
...@@ -41,13 +42,13 @@ ...@@ -41,13 +42,13 @@
; (declare-var z2 Real) ; (declare-var z2 Real)
;; Define the semantic constraints on the function ;; Define the semantic constraints on the function
(constraint (inShape 10 (- 10) 11 (- 11))) ; (constraint (inShape 2.372100525863791 0.9678239108338289 0.21114638864503218 3.1077494621276855 (- 0.16957582533359528)) )
(constraint (inShape 10 (- 10) 9 (- 9))) ; (constraint (inShape 20.547616744645502 0.9678266805486828 0.21114634046254058 0.08310934901237488 (- 0.1517469435930252)) )
(constraint (inShape 1 1 1.5 2.3 )) ; (constraint (inShape 18.006150893559386 0.9678099296179283 0.21114528600009771 3.360142707824707 (- 0.1889626532793045)) )
(constraint (not (inShape 1 1 0.5 0.5))) ;(constraint (not (inShape 1 1 0.5 0.5)))
(constraint (not (inShape 1 1 1.5 1.5))) ;(constraint (not (inShape 1 1 1.5 1.5)))
(constraint (not (inShape 9 9 5 5)))
;(constraint (=> (<= (L2_norm (- z1 (+ (* 1 x1) (* 0 x2) 0)) (- z2 (+ (* 1 x1) (* 2 x2) 0))) 2) (inShape x1 x2 z1 z2) ) )
; outputs (define-fun inShape ((x1 Real) (x2 Real) (z1 Real) (z2 Real)) Bool (<= (Loo_norm (- z1 (+ (* 0 x1) (* 0 x2) 6)) (- z2 (+ (* 0 x1) (* 1 x2) 0))) 5))
;(constraint (=> (<= (L2_norm (- z1 (+ (* 1 x1) (* 0 x2) 0)) (- z2 (+ (* 1 x1) (* 2 x2) 0))) 2) (inShape x1 x2 z1 z2) ) )
(check-synth) (check-synth)
...@@ -47,17 +47,17 @@ def test_synth_region(): ...@@ -47,17 +47,17 @@ def test_synth_region():
# teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0]) # teacher.set_old_state_bound(lb=[6.0, 5.0], ub=[11.0, 10.0])
#[::20] #[::20]
synth_region(positive_examples[:20:], teacher, num_max_iterations=100) synth_region(positive_examples[20:40:], teacher, num_max_iterations=20)
#Gurobi encoding
# 0 is diamond
# 1 s circle
# 2 is squares
#TODO: Angello consider a returning a tuple instead of just candidates #TODO: Angello consider a returning a tuple instead of just candidates
def synth_region(positive_examples, teacher, num_max_iterations: int = 10): def synth_region(positive_examples, teacher, num_max_iterations: int = 10):
learner = Learner(state_dim=teacher.state_dim, learner = Learner(state_dim=teacher.state_dim,
perc_dim=teacher.perc_dim, timeout=20000) perc_dim=teacher.perc_dim, timeout=20000)
print(positive_examples[0])
#return
# 0 is diamond
# 1 s circle
# 2 is squares
learner.add_positive_examples(*positive_examples) learner.add_positive_examples(*positive_examples)
past_candidate_list = [] past_candidate_list = []
...@@ -69,7 +69,10 @@ def synth_region(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -69,7 +69,10 @@ def synth_region(positive_examples, teacher, num_max_iterations: int = 10):
#return #return
for k in range(num_max_iterations): for k in range(num_max_iterations):
print(f"Iteration {k}:", sep='') print(f"Iteration {k}:", sep='')
print("learning ....")
candidate = learner.learn() candidate = learner.learn()
print("done learning")
if candidate is None: # learning FAILED if candidate is None: # learning FAILED
print("Learning Failed.") print("Learning Failed.")
return return
...@@ -84,7 +87,7 @@ def synth_region(positive_examples, teacher, num_max_iterations: int = 10): ...@@ -84,7 +87,7 @@ def synth_region(positive_examples, teacher, num_max_iterations: int = 10):
# TODO check if negative example state is spurious or true courterexample # TODO check if negative example state is spurious or true courterexample
print(f"negative examples: {m}") print(f"negative examples: {m}")
learner.add_negative_examples(*m) learner.add_negative_examples(*m)
return
elif result == z3.unsat: elif result == z3.unsat:
print("we are done!") print("we are done!")
return past_candidate_list return past_candidate_list
......
...@@ -73,24 +73,44 @@ class SygusLearner(LearnerBase): ...@@ -73,24 +73,44 @@ class SygusLearner(LearnerBase):
loo_norm = Function("Loo_norm",RealSort(),RealSort(),RealSort() ) loo_norm = Function("Loo_norm",RealSort(),RealSort(),RealSort() )
# sqr = Function('sqr', RealSort(),RealSort() ) # sqr = Function('sqr', RealSort(),RealSort() )
variable_map = { str(x1):x1, str(x2):x2,str(x3):x3, str(z1):z1, str(z2):z2,str(loo_norm):loo_norm ,str(l1_norm):l1_norm, str(l2_norm):l2_norm} variable_map = { str(x1):x1, str(x2):x2,str(x3):x3, str(z1):z1, str(z2):z2,str(loo_norm):loo_norm ,str(l1_norm):l1_norm, str(l2_norm):l2_norm}
res = os.popen('../cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 firstpass_learner.sl').read() res = os.popen('bin/cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 firstpass_learner.sl').read()
print(f"cvc4 output {res}") print(f"cvc4 output {res}")
expr_body = res[77:-2]#.replace("(* 1 x1)", "(* (- 1) x1)") expr_shape_body = SygusLearner.get_shape_and_Body(res)
print(f"expr body\n{expr_body}")
z3_vector_expr:AstVector = parse_smt2_string("(assert "+expr_body+" )", decls=variable_map) print(f"expr body\n{expr_shape_body[1]}")
print(z3_vector_expr) z3_vector_expr:AstVector = parse_smt2_string("(assert "+expr_shape_body[1]+" )", decls=variable_map)
#print(z3_vector_expr)
constants = SygusLearner.get_constants(z3_vector_expr[0]) constants = SygusLearner.get_constants(z3_vector_expr[0])
#TODO: check that constants have valid values #TODO: check that constants have valid values
if len(constants) == 4: conjecture: Tuple = (expr_shape_body[0],)+constants
return constants return conjecture
@staticmethod @staticmethod
def get_shape_and_Body: def get_shape_and_Body(cvc4_output:str):
# Gurobi encoding # Gurobi encoding
# 0 is diamond # 0 is diamond
# 1 s circle # 1 s circle
# 2 is squares # 2 is squares
pass ret_index = cvc4_output.index("Bool")
body = cvc4_output[ret_index+4: len(cvc4_output)-2].lstrip()
ret = tuple()
if "L1_norm" in body:
#diamond
ret += (0,)
elif "L2_norm" in body:
#circle
ret += (1,)
elif "Loo_norm" in body:
#squares
ret += (2,)
ret +=(body,)
return ret
@staticmethod
def to_float(var: z3.ExprRef) -> float:
x2 = var.as_fraction()
return float(x2.numerator) / float(x2.denominator)
@staticmethod @staticmethod
def get_constants(e): def get_constants(e):
...@@ -103,15 +123,16 @@ class SygusLearner(LearnerBase): ...@@ -103,15 +123,16 @@ class SygusLearner(LearnerBase):
if is_app_of(e, Z3_OP_LE) and e.num_args() == 2: if is_app_of(e, Z3_OP_LE) and e.num_args() == 2:
print("radius: "+ str(e.arg(1))) print("radius: "+ str(e.arg(1)))
simp_arg = simplify(e.arg(1)) simp_arg = simplify(e.arg(1))
float_val = float(simp_arg.as_string()) #float_val = float(simp_arg.as_string())
float_val = SygusLearner.to_float(simp_arg)
print(float_val) print(float_val)
radius = float_val radius = float_val
def collect(t): def collect(t):
if is_app_of(t, Z3_OP_MUL) and t.num_args() == 2: if is_app_of(t, Z3_OP_MUL) and t.num_args() == 2:
print("multiplication term: "+ f"{t}") # print("multiplication term: "+ f"{t}")
print("constant: "+ f"{simplify(t.arg(0)) }") # print("constant: "+ f"{simplify(t.arg(0)) }")
print("constant: "+ f"{type(simplify(t.arg(0))) }") # print("constant: "+ f"{type(simplify(t.arg(0))) }")
simp_arg = simplify(t.arg(0)) simp_arg = simplify(t.arg(0))
float_val = float(simp_arg.as_string()) float_val = float(simp_arg.as_string())
print(float_val) print(float_val)
...@@ -119,9 +140,9 @@ class SygusLearner(LearnerBase): ...@@ -119,9 +140,9 @@ class SygusLearner(LearnerBase):
ais.append(float_val) ais.append(float_val)
return return
if is_app_of(t, Z3_OP_ADD) and t.num_args() == 4: if is_app_of(t, Z3_OP_ADD) and t.num_args() == 4:
print("Addition term: "+ f"{t}") # print("Addition term: "+ f"{t}")
print("constant: "+ f"{simplify(t.arg(3)) }") # print("constant: "+ f"{simplify(t.arg(3)) }")
print("constant: "+ f"{type(simplify(t.arg(3))) }") # print("constant: "+ f"{type(simplify(t.arg(3))) }")
simp_arg = simplify(t.arg(3)) simp_arg = simplify(t.arg(3))
float_val = float(simp_arg.as_string()) float_val = float(simp_arg.as_string())
print(float_val) print(float_val)
...@@ -145,7 +166,7 @@ class SygusLearner(LearnerBase): ...@@ -145,7 +166,7 @@ class SygusLearner(LearnerBase):
print(ais_arr) print(ais_arr)
print(bis_arr) print(bis_arr)
print(radius) print(radius)
return (0, ais_arr, bis_arr, radius) return (ais_arr, bis_arr, radius)
@staticmethod @staticmethod
def convertExamplesToStringSygusFormat(*examples, pos_or_neg:int ): def convertExamplesToStringSygusFormat(*examples, pos_or_neg:int ):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment