from collections import OrderedDict
import csv
import itertools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Literal, MutableSet, Tuple

import numpy as np
import z3

from learner_base import LearnerBase


class DTreeLearner(LearnerBase):
    def __init__(self, state_dim: int, perc_dim: int,
                 timeout: int = 10000) -> None:
        super().__init__()
        self.debug_neg_conc = set()  # type: MutableSet[Tuple[float,...]]
        self.debug_neg_perc = set()  # type: MutableSet[Tuple[float,...]]
        self._state_dim: int = state_dim
        self._perc_dim: int = perc_dim
        self.count_neg_dup = 0
        self._s2f_func = lambda x: x
        self._cons_s2f_method = None

        # Given a base or derived feature name,
        # returns a mapping from base feature names to coefficients
        self._var_coeff_map: Dict[str, Dict[str, int]] = {}
        # Given a base feature name,
        # this map returns the affine transformation provided in the grammar
        self._basevar_trans_map: Dict[str, Tuple[Any, int]] = {}

        # check directory name exists, if not create it.
        self.dir_name = "out"
        if not os.path.isdir(self.dir_name):
            os.makedirs(self.dir_name)

        path_prefix = self.dir_name+"/pre"
        self.data_file = path_prefix + ".data"
        self.names_file = path_prefix + ".names"
        self.tree_out = path_prefix + ".json"

        # create empty data files or truncate existing data in files
        open(self.data_file, 'w').close()

        self.exec = f'c50exact/c5.0dbg  -I 1 -m 1 -f {path_prefix}'

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def perc_dim(self) -> int:
        return self._perc_dim

    def set_grammar(self, grammar, s2f_method: Literal["concat", "diff"] = "diff") -> None:
        base_features: List[str] = []
        derived_feature_map: Dict[str, Tuple[Dict, str]] = OrderedDict()

        self._cons_s2f_method = s2f_method
        if s2f_method == "concat":
            construct_s2f = construct_sample_to_feature_func_by_concatenate
        elif s2f_method == "diff":
            construct_s2f = construct_sample_to_feature_func_by_diff
        else:
            raise ValueError(f"Unknown method '{s2f_method}'"
                             "to construct sample to feature func")

        s2f_func_list = []
        for i, trans in enumerate(grammar):
            feature_dim, s2f_func = construct_s2f(*trans)
            s2f_func_list.append(s2f_func)
            ith_vars = [f"fvar{j}_A{i}" for j in range(feature_dim)]

            self._basevar_trans_map.update([(var, (trans, j)) for j, var in enumerate(ith_vars)])

            base_features.extend(ith_vars)
            derived_feature_map.update(
                self._generate_derived_features(ith_vars))

        # Store mapping from all feature names to coefficients of base features
        self._var_coeff_map.update([
            (var, {var: 1}) for var in base_features
        ])
        self._var_coeff_map.update([
            (var, coeff_map) for var, (coeff_map, _) in derived_feature_map.items()
        ])

        # One sample to feature vector function for many linear transformations
        self._s2f_func = self._compose_s2f_functions(s2f_func_list)

        # Write names file
        file_lines = ["precondition."] + \
            [f"{var}:  continuous." for var in base_features] + \
            [f"{var} := {expr}." for var, (_, expr) in derived_feature_map.items()] + \
            ["precondition:  true, false."]
        with open(self.names_file, "w") as f:
            f.write('\n'.join(file_lines))

    @staticmethod
    def _compose_s2f_functions(s2f_func_list):
        def composed_func(sample):
            return sum((list(f(sample)) for f in s2f_func_list), [])
        return composed_func

    @staticmethod
    def _generate_derived_features(
            base_vars: List[str], k: int = 2) -> List[Tuple[str, Tuple[Any, str]]]:
        res = []
        for var in base_vars:
            var_coeff_map = {var: -1}
            expr = f"(-1*{var})"
            name = expr
            res.append((name, (var_coeff_map, expr)))

        if len(base_vars) < k:
            return res

        coeff_combinations = list(itertools.product([1, -1], repeat=k))
        var_id_iter = range(len(base_vars))
        for selected_var_ids in itertools.combinations(var_id_iter, k):
            for coeff in coeff_combinations:
                var_coeff_map = {base_vars[i]: c
                                 for c, i in zip(coeff, selected_var_ids)}
                expr = " + ".join(f"({c}*{base_vars[i]})"
                                  for c, i in zip(coeff, selected_var_ids))
                name = f"({expr})"
                res.append((name, (var_coeff_map, expr)))
        return res

    def add_implication_examples(self, *args) -> None:
        return super().add_implication_examples(*args)

    def add_positive_examples(self, *args) -> None:
        feature_vec_list = [self._s2f_func(sample) for sample in args]
        self._append_to_data_file(feature_vec_list, "true")

    def add_negative_examples(self, *args) -> None:
        # NOTE the size of nonrepeating_samp_list and nonrepeating_fv_list can be different.
        if len(args) == 0:
            return
        nonrepeating_samp_list = [
            samp for samp in args if samp not in self.debug_neg_conc
        ]
        if len(nonrepeating_samp_list) == 0:
            raise ValueError(f"All negative examples {args} are repeated.")
  
        fv_list = [
            tuple(self._s2f_func(samp)) for samp in nonrepeating_samp_list
        ]
        nonrepeating_fv_list = [
            fv for fv in fv_list if fv not in self.debug_neg_perc
        ]
        if len(nonrepeating_fv_list) == 0:
            raise ValueError(f"All negative feature vectors {fv_list} are repeated.")

        self.debug_neg_perc.update(nonrepeating_fv_list)
        self.debug_neg_conc.update(nonrepeating_samp_list)

        # print(f"number of negative duplicate {self.count_neg_dup}")
        feature_vec_list = [self._s2f_func(sample) for sample in args]

        print("Negative feature vectors:", feature_vec_list)
        self._append_to_data_file(feature_vec_list, "false")

    def _append_to_data_file(self, feature_vec_list, label: str):
        with open(self.data_file, 'a') as d_file:
            data_out = csv.writer(d_file)
            for f in feature_vec_list:
                row = itertools.chain(f, [label])  # append label at the end of each row
                data_out.writerow(row)

    def learn(self) -> z3.BoolRef:
        res = os.popen(self.exec).read()
        assert os.path.exists(self.tree_out), "if learned successfully" \
            f"there should be a json file in {self.dir_name}"

        ite_expr = self.get_pre_from_json(self.tree_out)
        os.remove(self.tree_out)  # Remove the generated json to avoid reusing old trees

        # FIXME need to ensure the substitution matches the construction used in set_grammar
        if self._cons_s2f_method == "concat":
            ite_expr = self._subs_basevar_w_states_by_concatenate(ite_expr)
        elif self._cons_s2f_method == "diff":
            ite_expr = self._subs_basevar_w_states_by_diff(ite_expr)
        else:
            raise RuntimeError(f"Unknown method '{self._cons_s2f_method}' to reconstruct the candidate from the decision tree")
        return ite_expr

    def _subs_basevar_w_states_by_concatenate(self, ite_expr) -> z3.BoolRef:
        state_vars = z3.Reals([f"x_{i}" for i in range(self.state_dim)])
        perc_vars = z3.Reals([f"z_{i}" for i in range(self.perc_dim)])
        subs_basevar = []
        for basevar, (trans, j) in self._basevar_trans_map.items():
            if j < self.perc_dim:
                a_mat, b_vec = trans
                expanded_basevar = ((a_mat @ state_vars)[j] + b_vec[j])
            else:
                assert j < 2*self.perc_dim
                expanded_basevar = perc_vars[j - self.perc_dim]
            expanded_basevar = z3.simplify(expanded_basevar)
            subs_basevar.append((z3.Real(basevar), expanded_basevar))
        return z3.substitute(ite_expr, *subs_basevar)

    def _subs_basevar_w_states_by_diff(self, ite_expr) -> z3.BoolRef:
        state_vars = z3.Reals([f"x_{i}" for i in range(self.state_dim)])
        perc_vars = z3.Reals([f"z_{i}" for i in range(self.perc_dim)])
        subs_basevar = []
        for basevar, (trans, j) in self._basevar_trans_map.items():
            a_mat, b_vec = trans
            expanded_basevar = perc_vars[j] - ((a_mat @ state_vars)[j] + b_vec[j])
            expanded_basevar = z3.simplify(expanded_basevar)
            subs_basevar.append((z3.Real(basevar), expanded_basevar))
        return z3.substitute(ite_expr, *subs_basevar)

    def get_pre_from_json(self, path):
        try:
            with open(path) as json_file:
                tree = json.load(json_file)
                return self.parse_tree(tree)
        except json.JSONDecodeError:
            raise ValueError(f"cannot parse {path} as a json file")

    def parse_tree(self, tree) -> z3.BoolRef:
        if tree['children'] is None:
            # At a leaf node, return the clause
            if tree['classification']:
                return z3.BoolVal(True)  # True leaf node
            else:
                return z3.BoolVal(False)  # False leaf node
        elif len(tree['children']) == 2:
            # Post-order traversal
            left = self.parse_tree(tree['children'][0])
            right = self.parse_tree(tree['children'][1])
            # Create an ITE expression tree
            z3_expr = z3.Sum(*(coeff*z3.Real(base_fvar) for base_fvar, coeff
                               in self._var_coeff_map[tree['attribute']].items()))
            z3_cut = z3.simplify(z3.fpToReal(z3.FPVal(tree['cut'], z3.Float64())))
            if z3.is_true(left):
                if z3.is_true(right):
                    return z3.BoolVal(True)
                elif z3.is_false(right):
                    return (z3_expr <= z3_cut)
            if z3.is_false(left):
                if z3.is_true(right):
                    return (z3_expr > z3_cut)
                elif z3.is_false(right):
                    return z3.BoolVal(False)
            # else:
            return z3.If((z3_expr <= z3_cut), left, right)
        else:
            raise ValueError("error parsing the json object as a binary decision tree)")


def construct_sample_to_feature_func_by_concatenate(a_mat: np.ndarray, b_vec: np.ndarray) \
        -> Tuple[int, Callable[[np.ndarray], np.ndarray]]:
    perc_dim, state_dim = a_mat.shape

    def sample_to_feature_vec(sample):
        assert len(sample) == state_dim + perc_dim
        state = np.array(sample[0: state_dim])
        perc = np.array(sample[state_dim: state_dim+perc_dim])
        return np.concatenate(((a_mat @ state + b_vec), perc), axis=0)
    # NOTE: Ensure the dimension of output matches.
    # In this case, the output dimension is two times of perception dimension
    return 2*perc_dim, sample_to_feature_vec

def construct_sample_to_feature_func_by_diff(a_mat: np.ndarray, b_vec: np.ndarray) \
        -> Tuple[int, Callable[[np.ndarray], np.ndarray]]:
    perc_dim, state_dim = a_mat.shape

    def sample_to_feature_vec(sample):
        assert len(sample) == state_dim + perc_dim
        state = np.array(sample[0: state_dim])
        perc = np.array(sample[state_dim: state_dim+perc_dim])
        return perc - (a_mat @ state + b_vec)
    # NOTE: Ensure the dimension of output matches
    # In this case, the output dimension is exactly the perception dimension
    return perc_dim, sample_to_feature_vec


