Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
cs598mp-fall2021-proj
Manage
Activity
Members
Labels
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Model registry
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
chsieh16
cs598mp-fall2021-proj
Commits
dc91fffb
Commit
dc91fffb
authored
2 years ago
by
chsieh16
Browse files
Options
Downloads
Patches
Plain Diff
Check safety and use z3 api to (de-)serialize
parent
b5a9397e
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dtree_synth.py
+1
-1
1 addition, 1 deletion
dtree_synth.py
monitoring_gem_stanley.py
+33
-29
33 additions, 29 deletions
monitoring_gem_stanley.py
with
34 additions
and
30 deletions
dtree_synth.py
+
1
−
1
View file @
dc91fffb
...
...
@@ -127,7 +127,7 @@ def synth_dtree_per_part(
"
feature_domain
"
:
feature_domain
,
"
ultimate_bound
"
:
ult_bound
,
"
status
"
:
"
found
"
,
"
result
"
:
{
"
k
"
:
k
,
"
formula
"
:
z3_expr
.
sexpr
(),
"
smtlib
"
:
smtlib
},
"
result
"
:
{
"
k
"
:
k
,
"
formula
"
:
z3_expr
.
sexpr
(),
"
smtlib
"
:
z3_expr
.
serialize
()
},
"
teacher time
"
:
time_info
[
0
],
"
learner time
"
:
time_info
[
1
]})
else
:
...
...
This diff is collapsed.
Click to expand it.
monitoring_gem_stanley.py
+
33
−
29
View file @
dc91fffb
...
...
@@ -14,7 +14,7 @@ def fp_to_real(val: float):
return
z3
.
fpToReal
(
z3
.
FPVal
(
val
,
z3
.
Float64
()))
def
build_
z3
_predicate
(
def
build_
aap
_predicate
(
state_vars
:
z3
.
ExprRef
,
perc_vars
:
z3
.
ExprRef
,
aap_json
)
->
z3
.
BoolRef
:
...
...
@@ -36,13 +36,7 @@ def build_z3_predicate(
else
:
assert
np
.
isposinf
(
ub
)
z3_astvec
=
z3
.
parse_smt2_string
(
m
[
"
smtlib
"
])
if
len
(
z3_astvec
)
==
0
:
assert
m
[
'
formula
'
]
==
"
true
"
dtree
=
z3
.
BoolVal
(
True
)
else
:
assert
len
(
z3_astvec
)
==
1
dtree
=
z3_astvec
[
0
]
dtree
=
z3
.
deserialize
(
m
[
"
smtlib
"
])
aap_formula_list
.
append
(
z3
.
And
(
*
pre_list
,
dtree
))
...
...
@@ -61,9 +55,7 @@ def monitoring(
gt_vars
=
z3
.
Reals
(
gt_var_names
)
gte_vars
=
z3
.
Reals
(
gte_var_names
)
z3_astvec
=
z3
.
parse_smt2_string
(
aap_pred_smtlib
)
assert
len
(
z3_astvec
)
==
1
aap_pred
=
z3_astvec
[
0
]
aap_pred
=
z3
.
deserialize
(
aap_pred_smtlib
)
# Extract only relevant fields and order the fields correctly
gt_list
=
gt_trace_arr
[[
'
cte
'
,
'
psi
'
]].
tolist
()
...
...
@@ -77,7 +69,7 @@ def monitoring(
z3
.
simplify
(
z3
.
substitute
(
aap_pred
,
*
subs
))
)
bool_list
=
[
z3
.
is_true
(
b
)
for
b
in
bool_list
]
return
sum
(
bool_list
),
len
(
bool_list
)
return
bool_list
def
main
():
...
...
@@ -85,39 +77,51 @@ def main():
AAP_JSON_FILE
=
"
diff0_0/out/dtree_synth.4x10.out.json
"
TRACE_PKL_FILE
=
"
data/gem_stanley-straight-1000_traces-500_psicte.pickle
"
state_vars
=
[
z3
.
Real
(
f
"
x_
{
i
}
"
)
for
i
in
range
(
3
)]
perc_vars
=
[
z3
.
Real
(
f
"
z_
{
i
}
"
)
for
i
in
range
(
2
)]
with
open
(
TRACE_PKL_FILE
,
"
rb
"
)
as
pkl
:
trace_pairs
=
pickle
.
load
(
pkl
)
with
open
(
AAP_JSON_FILE
)
as
json_fp
:
data
=
json
.
load
(
json_fp
)
aap_pred
=
build_z3_predicate
(
state_vars
,
perc_vars
,
data
)
print
(
"
# Traces:
"
,
len
(
trace_pairs
))
print
(
"
# States in each trace:
"
,
len
(
trace_pairs
[
0
][
1
]))
# NOTE temporarily convert state/latent variables to gt variables
gt_var_names
=
[
"
d
"
,
"
psi
"
]
gte_var_names
=
[
"
d_e
"
,
"
psi_e
"
]
gt_vars
=
z3
.
Reals
(
gt_var_names
)
gte_vars
=
z3
.
Reals
(
gte_var_names
)
safety_pred
=
z3
.
Abs
(
gt_vars
[
0
])
<=
1.6
print
(
"
Safety predicate:
"
,
safety_pred
)
in_safe_trace_list
=
Parallel
(
NUM_JOBS
)(
delayed
(
monitoring
)(
gt_arr
,
gte_arr
,
gt_var_names
,
gte_var_names
,
safety_pred
.
serialize
())
for
gt_arr
,
gte_arr
in
trace_pairs
)
num_safe_total_list
=
[(
sum
(
bool_trace
),
len
(
bool_trace
))
for
bool_trace
in
in_safe_trace_list
]
safe_arr
,
total_arr
=
np
.
array
(
num_safe_total_list
).
T
safe_rate
=
safe_arr
/
total_arr
print
(
f
"
% states in safe:
{
np
.
mean
(
safe_rate
)
*
100
:
.
2
f
}
%
"
)
state_vars
=
[
z3
.
Real
(
f
"
x_
{
i
}
"
)
for
i
in
range
(
3
)]
perc_vars
=
[
z3
.
Real
(
f
"
z_
{
i
}
"
)
for
i
in
range
(
2
)]
with
open
(
AAP_JSON_FILE
)
as
json_fp
:
data
=
json
.
load
(
json_fp
)
aap_pred
=
build_aap_predicate
(
state_vars
,
perc_vars
,
data
)
# NOTE temporarily convert state/latent variables to gt variables
subs
=
[(
x
,
-
gt
)
for
x
,
gt
in
zip
(
state_vars
[
1
:],
gt_vars
)]
+
\
[(
p
,
gte
)
for
p
,
gte
in
zip
(
perc_vars
,
gte_vars
)]
aap_pred
=
z3
.
simplify
(
z3
.
substitute
(
aap_pred
,
*
subs
))
# Convert to SMTLib string for pickling and multiprocessing
solver
=
z3
.
Solver
()
solver
.
add
(
aap_pred
)
aap_pred_smtlib
=
solver
.
to_smt2
()
with
open
(
TRACE_PKL_FILE
,
"
rb
"
)
as
pkl
:
trace_pairs
=
pickle
.
load
(
pkl
)
num_pass_total_list
=
Parallel
(
NUM_JOBS
)(
in_aap_trace_list
=
Parallel
(
NUM_JOBS
)(
delayed
(
monitoring
)(
gt_arr
,
gte_arr
,
gt_var_names
,
gte_var_names
,
aap_pred
_smtlib
)
aap_pred
.
serialize
()
)
for
gt_arr
,
gte_arr
in
trace_pairs
)
num_pass_total_list
=
[(
sum
(
bool_trace
),
len
(
bool_trace
))
for
bool_trace
in
in_aap_trace_list
]
pass_arr
,
total_arr
=
np
.
array
(
num_pass_total_list
).
T
pass_rate
=
pass_arr
/
total_arr
print
(
np
.
mean
(
pass_rate
))
print
(
f
"
% states in AAP:
{
np
.
mean
(
pass_rate
)
*
100
:
.
2
f
}
%
"
)
if
__name__
==
"
__main__
"
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment