Skip to content
Snippets Groups Projects
Commit e85e1fad authored by chsieh16's avatar chsieh16
Browse files

Fix type info warnings

parent 1b6c9f8e
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ import itertools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Literal, MutableSet, Tuple
from typing import Any, Callable, Dict, List, Literal, Set, Tuple
import numpy as np
import z3
......@@ -16,13 +16,13 @@ class DTreeLearner(LearnerBase):
def __init__(self, state_dim: int, perc_dim: int,
timeout: int = 10000) -> None:
super().__init__()
self.debug_neg_conc = set() # type: MutableSet[Tuple[float,...]]
self.debug_neg_perc = set() # type: MutableSet[Tuple[float,...]]
self.debug_neg_conc = set() # type: Set[Tuple[float,...]]
self.debug_neg_perc = set() # type: Set[Tuple[float,...]]
self._state_dim: int = state_dim
self._perc_dim: int = perc_dim
self.count_neg_dup = 0
self._s2f_func = lambda x: x
self._cons_s2f_method = None
self._cons_s2f_method = None # type: Literal["concat", "diff", None]
# Given a base or derived feature name,
# returns a mapping from base feature names to coefficients
......@@ -137,7 +137,7 @@ class DTreeLearner(LearnerBase):
self._append_to_data_file(feature_vec_list, "true")
def add_negative_examples(self, *args) -> None:
#TODO: add precondition that args should not contains np.array type
# TODO: add precondition that args should not contains np.array type
# NOTE the size of nonrepeating_samp_list and nonrepeating_fv_list can be different.
if len(args) == 0:
return
......@@ -146,7 +146,7 @@ class DTreeLearner(LearnerBase):
]
if len(nonrepeating_samp_list) == 0:
raise ValueError(f"All negative examples {args} are repeated.")
fv_list = [
tuple(self._s2f_func(samp)) for samp in nonrepeating_samp_list
]
......@@ -268,6 +268,7 @@ def construct_sample_to_feature_func_by_concatenate(a_mat: np.ndarray, b_vec: np
# In this case, the output dimension is two times of perception dimension
return 2*perc_dim, sample_to_feature_vec
def construct_sample_to_feature_func_by_diff(a_mat: np.ndarray, b_vec: np.ndarray) \
-> Tuple[int, Callable[[np.ndarray], np.ndarray]]:
perc_dim, state_dim = a_mat.shape
......@@ -328,7 +329,7 @@ def test_sample_to_feature():
sample = np.array([1., 2., 3., -2., -3.])
# sample_to_feature_func will compute dBar and psiBar
feature_vec = sample_to_feature_func(sample)
assert len(feature_vec ) == feature_dim, "The dimension of feature vector should match."
assert len(feature_vec) == feature_dim, "The dimension of feature vector should match."
print("sample: " + str(feature_vec))
assert np.array_equal(feature_vec, np.array([0., 0.]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment