Skip to content
Snippets Groups Projects
Commit ea841feb authored by aastorg2's avatar aastorg2
Browse files

sygus learner that roughly works

parent ba1ca4ec
No related branches found
No related tags found
No related merge requests found
from typing import MutableMapping, Sequence, Optional, Tuple
import numpy as np
from z3 import *
from learner_base import LearnerBase
import re
class SygusLearner(LearnerBase):
def __init__(self, state_dim, perc_dim, timeout=10000):
self.file = 'firstpass_learner.sl'
def set_grammar(self, grammar) -> None:
raise NotImplementedError
def add_positive_examples(self, *args) -> None:
# TODO: For optimization purposes,
# read and then write
lines = []
pos = 0
synth_pos = None
with open(self.file, mode='r') as fin:
for line in fin:
if "(check-synth)" in line:
synth_pos = pos
lines.append(line)
pos += 1
prefix = lines[0:synth_pos]
converted = SygusLearner.convertExamplesToStringSygusFormat(*args, pos_or_neg=1)
prefix.extend(converted)
#print(prefix)
#print("======")
#print(lines[synth_pos:])
# Todo: shorter code
#z3_func = Function("inShape", RealSort(),RealSort(),RealSort(),RealSort(), RealSort(), BoolSort())
final = prefix + lines[synth_pos:]
with open(self.file, mode='w') as fout:
fout.writelines(final)
def add_negative_examples(self, *args) -> None:
lines = []
pos = 0
synth_pos = None
with open(self.file, mode='r') as fin:
for line in fin:
if "(check-synth)" in line:
synth_pos = pos
lines.append(line)
pos += 1
prefix = lines[0:synth_pos]
converted = SygusLearner.convertExamplesToStringSygusFormat(*args,pos_or_neg=2)
prefix.extend(converted)
final = prefix + lines[synth_pos:]
with open(self.file, mode='w') as fout:
fout.writelines(final)
def add_implication_examples(self, *args) -> None:
raise NotImplementedError
def learn(self):
x1 = Real('x1')
x2 = Real('x2')
x3 = Real('x3')
z1 = Real('z1')
z2 = Real('z2')
l2_norm = Function("L2_norm",RealSort(),RealSort(),RealSort() )
l1_norm = Function("L1_norm",RealSort(),RealSort(),RealSort() )
loo_norm = Function("Loo_norm",RealSort(),RealSort(),RealSort() )
# sqr = Function('sqr', RealSort(),RealSort() )
variable_map = { str(x1):x1, str(x2):x2,str(x3):x3, str(z1):z1, str(z2):z2,str(loo_norm):loo_norm ,str(l1_norm):l1_norm, str(l2_norm):l2_norm}
res = os.popen('../cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 firstpass_learner.sl').read()
print(f"cvc4 output {res}")
expr_body = res[77:-2]#.replace("(* 1 x1)", "(* (- 1) x1)")
print(f"expr body\n{expr_body}")
z3_vector_expr:AstVector = parse_smt2_string("(assert "+expr_body+" )", decls=variable_map)
print(z3_vector_expr)
constants = SygusLearner.get_constants(z3_vector_expr[0])
#TODO: check that constants have valid values
if len(constants) == 4:
return constants
@staticmethod
def get_shape_and_Body:
# Gurobi encoding
# 0 is diamond
# 1 s circle
# 2 is squares
pass
@staticmethod
def get_constants(e):
r = set()
ais_array = np.empty((2, 3), float)
bis_array = np.empty((1, 2), float)
ais = []
bis = []
radius = None
if is_app_of(e, Z3_OP_LE) and e.num_args() == 2:
print("radius: "+ str(e.arg(1)))
simp_arg = simplify(e.arg(1))
float_val = float(simp_arg.as_string())
print(float_val)
radius = float_val
def collect(t):
if is_app_of(t, Z3_OP_MUL) and t.num_args() == 2:
print("multiplication term: "+ f"{t}")
print("constant: "+ f"{simplify(t.arg(0)) }")
print("constant: "+ f"{type(simplify(t.arg(0))) }")
simp_arg = simplify(t.arg(0))
float_val = float(simp_arg.as_string())
print(float_val)
r.add(float_val)
ais.append(float_val)
return
if is_app_of(t, Z3_OP_ADD) and t.num_args() == 4:
print("Addition term: "+ f"{t}")
print("constant: "+ f"{simplify(t.arg(3)) }")
print("constant: "+ f"{type(simplify(t.arg(3))) }")
simp_arg = simplify(t.arg(3))
float_val = float(simp_arg.as_string())
print(float_val)
r.add(float_val)
bis.append(float_val)
#bis_array.append(simplify(t.arg(3)))
for c in t.children():
collect(c)
else:
for c in t.children():
collect(c)
return
collect(e)
#print(ais)
ais_arr = np.array(ais)
ais_arr = np.reshape(ais_arr,(2,3))
bis_arr = np.array(bis)
print(ais_arr)
print(bis_arr)
print(radius)
return (0, ais_arr, bis_arr, radius)
@staticmethod
def convertExamplesToStringSygusFormat(*examples, pos_or_neg:int ):
#assert pos_or_neg == 2 or pos_or_neg == 1
if pos_or_neg == 2:
prefix_sygus = "(constraint (not (inShape "
else:
prefix_sygus = "(constraint (inShape "
new_lines = []
for example in examples:
str_to_add = prefix_sygus
for val in example:
if abs(val) <= 10**-2:
val = 0.0
pat = r"\-(\d+\.\d*)"
s_val = str(val)
gres= re.match(pat,s_val)
if gres != None and len(gres.groups()) >= 1:
str_to_add += " "+"(- "+gres.groups()[0]+")"
else:
str_to_add += " "+s_val
str_to_add += ")"+ (")"*pos_or_neg)+"\n"
new_lines.extend(str_to_add)
return new_lines
if __name__ == "__main__":
l = SygusLearner(3,2)
l.add_positive_examples(*[(2.372100525863791, 0.9678239108338289, 0.21114638864503218, 3.1077494621276855, -0.169 )])
test = "(2.372100525863791, 0.9678239108338289, 0.21114638864503218, -3.1077494621276855, -0.169)"
# pat = r"\-(\d+\.\d*)"
# gres= re.search(pat,test)
# print(gres.groups())
#print(re.sub(pat,"(- "+gres.groups()[0]+")",test ) )
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment