Skip to content

Commit

Permalink
Support object types as values.
Browse files Browse the repository at this point in the history
Fixes pytries#7.
  • Loading branch information
b4hand committed May 19, 2014
1 parent dec3942 commit 687f074
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 6 deletions.
85 changes: 83 additions & 2 deletions src/hat_trie.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from chat_trie cimport *

cimport cpython

cdef class BaseTrie:
"""
Base HAT-Trie wrapper.
Expand Down Expand Up @@ -81,9 +83,9 @@ cdef class BaseTrie:
return value_ptr != NULL


cdef class Trie(BaseTrie):
cdef class IntTrie(BaseTrie):
"""
HAT-Trie with unicode support.
HAT-Trie with unicode support that stores int as value.
XXX: Internal encoding is hardcoded as UTF8. This is the fastest
encoding that can handle all unicode symbols and doesn't have
Expand Down Expand Up @@ -130,3 +132,82 @@ cdef class Trie(BaseTrie):

def keys(self):
return [key.decode('utf8') for key in self.iterkeys()]


cdef class Trie(BaseTrie):
"""
HAT-Trie with unicode support and arbitrary values.
XXX: Internal encoding is hardcoded as UTF8. This is the fastest
encoding that can handle all unicode symbols and doesn't have
zero bytes.
This may seem sub-optimal because it is multibyte encoding;
single-byte language-specific encoding (such as cp1251)
seems to be faster. But this is not the case because:
1) the bottleneck of this wrapper is string encoding, not trie traversal;
2) python's unicode encoding utilities are optimized for utf8;
3) users will have to select language-specific encoding for the trie;
4) non-hardcoded encoding causes extra overhead and prevents cython
optimizations.
That's why hardcoded utf8 is up to 9 times faster than configurable cp1251.
XXX: char-walking utilities may become tricky with multibyte
internal encoding.
"""

def __dealloc__(self):
cdef cpython.PyObject *o
if self._trie:
for k in self.iterkeys():
o = <cpython.PyObject *> self._getitem(k)
cpython.Py_XDECREF(o)


def __getitem__(self, unicode key):
cdef bytes bkey = key.encode('utf8')
return self._fromvalue(self._getitem(bkey))

def __contains__(self, unicode key):
cdef bytes bkey = key.encode('utf8')
return self._contains(bkey)

def __setitem__(self, unicode key, value):
cdef bytes bkey = key.encode('utf8')
self._setitem(bkey, self._tovalue(value))

def get(self, unicode key, value=None):
cdef bytes bkey = key.encode('utf8')
try:
return self._fromvalue(self._getitem(bkey))
except KeyError:
return value

def setdefault(self, unicode key, value):
cdef bytes bkey = key.encode('utf8')
return self._setdefault(bkey, self._tovalue(value))

def keys(self):
return [key.decode('utf8') for key in self.iterkeys()]

cdef void _setitem(self, char* key, value_t value):
cdef cpython.PyObject *o
cdef value_t* value_ptr = hattrie_tryget(self._trie, key, len(key))
if value_ptr != NULL:
o = <cpython.PyObject *> value_ptr[0]
cpython.Py_XDECREF(o)
hattrie_get(self._trie, key, len(key))[0] = value

cdef object _fromvalue(self, value_t value):
cdef cpython.PyObject *o
o = <cpython.PyObject *> value
cpython.Py_XINCREF(o)
return <object> o

cdef value_t _tovalue(self, object obj):
cdef cpython.PyObject *o
o = <cpython.PyObject *> obj
cpython.Py_XINCREF(o)
return <value_t> o
70 changes: 70 additions & 0 deletions tests/test_inttrie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
import string
import random

import pytest
import hat_trie

def test_getitem_set():
trie = hat_trie.IntTrie()
trie['foo'] = 5
trie['bar'] = 10
assert trie['foo'] == 5
assert trie['bar'] == 10

with pytest.raises(KeyError):
trie['f']

with pytest.raises(KeyError):
trie['foob']

with pytest.raises(KeyError):
trie['x']

non_ascii_key = 'вася'
trie[non_ascii_key] = 20
assert trie[non_ascii_key] == 20

def test_get():
trie = hat_trie.IntTrie()

assert trie.get('foo') is -1
assert trie.get('bar') is -1
assert trie.get('foo', 5) == 5

trie['foo'] = 5
trie['bar'] = 10

assert trie.get('foo') == 5
assert trie.get('bar') == 10

def test_contains():
trie = hat_trie.IntTrie()
assert 'foo' not in trie
trie['foo'] = 5
assert 'foo' in trie
assert 'f' not in trie


def test_get_set_fuzzy():
russian = 'абвгдеёжзиклмнопрстуфхцчъыьэюя'
alphabet = russian.upper() + string.ascii_lowercase
words = list(set([
"".join([random.choice(alphabet) for x in range(random.randint(2,10))])
for y in range(20000)
]))

trie = hat_trie.IntTrie()

enumerated_words = list(enumerate(words))

for index, word in enumerated_words:
trie[word] = index

random.shuffle(enumerated_words)
for index, word in enumerated_words:
assert word in trie, word
assert trie[word] == index, (word, index)

assert sorted(trie.keys()) == sorted(words)
10 changes: 6 additions & 4 deletions tests/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
def test_getitem_set():
trie = hat_trie.Trie()
trie['foo'] = 5
trie['bar'] = 10
trie['bar'] = 'asdf'
trie['baz'] = (10, 'quuz')
assert trie['foo'] == 5
assert trie['bar'] == 10
assert trie['bar'] == 'asdf'
assert trie['baz'] == (10, 'quuz')

with pytest.raises(KeyError):
trie['f']
Expand All @@ -29,8 +31,8 @@ def test_getitem_set():
def test_get():
trie = hat_trie.Trie()

assert trie.get('foo') == -1
assert trie.get('bar') == -1
assert trie.get('foo') is None
assert trie.get('bar') is None
assert trie.get('foo', 5) == 5

trie['foo'] = 5
Expand Down

0 comments on commit 687f074

Please sign in to comment.