def test_dtree_learner():
    a_mat_0 = np.array([[0., -1., 0.],
                        [0., 0., -1.]])
    b_vec_0 = np.zeros(2)

    a_mat_1 = np.array([[0., -0.75, 0.],
                        [0., 0., -1.25]])
    b_vec_1 = np.zeros(2)

    learner = DTreeLearner(state_dim=3, perc_dim=2)
    learner.set_grammar([(a_mat_0, b_vec_0), (a_mat_1, b_vec_1)])

    logging.debug(*learner._basevar_trans_map.items(), sep='\n')
    logging.debug(*learner._var_coeff_map.items(), sep='\n')

    pos_examples = [
        (1., 2., 3., -2., -3.),
        (1., 2., 3., -1., -2.)
    ]
    learner.add_positive_examples(*pos_examples)

    neg_examples = [
        (10., 1.0, 1.0, 0.5, 0.5),
        (10., 1.0, 1.0, 1.5, 1.5),
        (10., 9.0, 9.0, 5.0, 5.0),
    ]
    learner.add_negative_examples(*neg_examples)

    print("Learned ITE expression:", learner.learn())


def test_sample_to_feature():
    # tuple
    a_mat = np.array([[0., -1., 0.],
                      [0., 0., -1]])
    b_vec = np.zeros(2)

    # construct_sample_to_feature_func: returns a function
    # map: lin_trans(a_mat and b_vec pair) -> func
    feature_dim, sample_to_feature_func = \
        construct_sample_to_feature_func_by_diff(a_mat, b_vec)

    # map = {name1:sample_to_feature_func}
    sample = np.array([1., 2., 3., -2., -3.])
    # sample_to_feature_func will compute dBar and psiBar
    feature_vec = sample_to_feature_func(sample)
    assert len(feature_vec ) == feature_dim, "The dimension of feature vector should match."
    print("sample: " + str(feature_vec))
    assert np.array_equal(feature_vec, np.array([0., 0.]))

    sample = np.array([1., 2., 3., -1., -2.])
    feature_vec = sample_to_feature_func(sample)
    print("sample: " + str(feature_vec))
    assert np.array_equal(feature_vec, np.array([1., 1.]))


