Python Yield From

yield from is a powerful feature of Python 3 that allows for recursion with generators. when a function has a yield statement, it ends up returning a generator. If your intention is to have a recursive algorithm and have all calculated values pass through to the original caller as a single generator (for debugging, or to produce a sequence of numbers) regardless of which recursive call the value was created by, then using a simple yield won’t work.

As an example, here’s a factorial yield that might seem like it would work. The expected output is a generator that produces the factorials from 1 to n.

factorial(1), factorial(2), …, factorial(n-1), factorial(n)

[1,2,6,24,…]

def factorial(n):
    if n == 1:
        yield 1
        return 1
    else:
        x = n * factorial(n-1)
        yield x
        return x

print(list(factorial(3)))
Traceback (most recent call last):
  File "factorialyield.py", line 9, in <module>
    print(list(factorial(3)))
  File "factorialyield.py", line 6, in factorial
    x = n * factorial(n-1)
TypeError: unsupported operand type(s) for *: 'int' and 'generator'

What’s happening here is that factorial(i) calls factorial(i-1), which produces a generator that yields factorial(i-1) then returns factorial(i-1). We can’t multiply n by the generator, so we would need to fix this example, in a cringeworthy way, like the following:

def factorial(n):
    if n == 1:
        yield 1
        return 1
    else:
        #we know factorial(n-1) has len(1) so we can do this
        generator = list(factorial(n-1))
        x = n * generator[0]
        yield x
        return x

print(list(factorial(3)))
[6]

The intention here is to get rid of the int * generator warning. We returned the final value from the function call in a generator (because that’s how we defined the function), but we’re missing out on the original intention of the exercise. An issue with this approach is that it would require the generator to have length of 1, which may not always be the case.

This leads us back to the original question posed at the beginning of the article.

How do I seemlessly pass all intermediary values to the original caller that are produced by a recursive function? Enter yield from

def factorial(n):
    if n == 1:
        yield 1
        return 1
    else:
        x = n * (yield from factorial(n-1))
        yield x
        return x

print(list(factorial(6)))
[1, 2, 6, 24, 120, 720]

The only difference in this example is the addition of the yield from statement. yield from allows the values yielded in factorial(n-1) to be yielded in factorial(n), which would be received in factorial(n+1)’s yield from statement, or would be passed along to the original caller.

The reason why there are return statements is because in this example, yield from is serving two purposes. It’s yield values from factorial(n-1) but it’s also returning factorial(n-1). Now the yield values can’t be used inside any factorial calls, they’re purely for creating a generator to return the values to the original caller of factorial(n), and the return statements are for calculating the factorial inside the recursive call chain.

Here’s a non-cached version of the fibonacci numbers.

def fib(n):
    if n == 0:
        yield 0
        return 0
    elif n == 1:
        yield 1
        return 1
    else:
        x = yield from fib(n-2)
        y = yield from fib(n-1)
        yield x + y
        return x + y

print(list(fib(7)))

[1, 0, 1, 1, 2, 0, 1, 1, 1, 0, 1, 1, 2, 3, 5, 0, 1, 1, 1, 0, 1, 1, 2, 3, 1, 0, 1, 1, 2, 0, 1, 1, 1, 0, 1, 1, 2, 3, 5, 8, 13]

This fibonacci example (python 3.4) produces a generator which yields every fibonacci number generated by the algorithm. Recall that a naive recursive fibonacci algorithm will produce an exponentially large function calls. This example shows all of the redunant calls.

Calculating fib(3) requires fib(2) + fib(1), fib(4) requires fib(3) + fib(2), etc, and this example does an increasingly large amount of work for each increasing fibonacci number. The natural way to speed this up is to cache previously calculated values. Instead of calculating fib(5) by running fib(4) + fib(3) and doing the work to calculate each one separately, a cache already knows fib(4) and fib(3), so fib(5) is equal to those cached values and the only work required is a table lookup and an addition.

If we add a cache for the fibonacci numbers then the speedup of the improved algorithm will be transparent.

from functools import partial
def cache_yield(func, default=None):
    if default == None:
        cache = {}
    else:
        cache = default
    def inner_function(n):
        if n in cache:
            return cache[n]
        else:
            cache[n] = yield from func(n)
            return cache[n]
    return inner_function

fib_cache_yield = partial(cache_yield, default={0:0, 1:1})

@fib_cache_yield
def fib(n):
    if n == 0:
        yield 0
        return 0
    elif n == 1:
        yield 1
        return 1
    else:
        x = yield from fib(n-2)
        y = yield from fib(n-1)
        yield x + y
        return x + y

[1, 2, 3, 5, 8, 13]

Now that we added a cache, you can see that we minimize the calls to our function. If you want to see cache hits, add a yield cache[n] after the if statement but before the return in the cache.

- to blog -

blog built using the cayman-theme by Jason Long. LICENSE