Skip to content
Snippets Groups Projects
dtree_learner.py 8.01 KiB
Newer Older
from collections import OrderedDict
import itertools
import json
import os
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from learner_base import LearnerBase
class DTreeLearner(LearnerBase):
    def __init__(self, state_dim: int, perc_dim: int,
                 timeout: int = 10000) -> None:
        super().__init__()
        self._state_dim: int = state_dim
        self._perc_dim: int = perc_dim

        self._s2f_func = lambda x: x

        # Given a (derived) feature name,
        # this map returns a mapping from base feature names to coefficients
        self._var_coeff_map: Dict[str, Dict[str, int]] = {}
        # check directory name exists, if not create it.
        self.dir_name = "out"
        if not os.path.isdir(self.dir_name):
            os.makedirs(self.dir_name)
        path_prefix = self.dir_name+"/pre"
        self.data_file = path_prefix + ".data"
        self.names_file = path_prefix + ".names"
        self.tree_out = path_prefix + ".json"
        # create empty data files or truncate existing data in files
        open(self.data_file, 'w').close()
        self.exec = f'c50exact/c5.0dbg  -I 1 -m 1 -f {path_prefix}'
    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def perc_dim(self) -> int:
        return self._perc_dim

    def set_grammar(self, grammar) -> None:
        base_features: List[str] = []
        derived_feature_map = OrderedDict()
        s2f_func_list = []
        for i, (a_mat, b_vec) in enumerate(grammar):
            s2f_func_list.append(
                construct_sample_to_feature_func(a_mat, b_vec))
            ith_vars = [f"fvar{j}_A{i}" for j in range(self.perc_dim)]
            base_features.extend(ith_vars)
            derived_feature_map.update(
                self._generate_derived_features(ith_vars))
            self._var_coeff_map.update({
                var: {var: 1} for var in base_features
            })

        self._s2f_func = self._compose_s2f_functions(s2f_func_list)
        file_lines = ["precondition."] + \
            [f"{var}:  continuous." for var in base_features] + \
            [f"{var} := {expr}." for var, (_, expr) in derived_feature_map.items()] + \
            ["precondition:  true, false."]
        with open(self.names_file, "w") as f:
            f.write('\n'.join(file_lines))

    @staticmethod
    def _compose_s2f_functions(s2f_func_list):
        def composed_func(sample):
            return sum((list(f(sample)) for f in s2f_func_list),
                       start=[])
        return composed_func

    @staticmethod
    def _generate_derived_features(
            base_vars: List[str], k: int = 2) -> List[Tuple[str, Tuple[Any, str]]]:
        res = []
        for var in base_vars:
            var_coeff_map = {var: -1}
            expr = f"(-1*{var})"
            name = expr
            res.append((name, (var_coeff_map, expr)))

        if len(base_vars) < k:
            return res

        coeff_combinations = list(itertools.product([1, -1], repeat=k))
        var_id_iter = range(len(base_vars))
        for selected_var_ids in itertools.combinations(var_id_iter, k):
            for coeff in coeff_combinations:
                var_coeff_map = {base_vars[i]: c
                                 for c, i in zip(coeff, selected_var_ids)}
                expr = " + ".join(f"({c}*{base_vars[i]})"
                                  for c, i in zip(coeff, selected_var_ids))
                name = f"({expr})"
                res.append((name, (var_coeff_map, expr)))
        return res
    def add_implication_examples(self, *args) -> None:
        return super().add_implication_examples(*args)

    def add_positive_examples(self, *args) -> None:
        feature_vec_list = [self._s2f_func(sample) for sample in args]
        print("Positive feature vectors:", feature_vec_list)
        self._append_to_data_file(feature_vec_list, "true")

    def add_negative_examples(self, *args) -> None:
        feature_vec_list = [self._s2f_func(sample) for sample in args]

        print("Negative feature vectors:", feature_vec_list)
        self._append_to_data_file(feature_vec_list, "false")

    def _append_to_data_file(self, feature_vec_list, label: str):
        with open(self.data_file, 'a') as d_file:
            data_out = csv.writer(d_file)
            for f in feature_vec_list:
                print(f)
                data_out.writerow(f+[label])

    def learn(self) -> Tuple:
        res = os.popen(self.exec).read()
        print(res)
        assert os.path.exists(self.tree_out), "if learned successfully" \
            f"there should be a json file in {self.dir_name}"

        result = self.get_pre_from_json(self.tree_out)

        for conjunct in result:
            for var, op, cut in conjunct:
                pass
        print(result)

    def get_pre_from_json(self, path):
        try:
            with open(path) as json_file:
                tree = json.load(json_file)
                return self.parse_tree(tree)
        except json.JSONDecodeError:
            raise ValueError(f"cannot parse {path} as a json file")

    def parse_tree(self, tree) -> Optional[List[List]]:
        if tree['children'] is None:
            # At a leaf node, return the clause
            if tree['classification']:
                return [[]]  # Non-none value represtns a True leaf node
                return None
        elif len(tree['children']) == 2:
            # Post-order traversal
            left = self.parse_tree(tree['children'][0])
            right = self.parse_tree(tree['children'][1])

            if left is None and right is None:
                return None
            res_left = []
            if left is not None:
                res_left = [[(tree['attribute'], "<=", tree['cut'])] + conjunct
                            for conjunct in left]
            res_right = []
            if right is not None:
                res_right = [[(tree['attribute'], ">", tree['cut'])] + conjunct
                             for conjunct in right]
            assert res_left or res_right
            return res_left + res_right
        else:
            raise ValueError("error parsing the json object as a binary decision tree)")
def construct_sample_to_feature_func(a_mat: np.ndarray, b_vec: np.ndarray):
    perc_dim, state_dim = a_mat.shape

    def sample_to_feature_vec(sample):
        assert len(sample) == state_dim + perc_dim
        state = sample[0: state_dim]
        perc = sample[state_dim: state_dim+perc_dim]
        perc_bar = perc - (np.dot(state, a_mat.T) + b_vec)
        return perc_bar
    return sample_to_feature_vec

def test_dtree_learner():
    a_mat = np.array([[0., -1., 0.],
                      [0., 0., -1]])
    b_vec = np.zeros(2)

    learner = DTreeLearner(state_dim=3, perc_dim=2)
    learner.set_grammar([(a_mat, b_vec)])

    pos_examples = [
        (1., 2., 3., -2., -3.),
        (1., 2., 3., -1., -2.)
    ]
    learner.add_positive_examples(*pos_examples)
    neg_examples = [
        (10., 1.0, 1.0, 0.5, 0.5),
        (10., 1.0, 1.0, 1.5, 1.5),
        (10., 9.0, 9.0, 5.0, 5.0),
    ]
    learner.add_negative_examples(*neg_examples)

    learner.learn()
def test_sample_to_feature():
    # tuple
    a_mat = np.array([[0., -1., 0.],
                      [0., 0., -1]])
    b_vec = np.zeros(2)

    # construct_sample_to_feature_func: returns a function
    # map: lin_trans(a_mat and b_vec pair) -> func
    sample_to_feature_func = construct_sample_to_feature_func(a_mat, b_vec)

    # map = {name1:sample_to_feature_func}
    sample = np.array([1., 2., 3., -2., -3.])
    # sample_to_feature_func will compute dBar and psiBar
    feature_vec = sample_to_feature_func(sample)
    print("sample: " + str(feature_vec))
    assert np.array_equal(feature_vec, np.array([0., 0.]))

    sample = np.array([1., 2., 3., -1., -2.])
    feature_vec = sample_to_feature_func(sample)
    print("sample: " + str(feature_vec))
    assert np.array_equal(feature_vec, np.array([1., 1.]))


if __name__ == "__main__":
    # test_sample_to_feature()
    test_dtree_learner()