Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add join #31

Merged
merged 6 commits into from
Jun 28, 2014
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cytoolz/itertoolz.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,20 @@ cdef class _pluck_list_default:


cpdef object pluck(object ind, object seqs, object default=*)

cdef class join:
cdef Py_ssize_t n
cdef object iterseq
cdef object leftkey
cdef object leftseq
cdef object rightkey
cdef object rightseq
cdef object matches
cdef object right
cdef object key
cdef object d
cdef object d_items
cdef object seen_keys
cdef object is_rightseq_exhausted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use cdef bint for a fast C boolean type that is compatible with Python bools.

cdef object left_default
cdef object right_default
135 changes: 135 additions & 0 deletions cytoolz/itertoolz.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,141 @@ cpdef object pluck(object ind, object seqs, object default=no_default):
return _pluck_index_default(ind, seqs, default)


def getter(index):
if isinstance(index, list):
if len(index) == 1:
index = index[0]
return lambda x: (x[index],)
else:
return itemgetter(*index)
else:
return itemgetter(index)


cdef class join:
""" Join two sequences on common attributes

This is a semi-streaming operation. The LEFT sequence is fully evaluated
and placed into memory. The RIGHT sequence is evaluated lazily and so can
be arbitrarily large.

>>> friends = [('Alice', 'Edith'),
... ('Alice', 'Zhao'),
... ('Edith', 'Alice'),
... ('Zhao', 'Alice'),
... ('Zhao', 'Edith')]

>>> cities = [('Alice', 'NYC'),
... ('Alice', 'Chicago'),
... ('Dan', 'Syndey'),
... ('Edith', 'Paris'),
... ('Edith', 'Berlin'),
... ('Zhao', 'Shanghai')]

>>> # Vacation opportunities
>>> # In what cities do people have friends?
>>> result = join(second, friends,
... first, cities)
>>> for ((a, b), (c, d)) in sorted(unique(result)):
... print((a, d))
('Alice', 'Berlin')
('Alice', 'Paris')
('Alice', 'Shanghai')
('Edith', 'Chicago')
('Edith', 'NYC')
('Zhao', 'Chicago')
('Zhao', 'NYC')
('Zhao', 'Berlin')
('Zhao', 'Paris')

Specify outer joins with keyword arguments ``left_default`` and/or
``right_default``. Here is a full outer join in which unmatched elements
are paired with None.

>>> identity = lambda x: x
>>> list(join(identity, [1, 2, 3],
... identity, [2, 3, 4],
... left_default=None, right_default=None))
[(2, 2), (3, 3), (None, 4), (1, None)]

Usually the key arguments are callables to be applied to the sequences. If
the keys are not obviously callable then it is assumed that indexing was
intended, e.g. the following is a legal change

>>> # result = join(second, friends, first, cities)
>>> result = join(1, friends, 0, cities) # doctest: +SKIP
"""
def __init__(self,
object leftkey, object leftseq,
object rightkey, object rightseq,
object left_default=no_default,
object right_default=no_default):
if not callable(leftkey):
leftkey = getter(leftkey)
if not callable(rightkey):
rightkey = getter(rightkey)

self.left_default = left_default
self.right_default = right_default

self.leftkey = leftkey
self.rightkey = rightkey
self.rightseq = iter(rightseq)

self.d = groupby(leftkey, leftseq)
self.seen_keys = set()
self.matches = iter(())
self.right = None

self.is_rightseq_exhausted = False


def __iter__(self):
return self

def __next__(self):
if not self.is_rightseq_exhausted:
try:
match = next(self.matches)
return (match, self.right)
except StopIteration: # iterator of matches exhausted
try:
item = next(self.rightseq) # get a new item
except StopIteration: # no items, switch to outer join
self.is_rightseq_exhausted = True
if self.right_default is not no_default:
self.d_items = iter(self.d.items())
self.matches = iter(())
return next(self)
else:
raise

key = self.rightkey(item)
self.seen_keys.add(key)

try:
self.matches = iter(self.d[key]) # get left matches
except KeyError:
if self.left_default is not no_default:
return (self.left_default, item) # outer join

self.right = item
return next(self)

else: # we've exhausted the right sequence, lets iterate over unseen
# items on the left
try:
match = next(self.matches)
return (match, self.right_default)
except StopIteration:
key, matches = next(self.d_items)
while(key in self.seen_keys and matches):
key, matches = next(self.d_items)
self.key = key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.key appears to be unnecessary.

self.matches = iter(matches)
return next(self)


# I find `_consume` convenient for benchmarking. Perhaps this belongs
# elsewhere, so it is private (leading underscore) and hidden away for now.

Expand Down
95 changes: 94 additions & 1 deletion cytoolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from itertools import starmap
from cytoolz.utils import raises
from functools import partial
from cytoolz.itertoolz import (remove, groupby, merge_sorted,
Expand All @@ -9,7 +10,7 @@
rest, last, cons, frequencies,
reduceby, iterate, accumulate,
sliding_window, count, partition,
partition_all, take_nth, pluck)
partition_all, take_nth, pluck, join)

from cytoolz.compatibility import range, filter
from operator import add, mul
Expand Down Expand Up @@ -264,3 +265,95 @@ def test_pluck():

assert raises(IndexError, lambda: list(pluck(1, [[0]])))
assert raises(KeyError, lambda: list(pluck('name', [{'id': 1}])))


def test_join():
names = [(1, 'one'), (2, 'two'), (3, 'three')]
fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)]

def addpair(pair):
return pair[0] + pair[1]

result = set(starmap(add, join(first, names, second, fruit)))

expected = set([((1, 'one', 'apple', 1)),
((1, 'one', 'orange', 1)),
((2, 'two', 'banana', 2)),
((2, 'two', 'coconut', 2))])

print(result)
print(expected)
assert result == expected


def test_key_as_getter():
squares = [(i, i**2) for i in range(5)]
pows = [(i, i**2, i**3) for i in range(5)]

assert set(join(0, squares, 0, pows)) == set(join(lambda x: x[0], squares,
lambda x: x[0], pows))

get = lambda x: (x[0], x[1])
assert set(join([0, 1], squares, [0, 1], pows)) == set(join(get, squares,
get, pows))

get = lambda x: (x[0],)
assert set(join([0], squares, [0], pows)) == set(join(get, squares,
get, pows))


def test_join_double_repeats():
names = [(1, 'one'), (2, 'two'), (3, 'three'), (1, 'uno'), (2, 'dos')]
fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)]

result = set(starmap(add, join(first, names, second, fruit)))

expected = set([((1, 'one', 'apple', 1)),
((1, 'one', 'orange', 1)),
((2, 'two', 'banana', 2)),
((2, 'two', 'coconut', 2)),
((1, 'uno', 'apple', 1)),
((1, 'uno', 'orange', 1)),
((2, 'dos', 'banana', 2)),
((2, 'dos', 'coconut', 2))])

print(result)
print(expected)
assert result == expected


def test_join_missing_element():
names = [(1, 'one'), (2, 'two'), (3, 'three')]
fruit = [('apple', 5), ('orange', 1)]

result = list(join(first, names, second, fruit))
print(result)
result = set(starmap(add, result))

expected = set([((1, 'one', 'orange', 1))])

assert result == expected


def test_left_outer_join():
result = set(join(identity, [1, 2], identity, [2, 3], left_default=None))
expected = set([(2, 2), (None, 3)])

print(result)
print(expected)
assert result == expected


def test_right_outer_join():
result = set(join(identity, [1, 2], identity, [2, 3], right_default=None))
expected = set([(2, 2), (1, None)])

assert result == expected


def test_outer_join():
result = set(join(identity, [1, 2], identity, [2, 3],
left_default=None, right_default=None))
expected = set([(2, 2), (1, None), (None, 3)])

assert result == expected