Skip to content
Snippets Groups Projects
Commit eb62c8a4 authored by chsieh16's avatar chsieh16
Browse files

Store mapping from all features to coefficients of base features

parent aa451672
No related branches found
No related tags found
No related merge requests found
......@@ -49,7 +49,7 @@ class DTreeLearner(LearnerBase):
def set_grammar(self, grammar) -> None:
base_features: List[str] = []
derived_feature_map = OrderedDict()
derived_feature_map: Dict[str, Tuple[Dict, str]] = OrderedDict()
s2f_func_list = []
for i, (a_mat, b_vec) in enumerate(grammar):
......@@ -59,11 +59,19 @@ class DTreeLearner(LearnerBase):
base_features.extend(ith_vars)
derived_feature_map.update(
self._generate_derived_features(ith_vars))
self._var_coeff_map.update({
var: {var: 1} for var in base_features
})
# Store mapping from all feature names to coefficients of base features
self._var_coeff_map.update({
var: {var: 1} for var in base_features
})
self._var_coeff_map.update({
var: coeff_map for var, (coeff_map, _) in derived_feature_map.items()
})
# One sample to feature vector function for many linear transformations
self._s2f_func = self._compose_s2f_functions(s2f_func_list)
# Write names file
file_lines = ["precondition."] + \
[f"{var}: continuous." for var in base_features] + \
[f"{var} := {expr}." for var, (_, expr) in derived_feature_map.items()] + \
......@@ -193,7 +201,10 @@ def test_dtree_learner():
b_vec = np.zeros(2)
learner = DTreeLearner(state_dim=3, perc_dim=2)
learner.set_grammar([(a_mat, b_vec)])
learner.set_grammar([(a_mat, b_vec), (a_mat, b_vec)])
print(*learner._var_coeff_map.items(), sep='\n')
return
pos_examples = [
(1., 2., 3., -2., -3.),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment