From dcbcd07c7f5c8b7de2b1e19a453242deeef8c368 Mon Sep 17 00:00:00 2001 From: Eduardo Salinas Date: Thu, 2 Feb 2023 12:05:12 -0500 Subject: [PATCH] fix!: [py] use full word for namespace and add test (#4485) --- python/pylibvw.cc | 23 +++++++++++------------ python/tests/test_pyvw.py | 14 +++++++++++--- python/vowpalwabbit/pyvw.py | 14 ++++++++++++-- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/python/pylibvw.cc b/python/pylibvw.cc index 3893eff2267..69bde8d1089 100644 --- a/python/pylibvw.cc +++ b/python/pylibvw.cc @@ -624,10 +624,8 @@ void ex_push_feature(example_ptr ec, unsigned char ns, feature_index fid, float } // List[Union[Tuple[Union[str,int], float], str,int]] -void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& a) +void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns_first_letter, uint64_t ns_hash, py::list& a) { // warning: assumes namespace exists! - char ns_str[2] = {(char)ns, 0}; - uint64_t ns_hash = VW::hash_space(*vw, ns_str); size_t count = 0; for (ssize_t i = 0; i < len(a); i++) { @@ -678,7 +676,7 @@ void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& } if (got) { - ec->feature_space[ns].push_back(f.x, f.weight_index); + ec->feature_space[ns_first_letter].push_back(f.x, f.weight_index); count++; } } @@ -688,11 +686,9 @@ void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& } // Dict[Union[str,int],Union[int,float]] -void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns, PyObject* o) +void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns_first_letter, uint64_t ns_hash, PyObject* o) { // warning: assumes namespace exists! - char ns_str[2] = {(char)ns, 0}; - uint64_t ns_hash = VW::hash_space(*vw, ns_str); size_t count = 0; const char* key_chars; @@ -729,7 +725,7 @@ void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns, PyObject* continue; } - ec->feature_space[ns].push_back(feat_value, feat_index); + ec->feature_space[ns_first_letter].push_back(feat_value, feat_index); count++; } @@ -759,15 +755,18 @@ void ex_push_dictionary(example_ptr ec, vw_ptr vw, PyObject* o) { py::extract ns_e(ns_raw); if (ns_e().length() < 1) continue; - unsigned char ns = ns_e()[0]; - ex_ensure_namespace_exists(ec, ns); + std::string ns_full = ns_e(); + unsigned char ns_first_letter = ns_full[0]; + uint64_t ns_hash = VW::hash_space(*vw, ns_full); - if (PyDict_Check(feats)) { ex_push_feature_dict(ec, vw, ns, feats); } + ex_ensure_namespace_exists(ec, ns_first_letter); + + if (PyDict_Check(feats)) { ex_push_feature_dict(ec, vw, ns_first_letter, ns_hash, feats); } else { py::list list = py::extract(feats); - ex_push_feature_list(ec, vw, ns, list); + ex_push_feature_list(ec, vw, ns_first_letter, ns_hash, list); } } } diff --git a/python/tests/test_pyvw.py b/python/tests/test_pyvw.py index 5127e7e00e3..8c061a43783 100644 --- a/python/tests/test_pyvw.py +++ b/python/tests/test_pyvw.py @@ -540,14 +540,22 @@ def test_example_features(): def test_example_features_dict(): vw = Workspace(quiet=True) ex = vw.example( - {"a": {"two": 1, "features": 1.0}, "b": {"more": 1, "features": 1, 5: 1.5}} + { + "a": {"two": 1, "features": 1.0}, + "namespace": {"more": 1, "feature": 1, 5: 1.5}, + } ) fs = list(ex.iter_features()) + fs_keys = [f[0] for f in fs] + + expected = [53373, 165129, 24716, 242309, 5] + + assert set(fs_keys) == set(expected) assert (ex.get_feature_id("a", "two"), 1) in fs assert (ex.get_feature_id("a", "features"), 1) in fs - assert (ex.get_feature_id("b", "more"), 1) in fs - assert (ex.get_feature_id("b", "features"), 1) in fs + assert (ex.get_feature_id("namespace", "more"), 1) in fs + assert (ex.get_feature_id("namespace", "feature"), 1) in fs assert (5, 1.5) in fs diff --git a/python/vowpalwabbit/pyvw.py b/python/vowpalwabbit/pyvw.py index 919d771ef06..dcaa06c84a9 100644 --- a/python/vowpalwabbit/pyvw.py +++ b/python/vowpalwabbit/pyvw.py @@ -974,6 +974,7 @@ def __init__(self, ex: "Example", id: Union[int, str]): - If int, uses that as an index into this Examples list of feature groups to get the namespace id character - If str, uses the first character as the namespace id character """ + self.full = None if isinstance(id, int): # you've specified a namespace by index if id < 0 or id >= ex.num_namespaces(): raise Exception("namespace " + str(id) + " out of bounds") @@ -983,6 +984,7 @@ def __init__(self, ex: "Example", id: Union[int, str]): elif isinstance(id, str): # you've specified a namespace by string if len(id) == 0: id = " " + self.full = id self.id = None # we don't know and we don't want to do the linear search required to find it self.ns = id[0] self.ord_ns = ord(self.ns) @@ -1695,6 +1697,7 @@ def num_features_in(self, ns: Union[NamespaceId, str, int]) -> int: """ return pylibvw.example.num_features_in(self, self.get_ns(ns).ord_ns) + # pytype: disable=attribute-error def get_feature_id( self, ns: Union[NamespaceId, str, int], @@ -1722,7 +1725,13 @@ def get_feature_id( return feature if isinstance(feature, str): if ns_hash is None: - ns_hash = self.vw.hash_space(self.get_ns(ns).ns) + if type(ns) != NamespaceId: + ns = self.get_ns(ns) + ns_hash = ( + self.vw.hash_space(ns.full) + if ns.full + else self.vw.hash_space(ns.ns) + ) return self.vw.hash_feature(feature, ns_hash) raise Exception("cannot extract feature of type: " + str(type(feature))) @@ -1839,8 +1848,9 @@ def push_features( """ ns = self.get_ns(ns) self.ensure_namespace_exists(ns) + ns_hash = self.vw.hash_space(ns.full) if ns.full else self.vw.hash_space(ns.ns) self.push_feature_list( - self.vw, ns.ord_ns, featureList + self.vw, ns.ord_ns, ns_hash, featureList ) # much faster just to do it in C++ # ns_hash = self.vw.hash_space( ns.ns ) # for feature in featureList: