Skip to content
Snippets Groups Projects
dtree_learner.py 1.03 KiB
Newer Older
from typing import Tuple

import numpy as np

STATE_DIM = 3
PERC_DIM = 2


def construct_sample_to_feature_func(lin_trans):
    a_mat, b_vec = lin_trans

    def sample_to_feature_vec(sample):
        assert len(sample) == STATE_DIM + PERC_DIM
        state = sample[0: STATE_DIM]
        perc = sample[STATE_DIM: STATE_DIM+PERC_DIM]
        perc_bar = perc - np.dot(state, a_mat.T) + b_vec
        return perc_bar
    return sample_to_feature_vec


def test_sample_to_feature():
    lin_trans = (
        np.array([[0., -1., 0.],  # a_mat
                  [0., 0., -1]]),
        np.zeros(2)  # b_vec
    )
    sample_to_feature_func = construct_sample_to_feature_func(lin_trans)

    sample = np.array([1., 2., 3., -2., -3.])
    feature_vec = sample_to_feature_func(sample)
    assert np.array_equal(feature_vec, np.array([0., 0.]))

    sample = np.array([1., 2., 3., -1., -2.])
    feature_vec = sample_to_feature_func(sample)
    assert np.array_equal(feature_vec, np.array([1., 1.]))


if __name__ == "__main__":
    test_sample_to_feature()