diff --git a/learner_base.py b/learner_base.py index 18ead2220a09c079eb2ca1d23c47ed0706843579..383a8de5a90351c42b4c6fe9d2802dd06066ed3a 100644 --- a/learner_base.py +++ b/learner_base.py @@ -1,5 +1,4 @@ import abc -from typing import Optional import z3 @@ -29,10 +28,15 @@ class LearnerBase(abc.ABC): class Z3LearnerBase(LearnerBase): - def __init__(self, state_dim, perc_dim) -> None: + def __init__(self, state_dim:int, perc_dim:int, timeout:int=10000) -> None: super().__init__() - self._solver = z3.SolverFor('QF_LRA') - self._solver.set(timeout=1000) + z3.set_param("timeout", timeout, + "solver.timeout", timeout, + "model.completion", True, + "smt.logic", "QF_LRA") + + self._solver = z3.AndThen('qfnra-nlsat', 'smt').solver() + # self._solver.set('ctrl_c', True) real_var_vec = z3.RealVarVector(state_dim + perc_dim) self._state_vars = real_var_vec[0:state_dim] self._percept_vars = real_var_vec[state_dim:state_dim+perc_dim] @@ -44,14 +48,3 @@ class Z3LearnerBase(LearnerBase): @property def perc_dim(self) -> int: return len(self._percept_vars) - - def learn(self) -> Optional[z3.ModelRef]: - res = self._solver.check() - if res == z3.sat: - return self._solver.model() - elif res == z3.unsat: - print("(unsat)") - return None - else: - print("(unknown)") - return None