Hashing, Grouping &
Advanced Algorithms
We can scan, filter, project, and compute — all lazily, all in batches, all with schema propagation. But real data pipelines need two more things: combining data from multiple sources, and summarizing data into groups. These aren't just new plan nodes. They require real algorithms.
Until now, every plan node we've built has been a streaming operation: it pulls a row from its child, does something to it, and hands it up. No buffering, no memory growth. That streak ends here. Joins need to build a lookup table. Aggregations need running accumulators. Both force us to think about memory vs. speed — the oldest tradeoff in computing.
This module introduces two families of algorithms — hash-based and
sort-based — and shows how they power the GROUP BY and
JOIN operations that every data engineer writes daily. By
the end, you'll understand the machinery behind
orders.join(customers, on="customer_id") — and you'll have
built it yourself.
Lesson 4.1 The Hash Join
Here's a scenario every data pipeline encounters. You have two tables:
orders = [
(1, 101, 250), # (order_id, customer_id, amount)
(2, 102, 180),
(3, 101, 320),
(4, 103, 90),
]
customers = [
(101, "Alice"), # (customer_id, name)
(102, "Bob"),
(103, "Carol"),
]
You want to combine them: for each order, attach the customer's name. In
SQL, this is SELECT * FROM orders JOIN customers ON orders.customer_id = customers.customer_id.
Simple enough. But how do you actually implement that?
The naive approach: nested loops
The most obvious strategy is what you'd do by hand — for every order, scan through all customers looking for a match:
def nested_loop_join(left, right, l_key, r_key):
results = []
for l_row in left:
for r_row in right:
if l_row[l_key] == r_row[r_key]:
results.append(l_row + r_row)
return results
It works. And for four orders and three customers, it's fine — 12 comparisons. But think about what happens at scale.
The fix is the same one Python uses for dict lookups: hashing.
Instead of scanning the entire right table for every left row, we build a
hash table from the right side once, then look up each left row's
key in O(1).
The hash join algorithm
A hash join has two phases:
Phase 1 — Build. Iterate through the right (smaller) table and insert every row into a dictionary keyed by the join column. This costs O(m) time and O(m) memory, where m is the size of the right table.
Phase 2 — Probe. Stream through the left (larger) table. For each row, extract its key and look it up in the dictionary. If there's a match, combine the rows. This costs O(n) time, where n is the size of the left table.
Total: O(n + m) time, O(m) memory. That's a massive improvement over O(n × m).
from collections import defaultdict
def hash_join(left, right, l_key, r_key):
# Phase 1: Build hash table from the right side
index = defaultdict(list)
for row in right:
index[row[r_key]].append(row)
# Phase 2: Probe with the left side
results = []
for row in left:
for match in index[row[l_key]]:
results.append(row + match)
return results
Notice the defaultdict(list). When we look up a key that
doesn't exist, we get an empty list instead of a KeyError.
This matters because a single customer can have many orders — the hash
table maps each key to a list of matching rows, not a single row.
import pyfloe as pf
orders = pf.LazyFrame([
{"order_id": 1, "customer_id": 101, "amount": 250},
{"order_id": 2, "customer_id": 102, "amount": 180},
{"order_id": 3, "customer_id": 101, "amount": 320},
{"order_id": 4, "customer_id": 103, "amount": 90},
])
customers = pf.LazyFrame([
{"customer_id": 101, "name": "Alice"},
{"customer_id": 102, "name": "Bob"},
{"customer_id": 103, "name": "Carol"},
])
# pyfloe uses a hash join internally — see the plan
joined = orders.join(customers, on="customer_id")
print(joined.explain())
print()
# Execute and see the result
joined.collect()
for row in joined.to_pylist():
print(f" Order {row['order_id']}: ${row['amount']} — {row['name']}")
# Under the hood, pyfloe's JoinNode builds a hash table
# from the right side (customers), then probes it with
# each row from the left side (orders). O(n+m) instead of O(n*m).
print(f"\nInternal plan node: {type(joined._plan).__name__}")
on="customer_id" to
how="left" and add a customer with no orders — say,
{"customer_id": 105, "name": "Eve"}. What happens to Eve
in the result? Now try how="full". Add an order with a
customer_id that doesn't exist in the customers table.
Where do the None values appear?
_make_key_fn is a performance detail. Safe to skip.
From indices to key functions
Our simplified version uses a single integer index for the join key. But
real-world joins often use multiple columns —
JOIN ON a.year = b.year AND a.region = b.region. We need a
function that extracts a composite key as a tuple.
This is where Python's operator.itemgetter enters the picture.
It's a stdlib function that creates a fast key-extraction callable:
from operator import itemgetter
row = ("EU", 2024, "Alice", 250)
# Extract columns 0 and 1 as a composite key
get_key = itemgetter(0, 1)
get_key(row) # → ("EU", 2024)
Why not just use a lambda like lambda row: (row[0], row[1])?
Because itemgetter is implemented in C — it's 2–5× faster
than an equivalent lambda for tuple indexing. When you're extracting keys
from millions of rows in a tight loop, that difference adds up.
pyfloe wraps this in a small helper that handles both single-column and multi-column keys:
def _make_key_fn(indices: list[int]) -> Callable[[tuple], tuple]:
if len(indices) == 1:
idx = indices[0]
return lambda row: (row[idx],)
return itemgetter(*indices)
For a single-column key, it wraps the value in a tuple so the return type
is always consistent. For multi-column keys, it delegates straight to
itemgetter. This tiny function gets called for every single
row in both the build and probe phases — making it fast matters.
pyfloe's JoinNode — the real implementation
Now let's look at the actual JoinNode. The structure follows
our simplified version closely, but it integrates with the plan tree
(schema, batched execution) and handles the details we glossed over.
Let's take it in stages.
First, the class skeleton. Like every plan node, it uses
__slots__ and stores references to its two child nodes:
class JoinNode(PlanNode):
__slots__ = ("left", "right", "left_on", "right_on", "how")
def __init__(self, left: PlanNode, right: PlanNode,
left_on: list[str],
right_on: list[str],
how: JoinHow = "inner"):
self.left = left
self.right = right
self.left_on = left_on
self.right_on = right_on
self.how = how
def schema(self) -> LazySchema:
return self.left.schema().merge(self.right.schema())
Notice the schema() method: a join's output contains all
columns from both sides, so it merges the two schemas. No data
is touched — this runs at plan-build time, not execution time.
Now the heart of it — execute_batched(). The build phase
consumes the entire right side into a defaultdict:
def execute_batched(self) -> Iterator[list[tuple]]:
rs = self.right.schema()
r_map = {n: i for i, n in enumerate(rs.column_names)}
r_idx = [r_map[c] for c in self.right_on]
r_key = _make_key_fn(r_idx)
right_ht: dict[tuple, list] = defaultdict(list)
for chunk in self.right.execute_batched():
for row in chunk:
right_ht[r_key(row)].append(row)
Line by line: it resolves column names ("customer_id")
to column indices (position in the tuple) using the schema. Then
it builds the key function and populates the hash table. The pattern
r_map → r_idx → _make_key_fn is a
recurring idiom in pyfloe — you'll see it in AggNode too.
Then the probe phase streams the left side, looking up each key:
ls = self.left.schema()
l_map = {n: i for i, n in enumerate(ls.column_names)}
l_idx = [l_map[c] for c in self.left_on]
l_key = _make_key_fn(l_idx)
buf: list = []
for chunk in self.left.execute_batched():
for left_row in chunk:
key = l_key(left_row)
matches = right_ht.get(key)
if matches:
for right_row in matches:
buf.append(left_row + right_row)
if len(buf) >= _BATCH_SIZE:
yield buf
buf = []
if buf:
yield buf
Two things to notice. First, left_row + right_row — because
rows are tuples, concatenation is just +. An order tuple
(1, 101, 250) joined with a customer tuple
(101, "Alice") becomes (1, 101, 250, 101, "Alice").
Second, the output buffering. Results accumulate in buf
and are yielded in batches of _BATCH_SIZE (1024). This is
the same batching pattern from Module 3 — the join produces data in chunks,
keeping memory predictable even when there are many matches.
When you write orders.join(customers, on="customer_id") in
Polars or pyfloe, this is the algorithm. The right table becomes a hash map,
the left streams through probing it, and results come out in batches.
O(n + m) is why joins on millions of rows return in seconds.
Lesson 4.2 Join Variants — Inner, Left, Full
The code we just saw implements an inner join — it only emits rows where both sides have a matching key. But what about an order whose customer got deleted? Or a customer who hasn't placed any orders yet? Different join types handle these cases differently.
Inner Both sides must match
Only emit a row when the left key finds a match in the right hash table. Unmatched rows from either side are silently dropped.
Left Keep every left row
If a left row has no match, emit it anyway with None for
all right columns. No left row is ever lost.
Full Keep everything
Keep all unmatched left rows (null-filled right) and all unmatched right rows (null-filled left). Nothing is dropped.
Right join?
A right join is just a left join with the sides swapped. pyfloe (and Polars) don't provide a separate right join — you simply switch which table goes on which side.
The key insight is that join variants don't change the algorithm. The build-and-probe structure is identical. What changes is what happens when a key has no match.
Null padding
When a left row has no match in a left join, we need to fill the right
side's columns with None. pyfloe precomputes a "null row"
tuple so it can be appended cheaply:
null_right = (None,) * len(rs.column_names)
null_left = (None,) * len(ls.column_names)
If the right table has 4 columns, null_right is
(None, None, None, None). This tuple is created once and
reused for every unmatched row — a small but important optimization when
thousands of rows don't match.
The full probe loop with join variants
Here's how the probe phase actually looks in pyfloe, with all three join types handled:
matched_keys: set = set()
buf: list = []
for chunk in self.left.execute_batched():
for left_row in chunk:
key = l_key(left_row)
matches = right_ht.get(key)
if matches:
matched_keys.add(key)
for right_row in matches:
buf.append(left_row + right_row)
elif self.how in ("left", "full"):
buf.append(left_row + null_right)
if len(buf) >= _BATCH_SIZE:
yield buf
buf = []
Walk through it: when there's a match, we record the key in
matched_keys and emit the combined row. When there's no
match and this is a left or full join, we emit the left row padded with
nulls. For an inner join, unmatched rows simply disappear — the
elif never fires.
But a full join needs one more pass. After processing all left rows, there
may be right-side rows whose keys were never matched. That's what
matched_keys is for:
if self.how == "full":
for key, rows in right_ht.items():
if key not in matched_keys:
for right_row in rows:
buf.append(null_left + right_row)
if len(buf) >= _BATCH_SIZE:
yield buf
buf = []
if buf:
yield buf
It iterates over the entire hash table, finds keys that were never
matched during the probe, and emits those right rows padded with
null_left. This is the only place in the algorithm that
touches the hash table twice — and only for full joins.
JoinHow = Literal["inner", "left", "full"]
in expr.py — a Literal type that restricts the
allowed values to exactly three strings. This gives you IDE autocomplete,
type-checker guarantees, and a clear signal to anyone reading the code that
no other join types exist. It's a small touch of type-driven design that
costs nothing at runtime.
In Polars, PySpark, or pyfloe, the how parameter maps
directly to these if/elif branches. The
hash-build-and-probe algorithm is identical — only the unmatched-key
handling differs.
Lesson 4.3 Hash Aggregation with Running Accumulators
Joins combine rows from two tables. Aggregation does the opposite — it
collapses rows. A million rows of individual transactions become
a handful of rows summarizing totals per region, averages per product, or
counts per day. In SQL: SELECT region, SUM(amount) FROM orders GROUP BY region.
The question, as always, is: how?
The naive approach: collect everything, then compute
The simplest strategy is to group all the rows first, then apply the aggregation function to each group:
from collections import defaultdict
def naive_group_by_sum(rows, key_col, val_col):
groups = defaultdict(list)
for row in rows:
groups[row[key_col]].append(row[val_col])
return {k: sum(vals) for k, vals in groups.items()}
This works, but look at what it stores: every single value for every group. If you have 10 million orders across 50 regions, this collects 10 million integers into 50 lists before computing a single sum. Memory usage is O(n) — proportional to the total number of rows, not the number of groups.
n_unique truly requires storing every value. For most
aggregations, we can maintain a tiny accumulator per group instead of a
growing list.
Running accumulators — O(k) memory
The insight is simple: instead of storing all values per group, store only the accumulator state you need to compute the final result. For a sum, that's a single running total. For a mean, it's a running sum and a count. For min/max, it's the current extremum.
def running_group_by_sum(rows, key_col, val_col):
accumulators = {} # key → running_sum
for row in rows:
key = row[key_col]
val = row[val_col]
if key not in accumulators:
accumulators[key] = 0
accumulators[key] += val
return accumulators
Memory usage drops from O(n) to O(k), where k is the number of distinct groups. If you're summarizing 10 million rows into 50 regions, you store 50 integers instead of 10 million. That's the difference between megabytes and bytes.
Generalizing accumulators
The running-sum approach works for sum, but what about
mean, min, count? Each aggregation
needs its own accumulator shape and update logic. pyfloe solves this with
three small functions that work as a team:
_init_acc(agg) — creates the initial
accumulator state. sum needs just {"s": 0}.
mean needs {"s": 0.0, "n": 0} (running sum and count).
first uses a flag to capture only the first non-null value.
Full accumulator init and finalize code
def _init_acc(agg: AggExpr) -> dict:
kind = agg.agg_name
if kind == "sum": return {"s": 0}
elif kind == "count": return {"n": 0}
elif kind == "mean": return {"s": 0.0, "n": 0}
elif kind == "min": return {"v": None}
elif kind == "max": return {"v": None}
elif kind == "first": return {"v": None, "set": False}
elif kind == "last": return {"v": None}
elif kind == "n_unique": return {"s": set()}
else: return {"vals": []}
_make_updater(agg) — returns a function that
knows how to update a specific accumulator type:
def _make_updater(agg: AggExpr) -> Callable:
kind = agg.agg_name
if kind == "sum":
def _update(acc, val):
if val is not None: acc["s"] += val
elif kind == "count":
def _update(acc, val):
if val is not None: acc["n"] += 1
elif kind == "min":
def _update(acc, val):
if val is not None:
if acc["v"] is None or val < acc["v"]:
acc["v"] = val
# ... max, mean, first, last, n_unique follow the same pattern
return _update
Each updater is a closure that captures the accumulator shape. The
if val is not None guard is critical — nulls are skipped,
matching the SQL convention that SUM ignores nulls.
_finalize_acc(acc, agg) — extracts the final
result from the accumulator:
def _finalize_acc(acc: dict, agg: AggExpr) -> Any:
kind = agg.agg_name
if kind == "sum": return acc["s"]
elif kind == "count": return acc["n"]
elif kind == "mean":
return acc["s"] / acc["n"] if acc["n"] else 0.0
elif kind in ("min", "max", "first", "last"):
return acc["v"]
elif kind == "n_unique":
return len(acc["s"])
else:
return agg.eval_agg(acc["vals"])
The mean case divides the running sum by the count — only possible because
the accumulator tracked both. n_unique returns the size of the set.
Aggregator class
has zero, reduce, and merge.
Polars' internal aggregation kernels follow the same structure. Once you
recognize the pattern, you'll see it in every system that needs to
summarize large datasets without materializing them.
The AggNode — wiring it together
The three helper functions handle individual accumulators. The
AggNode orchestrates them across all groups and all
aggregation expressions. Here's how it consumes its input:
def execute_batched(self) -> Iterator[list[tuple]]:
col_map = {n: i for i, n in enumerate(
self.child.schema().column_names)}
g_idx = [col_map[c] for c in self.group_by]
key_fn = _make_key_fn(g_idx)
n_aggs = len(self.agg_exprs)
# Pre-compile expression evaluators and updaters
compiled_evals = [a.expr.compile(col_map) for a in self.agg_exprs]
updaters = [_make_updater(a) for a in self.agg_exprs]
accumulators: dict[tuple, list] = {}
Setup: resolve group-by column names to indices, build the key function
(same _make_key_fn we saw in JoinNode), and
pre-compile the expression evaluators and updaters. The
compiled_evals list contains one fast callable per
aggregation — each knows how to extract its value from a tuple row.
The updaters list contains the corresponding accumulator
update functions.
Then the main loop processes every row:
for chunk in self.child.execute_batched():
for row in chunk:
key = key_fn(row)
try:
accs = accumulators[key]
except KeyError:
accs = [_init_acc(a) for a in self.agg_exprs]
accumulators[key] = accs
for i in range(n_aggs):
updaters[i](accs[i], compiled_evals[i](row))
For each row: extract the group key, look up (or create) the accumulator
list, then update every accumulator. Notice the try/except KeyError
instead of if key not in accumulators — this is a deliberate
micro-optimization. In Python, the "ask forgiveness" (EAFP) pattern is
faster when misses are rare, which is the case once most groups have been
seen.
After consuming all input, the accumulators are finalized and emitted:
buf: list = []
for key, accs in accumulators.items():
agg_vals = tuple(
_finalize_acc(accs[i], self.agg_exprs[i])
for i in range(n_aggs)
)
buf.append(key + agg_vals)
if len(buf) >= _BATCH_SIZE:
yield buf
buf = []
if buf:
yield buf
Each output row is key + agg_vals — the group-by columns
followed by the aggregation results. For
group_by("region").agg(col("amount").sum().alias("total")),
you'd get rows like ("EU", 430) and ("US", 320).
df.group_by("region").agg(col("amount").sum().alias("total"))
builds an AggNode with group key ["region"] and a
sum over col("amount"). The running-accumulator
algorithm processes every row exactly once, maintaining one integer per group.
Lesson 4.4 The Sorted Alternative — O(1) Memory
Hash joins and hash aggregations are fast and general-purpose. But they come with a cost: the hash table. For joins, the entire right side lives in memory. For aggregations, all group accumulators live in memory. When the data is large and the groups are many, that memory pressure adds up.
There's an alternative — but it requires a precondition: the data must be sorted by the key columns. When that's true, something remarkable happens: both joins and aggregations can run in constant memory.
Sorted aggregation — one group at a time
When rows are sorted by the group key, all rows belonging to the same group appear in a contiguous run. You don't need a hash table — you just watch for the key to change:
def sorted_group_by_sum(sorted_rows, key_col, val_col):
prev_key = None
running_sum = 0
for row in sorted_rows:
key = row[key_col]
if key != prev_key:
if prev_key is not None:
yield (prev_key, running_sum)
prev_key = key
running_sum = 0
running_sum += row[val_col]
if prev_key is not None:
yield (prev_key, running_sum)
Only one accumulator is alive at any time. When the key changes, the current group is finalized and emitted, and a fresh accumulator is created for the next group. Memory usage: O(1).
Python's itertools.groupby encapsulates this exact pattern.
It groups consecutive elements with the same key — but it has a
critical gotcha:
GROUP BY, Python's itertools.groupby
only groups adjacent elements with the same key. If your data is
[A, A, B, A], you'll get three groups: [A, A],
[B], and [A] — not two. This is exactly right
for SortedAggNode — its input is guaranteed sorted. The
danger is using groupby on unsorted data elsewhere.
pyfloe's SortedAggNode uses groupby directly,
trusting that its input is sorted:
def execute_batched(self) -> Iterator[list[tuple]]:
col_map = {n: i for i, n in enumerate(
self.child.schema().column_names)}
g_idx = [col_map[c] for c in self.group_by]
key_func = _make_key_fn(g_idx)
compiled_evals = [a.expr.compile(col_map) for a in self.agg_exprs]
updaters = [_make_updater(a) for a in self.agg_exprs]
buf: list = []
for key, group_rows in groupby(self.child.execute(), key=key_func):
accs = [_init_acc(a) for a in self.agg_exprs]
for row in group_rows:
for i in range(len(self.agg_exprs)):
updaters[i](accs[i], compiled_evals[i](row))
buf.append(
key + tuple(_finalize_acc(accs[i], self.agg_exprs[i])
for i in range(len(self.agg_exprs))))
if len(buf) >= _BATCH_SIZE:
yield buf
buf = []
if buf:
yield buf
The structure mirrors AggNode almost exactly — same
_init_acc, same _make_updater, same
_finalize_acc. The only difference is
where the grouping happens. AggNode uses a
hash table (accumulators dict). SortedAggNode
uses itertools.groupby, which detects groups by watching for
key changes in sorted input. The accumulator logic is identical.
Sorted merge join — two cursors
The same principle applies to joins. When both sides are sorted by their join key, you don't need a hash table. Instead, you advance two cursors in lockstep:
# Both inputs sorted by join key
left_cursor = next(left_iter)
right_cursor = next(right_iter)
while both_active:
if left_key < right_key:
advance_left() # left row has no match
elif left_key > right_key:
advance_right() # right row has no match
else:
# Keys match! Emit cross product of matching groups
emit_matches()
Because the data is sorted, the cursor that's behind always advances. When the cursors land on equal keys, the rows match and get emitted. Memory usage is O(1) for one-to-one joins. For many-to-many matches (multiple rows with the same key on both sides), it briefly collects the matching group to compute the cross product — that's O(g) where g is the group size.
pyfloe's SortedMergeJoinNode implements this algorithm in
full, with support for inner, left, and full join modes. The code is the
longest in the plan module (~80 lines) because it needs to handle all
the edge cases: exhausted iterators, mismatched keys at the boundaries,
and unmatched rows for outer joins. We won't reproduce all of it here,
but here's the critical matching section:
else: # lk == rk — keys match
match_key = lk
left_group = [l_row]
right_group = [r_row]
# Collect all left rows with this key
while True:
try:
l_row = next(left_iter)
if l_key(l_row) == match_key:
left_group.append(l_row)
else: break
except StopIteration:
l_exhausted = True; break
# Same for right side... then cross product:
for lr in left_group:
for rr in right_group:
yield lr + rr
When keys match, both cursors advance to collect all rows sharing that key. Then the cross product is emitted. After that, the cursors resume their lockstep advance from where they left off.
Choosing the right algorithm
pyfloe exposes both strategies through a single sorted
parameter:
# Hash join (default) — works on any data
orders.join(customers, on="customer_id")
# Sort-merge join — O(1) memory for pre-sorted input
orders.sort("id").join(
customers.sort("id"), on="id", sorted=True)
# Hash aggregation (default)
df.group_by("region").agg(col("amount").sum())
# Sorted streaming aggregation — O(1) memory per group
df.sort("region").group_by(
"region", sorted=True).agg(col("amount").sum())
| Hash-based | Sorted | |
|---|---|---|
| Join time | O(n + m) | O(n + m) if pre-sorted |
| Join memory | O(m) right side | O(1) base |
| Agg time | O(n) | O(n) |
| Agg memory | O(k) groups | O(1) |
| Best when | Unsorted data, small right side | Pre-sorted data, streaming pipelines |
read_csv("sorted_log.csv").join(lookup, on="key", sorted=True).to_csv(...),
the sort-merge join never materializes either side. Rows flow from the
CSV reader through the join and out to the writer with constant memory —
no matter whether the file is 1 MB or 100 GB. This is the same technique
that makes Spark's sort-merge join viable on datasets larger than RAM.
Deep dive: Why does pyfloe keep both algorithms?
You might wonder: if sorted algorithms use less memory, why not always sort first and then use them? Because sorting itself costs O(n log n) time and O(n) memory. If your data isn't already sorted, the sort negates the memory savings of the sorted join or aggregation.
The hash-based algorithms are the right default — they work on any data with no preconditions. The sorted algorithms are a specialization for when the data happens to be sorted already (common with time-series data, log files, and pre-sorted exports).
A cost-based optimizer (like the ones in PostgreSQL or Spark) would
choose automatically based on statistics — estimated row counts,
key cardinality, available memory. pyfloe's rule-based optimizer
doesn't make this choice, so it's left to the user via the
sorted=True parameter. This is a deliberate design
tradeoff: simplicity over automation.
JoinNode for SortedMergeJoinNode automatically.
That's one of the open problems we'll leave you with.
Exercises
Test your understanding of the algorithms we've built.
Quick check
1. A hash join has two phases. Which side is used for the "build" phase?
2. What's the time complexity of a nested-loop join on tables of size n and m?
3. Why does hash aggregation use running accumulators instead of storing all values?
4. What's the critical gotcha with itertools.groupby?
5. In a full join, what does matched_keys track?
Before diving into the optimizer, take a detour and wire everything you've built — scan, filter, join, aggregate — into a complete mini-pipeline you can run end to end.
Take the side quest →Source References
Browse the pyfloe source code for the classes and functions covered in this module.