def test_parse_json_with_s2f_by_concat():
    json_obj = json.loads("""
    {"attribute":"((1*fvar0_A0) + (1*fvar1_A0))","cut":-0.125,"classification":0,
     "children":[
         {"attribute":"fvar2_A0","cut":0.375,"classification":0,
          "children":[{"attribute":"","cut":0,"classification":true,"children":null},
                      {"attribute":"","cut":0,"classification":false,"children":null}]
         },
         {"attribute":"fvar3_A1","cut":-0.5,"classification":0,
          "children":[{"attribute":"","cut":0,"classification":true,"children":null},
                      {"attribute":"","cut":0,"classification":false,"children":null}]
         }
     ]
    }""")
    a_mat_0 = np.array([[0., -1., 0.],
                        [0., 0., -1.]])
    b_vec_0 = np.zeros(2)

    a_mat_1 = np.array([[0., -0.75, 0.],
                        [0., 0., -1.25]])
    b_vec_1 = np.zeros(2)

    learner = DTreeLearner(state_dim=3, perc_dim=2)
    learner.set_grammar([(a_mat_0, b_vec_0), (a_mat_1, b_vec_1)], s2f_method="concat")
    tree = learner.parse_tree(json_obj)
    print(learner._subs_basevar_w_states_by_concatenate(tree))


def test_parse_json_with_s2f_by_diff():
    json_obj = json.loads("""
    {"attribute":"((1*fvar0_A0) + (1*fvar1_A0))","cut":-0.125,"classification":0,
     "children":[
         {"attribute":"fvar1_A0","cut":0.625,"classification":0,
          "children":[{"attribute":"","cut":0,"classification":true,"children":null},
                      {"attribute":"","cut":0,"classification":false,"children":null}]
         },
         {"attribute":"fvar1_A1","cut":-1.5,"classification":0,
          "children":[{"attribute":"","cut":0,"classification":true,"children":null},
                      {"attribute":"","cut":0,"classification":false,"children":null}]
         }
     ]
    }""")
    a_mat_0 = np.array([[0., -1., 0.],
                        [0., 0., -1.]])
    b_vec_0 = np.zeros(2)

    a_mat_1 = np.array([[0., -0.75, 0.],
                        [0., 0., -1.25]])
    b_vec_1 = np.zeros(2)

    learner = DTreeLearner(state_dim=3, perc_dim=2)
    learner.set_grammar([(a_mat_0, b_vec_0), (a_mat_1, b_vec_1)], s2f_method="diff")
    tree = learner.parse_tree(json_obj)
    print(learner._subs_basevar_w_states_by_diff(tree))


if __name__ == "__main__":
    test_sample_to_feature()
    test_dtree_learner()
    test_parse_json_with_s2f_by_concat()
    test_parse_json_with_s2f_by_diff()