Skip to content
Snippets Groups Projects
sygus_learner.py 7.29 KiB
Newer Older
  • Learn to ignore specific revisions
  • from typing import MutableMapping, Sequence, Optional, Tuple
    
    import numpy as np
    from z3 import *
    
    from learner_base import LearnerBase
    import re
    
    import shutil
    
    
    class SygusLearner(LearnerBase):
        def __init__(self, state_dim, perc_dim, timeout=10000):
    
            self.file = 'shapes_split.sl'
            self.template_file = 'template_sygus.sl'
    
            shutil.copyfile(self.template_file, self.file)
    
    
        def set_grammar(self, grammar) -> None:
            raise NotImplementedError
    
        def add_positive_examples(self, *args) -> None:
            # TODO: For optimization purposes, 
            # read and then write
            lines = []
            pos = 0
            synth_pos = None
            with open(self.file, mode='r') as fin:
                for line in fin:
                    if "(check-synth)" in line:
                        synth_pos = pos
                    lines.append(line)
                    pos += 1
            prefix = lines[0:synth_pos]
            converted = SygusLearner.convertExamplesToStringSygusFormat(*args, pos_or_neg=1)
            prefix.extend(converted)
            #print(prefix)
            #print("======")
            #print(lines[synth_pos:])
            # Todo: shorter code
            #z3_func = Function("inShape", RealSort(),RealSort(),RealSort(),RealSort(), RealSort(), BoolSort())
            final = prefix + lines[synth_pos:]
    
            with open(self.file, mode='w') as fout:
                fout.writelines(final)
    
    
        def add_negative_examples(self, *args) -> None:
            lines = []
            pos = 0
            synth_pos = None
            with open(self.file, mode='r') as fin:
                for line in fin:
                    if "(check-synth)" in line:
                        synth_pos = pos
                    lines.append(line)
                    pos += 1
            prefix = lines[0:synth_pos]
            converted = SygusLearner.convertExamplesToStringSygusFormat(*args,pos_or_neg=2)
            prefix.extend(converted)
            
            final = prefix + lines[synth_pos:]
    
            with open(self.file, mode='w') as fout:
                fout.writelines(final)
        
        def add_implication_examples(self, *args) -> None:
            raise NotImplementedError
    
    
        def learn(self):
            x1 = Real('x1')
            x2 = Real('x2')
            x3 = Real('x3')
            z1 = Real('z1')
            z2 = Real('z2')
            l2_norm = Function("L2_norm",RealSort(),RealSort(),RealSort() )
            l1_norm = Function("L1_norm",RealSort(),RealSort(),RealSort() )
            loo_norm = Function("Loo_norm",RealSort(),RealSort(),RealSort() )
    
    aastorg2's avatar
    aastorg2 committed
            split = Function("split",RealSort(),RealSort(),BoolSort(), BoolSort(), BoolSort())
    
            # sqr = Function('sqr', RealSort(),RealSort() )
    
    aastorg2's avatar
    aastorg2 committed
            variable_map = { str(x1):x1, str(x2):x2,str(x3):x3, str(z1):z1, str(z2):z2,str(loo_norm):loo_norm ,str(l1_norm):l1_norm, str(l2_norm):l2_norm, str(split):split}
    
            res = os.popen('bin/cvc4-1.8 --sygus-out=sygus-standard --lang=sygus2 '+ self.file).read()
    
            print(f"cvc4 output {res}")
    
    aastorg2's avatar
    aastorg2 committed
            expr_shape_body = SygusLearner.get_shape_and_Body(res)
            
            print(f"expr body\n{expr_shape_body[1]}")
            z3_vector_expr:AstVector = parse_smt2_string("(assert "+expr_shape_body[1]+" )", decls=variable_map)
            #print(z3_vector_expr)
    
            constants = SygusLearner.get_constants(z3_vector_expr[0])
            #TODO: check that constants have valid values
    
    aastorg2's avatar
    aastorg2 committed
            conjecture: Tuple = (expr_shape_body[0],)+constants
            return conjecture
    
             
    
    
        @staticmethod
    
    aastorg2's avatar
    aastorg2 committed
        def get_shape_and_Body(cvc4_output:str):
    
            # Gurobi encoding
            # 0 is diamond
            # 1 s circle
            # 2 is squares
    
    aastorg2's avatar
    aastorg2 committed
            ret_index = cvc4_output.index("Bool")
            body = cvc4_output[ret_index+4: len(cvc4_output)-2].lstrip()
            ret = tuple()
            if "L1_norm" in body:
                #diamond
                ret += (0,)
            elif "L2_norm" in body:
                #circle
                ret += (1,)
            elif "Loo_norm" in body:
                #squares
                ret += (2,)
            ret +=(body,)
            return ret
        @staticmethod
        def to_float(var: z3.ExprRef) -> float:
            x2 = var.as_fraction()
            return float(x2.numerator) / float(x2.denominator)
    
    
        @staticmethod 
        def get_constants(e):
            r = set()
            ais_array = np.empty((2, 3), float)
            bis_array = np.empty((1, 2), float)
            ais = []
            bis = []
            radius = None
            if is_app_of(e, Z3_OP_LE) and e.num_args() == 2:
                print("radius: "+ str(e.arg(1)))
                simp_arg = simplify(e.arg(1))
    
    aastorg2's avatar
    aastorg2 committed
                #float_val = float(simp_arg.as_string())
                float_val = SygusLearner.to_float(simp_arg)
    
                print(float_val)
                radius = float_val
            
            def collect(t):
                if is_app_of(t, Z3_OP_MUL) and t.num_args() == 2:
    
    aastorg2's avatar
    aastorg2 committed
                    # print("multiplication term: "+ f"{t}")
                    # print("constant: "+ f"{simplify(t.arg(0)) }")
                    # print("constant: "+ f"{type(simplify(t.arg(0))) }")
    
                    simp_arg = simplify(t.arg(0))
    
    aastorg2's avatar
    aastorg2 committed
                    float_val = SygusLearner.to_float(simp_arg)
    
    aastorg2's avatar
    aastorg2 committed
                    #print(float_val)
    
                    r.add(float_val)
                    ais.append(float_val)
                    return
                if is_app_of(t, Z3_OP_ADD) and t.num_args() == 4:
    
    aastorg2's avatar
    aastorg2 committed
                    # print("Addition term: "+ f"{t}")
                    # print("constant: "+ f"{simplify(t.arg(3)) }")
                    # print("constant: "+ f"{type(simplify(t.arg(3))) }")
    
                    simp_arg = simplify(t.arg(3))
    
    aastorg2's avatar
    aastorg2 committed
                    float_val = SygusLearner.to_float(simp_arg)
    
    aastorg2's avatar
    aastorg2 committed
                    #print(float_val)
    
                    r.add(float_val)
                    bis.append(float_val)
                    #bis_array.append(simplify(t.arg(3)))
    
                    for c in t.children():
                        collect(c) 
                else:
                    for c in t.children():
                        collect(c) 
                return
    
            collect(e)
            #print(ais)
            ais_arr = np.array(ais)
            ais_arr = np.reshape(ais_arr,(2,3))
    
            bis_arr = np.array(bis)
            print(ais_arr)
            print(bis_arr)
            print(radius)
    
    aastorg2's avatar
    aastorg2 committed
            return (ais_arr, bis_arr, radius)
    
        
        @staticmethod
        def convertExamplesToStringSygusFormat(*examples, pos_or_neg:int ):
            #assert pos_or_neg == 2 or pos_or_neg == 1
            if pos_or_neg == 2:
                prefix_sygus = "(constraint (not (inShape "
            else: 
                prefix_sygus = "(constraint (inShape "
            new_lines = []
            for example in examples:
                str_to_add = prefix_sygus
                for val in example:
                    if abs(val) <= 10**-2:
                        val = 0.0
                    pat = r"\-(\d+\.\d*)"
                    s_val = str(val)
                    gres= re.match(pat,s_val)
                    
                    if gres != None and len(gres.groups()) >= 1:
                        str_to_add += " "+"(- "+gres.groups()[0]+")"
                    else:
                        str_to_add += " "+s_val
                str_to_add += ")"+ (")"*pos_or_neg)+"\n"
                new_lines.extend(str_to_add)
            return new_lines
    
    if __name__ == "__main__":
    
        l = SygusLearner(3,2)
        
        l.add_positive_examples(*[(2.372100525863791, 0.9678239108338289, 0.21114638864503218, 3.1077494621276855, -0.169 )])
    
        test = "(2.372100525863791, 0.9678239108338289, 0.21114638864503218, -3.1077494621276855, -0.169)"
        # pat = r"\-(\d+\.\d*)"
        # gres= re.search(pat,test)
        # print(gres.groups())
    
        #print(re.sub(pat,"(- "+gres.groups()[0]+")",test ) )