Newer
Older
from typing import MutableMapping, Sequence, Optional, Tuple
import numpy as np
from z3 import *
from learner_base import LearnerBase
import re
class SygusLearner(LearnerBase):
def __init__(self, state_dim, perc_dim, timeout=10000):
self.file = 'shapes_split.sl'
self.template_file = 'template_sygus.sl'
shutil.copyfile(self.template_file, self.file)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def set_grammar(self, grammar) -> None:
raise NotImplementedError
def add_positive_examples(self, *args) -> None:
# TODO: For optimization purposes,
# read and then write
lines = []
pos = 0
synth_pos = None
with open(self.file, mode='r') as fin:
for line in fin:
if "(check-synth)" in line:
synth_pos = pos
lines.append(line)
pos += 1
prefix = lines[0:synth_pos]
converted = SygusLearner.convertExamplesToStringSygusFormat(*args, pos_or_neg=1)
prefix.extend(converted)
#print(prefix)
#print("======")
#print(lines[synth_pos:])
# Todo: shorter code
#z3_func = Function("inShape", RealSort(),RealSort(),RealSort(),RealSort(), RealSort(), BoolSort())
final = prefix + lines[synth_pos:]
with open(self.file, mode='w') as fout:
fout.writelines(final)
def add_negative_examples(self, *args) -> None:
lines = []
pos = 0
synth_pos = None
with open(self.file, mode='r') as fin:
for line in fin:
if "(check-synth)" in line:
synth_pos = pos
lines.append(line)
pos += 1
prefix = lines[0:synth_pos]
converted = SygusLearner.convertExamplesToStringSygusFormat(*args,pos_or_neg=2)
prefix.extend(converted)
final = prefix + lines[synth_pos:]
with open(self.file, mode='w') as fout:
fout.writelines(final)
def add_implication_examples(self, *args) -> None:
raise NotImplementedError
def learn(self):
x1 = Real('x1')
x2 = Real('x2')
x3 = Real('x3')
z1 = Real('z1')
z2 = Real('z2')
l2_norm = Function("L2_norm",RealSort(),RealSort(),RealSort() )
l1_norm = Function("L1_norm",RealSort(),RealSort(),RealSort() )
loo_norm = Function("Loo_norm",RealSort(),RealSort(),RealSort() )
split = Function("split",RealSort(),RealSort(),BoolSort(), BoolSort(), BoolSort())
# sqr = Function('sqr', RealSort(),RealSort() )
variable_map = { str(x1):x1, str(x2):x2,str(x3):x3, str(z1):z1, str(z2):z2,str(loo_norm):loo_norm ,str(l1_norm):l1_norm, str(l2_norm):l2_norm, str(split):split}
res = os.popen('bin/cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 '+ self.file).read()
expr_shape_body = SygusLearner.get_shape_and_Body(res)
print(f"expr body\n{expr_shape_body[1]}")
z3_vector_expr:AstVector = parse_smt2_string("(assert "+expr_shape_body[1]+" )", decls=variable_map)
#print(z3_vector_expr)
constants = SygusLearner.get_constants(z3_vector_expr[0])
#TODO: check that constants have valid values
conjecture: Tuple = (expr_shape_body[0],)+constants
return conjecture
# Gurobi encoding
# 0 is diamond
# 1 s circle
# 2 is squares
ret_index = cvc4_output.index("Bool")
body = cvc4_output[ret_index+4: len(cvc4_output)-2].lstrip()
ret = tuple()
if "L1_norm" in body:
#diamond
ret += (0,)
elif "L2_norm" in body:
#circle
ret += (1,)
elif "Loo_norm" in body:
#squares
ret += (2,)
ret +=(body,)
return ret
@staticmethod
def to_float(var: z3.ExprRef) -> float:
x2 = var.as_fraction()
return float(x2.numerator) / float(x2.denominator)
@staticmethod
def get_constants(e):
r = set()
ais_array = np.empty((2, 3), float)
bis_array = np.empty((1, 2), float)
ais = []
bis = []
radius = None
if is_app_of(e, Z3_OP_LE) and e.num_args() == 2:
print("radius: "+ str(e.arg(1)))
simp_arg = simplify(e.arg(1))
#float_val = float(simp_arg.as_string())
float_val = SygusLearner.to_float(simp_arg)
print(float_val)
radius = float_val
def collect(t):
if is_app_of(t, Z3_OP_MUL) and t.num_args() == 2:
# print("multiplication term: "+ f"{t}")
# print("constant: "+ f"{simplify(t.arg(0)) }")
# print("constant: "+ f"{type(simplify(t.arg(0))) }")
r.add(float_val)
ais.append(float_val)
return
if is_app_of(t, Z3_OP_ADD) and t.num_args() == 4:
# print("Addition term: "+ f"{t}")
# print("constant: "+ f"{simplify(t.arg(3)) }")
# print("constant: "+ f"{type(simplify(t.arg(3))) }")
r.add(float_val)
bis.append(float_val)
#bis_array.append(simplify(t.arg(3)))
for c in t.children():
collect(c)
else:
for c in t.children():
collect(c)
return
collect(e)
#print(ais)
ais_arr = np.array(ais)
ais_arr = np.reshape(ais_arr,(2,3))
bis_arr = np.array(bis)
print(ais_arr)
print(bis_arr)
print(radius)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
@staticmethod
def convertExamplesToStringSygusFormat(*examples, pos_or_neg:int ):
#assert pos_or_neg == 2 or pos_or_neg == 1
if pos_or_neg == 2:
prefix_sygus = "(constraint (not (inShape "
else:
prefix_sygus = "(constraint (inShape "
new_lines = []
for example in examples:
str_to_add = prefix_sygus
for val in example:
if abs(val) <= 10**-2:
val = 0.0
pat = r"\-(\d+\.\d*)"
s_val = str(val)
gres= re.match(pat,s_val)
if gres != None and len(gres.groups()) >= 1:
str_to_add += " "+"(- "+gres.groups()[0]+")"
else:
str_to_add += " "+s_val
str_to_add += ")"+ (")"*pos_or_neg)+"\n"
new_lines.extend(str_to_add)
return new_lines
if __name__ == "__main__":
l = SygusLearner(3,2)
l.add_positive_examples(*[(2.372100525863791, 0.9678239108338289, 0.21114638864503218, 3.1077494621276855, -0.169 )])
test = "(2.372100525863791, 0.9678239108338289, 0.21114638864503218, -3.1077494621276855, -0.169)"
# pat = r"\-(\d+\.\d*)"
# gres= re.search(pat,test)
# print(gres.groups())
#print(re.sub(pat,"(- "+gres.groups()[0]+")",test ) )