Skip to content

Latest commit

 

History

History
443 lines (352 loc) · 10.7 KB

03-04-stack-and-recursion.md

File metadata and controls

443 lines (352 loc) · 10.7 KB

Stack and recursion

When running a program, the interpreter (Lua in this case) keeps track of variables defined in a scope and which function you are currently in. It organizes this information into a list in memory called the stack. The first item in the stack is the starting point - the root of your application. Take the following example:

local two = function()
  print('two')
end

local one = function()
  print('one')
  two()
end

one()

When starting the program, the start of the stack is the top level of the module. The Lua stack calls this the "main chunk". When a function is invoked, another layer is added to the stack. Every time a function is called from another function, the stack continues to build. So with the example code above, The stack will follow the progression:

  • Stack is { "main chunk" }.
  • Now start executing one. Stack is { "main chunk", "one" }.
  • Now start executing two while still in one. Stack is { "main chunk", "one", "two" }.
  • two is done executing. Stack is now { "main chunk", "one" }.
  • one is done executing. Stack is now { "main chunk" }.
  • Program exits.

This can be visualized by throwing an error at any point the program. The interpreter will give you back a stack trace that details where it was when the problem occurred. Lua provides a helpful error function for debugging that we can use here:

local three = function()
  error('This is an error.')
end

local two = function()
  print('two')
  three()
end

local one = function()
  print('one')
  two()
end

one()

Unfortunately the REPL doesn't provide us with stack traces, but if you have a Lua interpreter on your computer (lua command, luajit, or LÖVE) you will see the error message and a stack trace like this:

lua: test.lua:2: This is an error.
stack traceback:
        [C]: in function 'error'
        test.lua:2: in upvalue 'three'
        test.lua:7: in upvalue 'two'
        test.lua:12: in local 'one'
        test.lua:15: in main chunk
        [C]: in ?

From the "stack traceback" you can see the newest from the top of the stack to the oldest on the bottom. In complex programs is can be very beneficial to see which function invoked another function to help trace down how an error came about.

Understanding the stack is beneficial for more than just reading errors. Let's switch the conversation over to something seemingly unrelated for a bit.

Recursion

When thinking of loops, many programmers first think of the for loop or the while loop. Another common method is to make a function call itself. Similar to the while loop, you can create infinite loops like this one

local loop
loop = function()
  print('hello!')
  loop()
end

When a function invokes itself, whether directly or indirectly, this is called recursion. The same function will recur again and again until a condition changes. Or in the case above, loop() will be called unconditionally. Without a condition, any kind of loop will run infinitely (or crash trying). Here's a loop that is a little safer to run:

local count_to_5
count_to_5 = function(current_number)
  print(current_number)
  if current_number < 5 then
    count_to_5(current_number + 1)
  end
end

count_to_5(1)

Which prints:

1
2
3
4
5

One quick little aside; Notice how the function was defined in both these situations:

local loop
loop = function()
  ...

The variable was defined before the function was created. Since the function needs to access the variable inside itself, the variable needs to exist at the time the function's scope is created. Variables created after the function are unknown to the function. This is discussed in 1.17 - Scopes and is a limitation of Lua's design. Fortunately there is shorthand syntax for writing recursive functions:

local function count_to_5(current_number)
  print(current_number)
  if current_number < 5 then
    count_to_5(current_number + 1)
  end
end

count_to_5(1)

is the same as writing:

local count_to_5
count_to_5 = function(current_number)
  ...

Let's try another recursive loop:

local grocery_list = {
  'pumpkin',
  'pecans',
  'butter',
  'flour',
  'sugar'
}

local function print_items(list, index)
  index = index or 1
  if index <= #list then
    print(list[index])
    print_items(list, index + 1)
  end
end

print_items(grocery_list)

Which prints the grocery list. Don't forget the local at the beginning of local function print_items, otherwise you will accidentally generate global variables in your code when trying to define functions.

We can even re-implement our map function from earlier to use recursion instead of a for loop.

local grocery_list = {
  'pumpkin',
  'pecans',
  'butter',
  'flour',
  'sugar'
}

local function map(orig_list, transform_fn, new_list)
  new_list = new_list or {}
  if #new_list < #orig_list then
    local index = #new_list + 1
    new_list[index] = transform_fn(orig_list[index], index)
    return map(orig_list, transform_fn, new_list)
  end
  return new_list
end

local new_list = map(grocery_list, function(value, index)
  return index .. '. ' .. value
end)

map(new_list, function(value)
  print(value)
  return value
end)

Which prints:

1. pumpkin
2. pecans
3. butter
4. flour
5. sugar

Stack overflow

So what does the stack look like during recursion when a function enters itself? Here's a script to test:

local function recur(n)
  -- assert is like error, but takes an expression to test. If the
  -- expression passed becomes false then it throws the error message.
  assert(n < 5, 'This is a conditional error')
  print(n)
  recur(n + 1)
end

recur(1)
lua: test2.lua:2: This is a conditional error
stack traceback:
        [C]: in function 'assert'
        test2.lua:2: in upvalue 'recur'
        test2.lua:4: in upvalue 'recur'
        test2.lua:4: in upvalue 'recur'
        test2.lua:4: in upvalue 'recur'
        test2.lua:4: in local 'recur'
        test2.lua:7: in main chunk
        [C]: in ?

Every time the function recurs we get another addition to the stack. This can be a problem if you are looping over a large set of data because the stack will consume more and more memory as it stacks up. This can be accomplished by creating a recursive loop that runs infinitely. If you haven't tried so already, here's an easy example:

local function recur()
  recur()
end

recur()

When the stack reaches a critical size, you get a stack overflow error:

lua: test3.lua:2: stack overflow
stack traceback:
        test3.lua:2: in upvalue 'recur'
        test3.lua:2: in upvalue 'recur'
        ...
        test3.lua:2: in upvalue 'recur'
        test3.lua:2: in upvalue 'recur'
        test3.lua:2: in local 'recur'
        test3.lua:5: in main chunk
        [C]: in ?

With a specific return statement added to the loop, however, we no longer get a stack overflow:

local function recur()
  return recur()
end

recur()

This will run until you manually kill the application process. Killing it returns a somewhat mysterious stack track:

lua: test4.lua:2: interrupted!
stack traceback:
        test4.lua:2: in function <test4.lua:1>
        (...tail calls...)
        test4.lua:2: in function <test4.lua:1>
        (...tail calls...)
        test4.lua:5: in main chunk
        [C]: in ?

So how did our modification save us from overflowing our stack?

Tail call optimization

Inside a function when you return another function call, the interpreter has the ability to re-use the same layer of the stack instead of creating another layer. This works with direct recursion (function calling itself) and indirect (mutual) recursion such as two functions calling each other:

local one
local two

one = function()
  return two()
end

two = function()
  return one()
end

one()

Programming in Lua goes into greater detail on when a recursion will or won't be optimized, but the simple thing to remember is that the function(s) must return the value of invoking a function for this to work. The following will be tail-call optimized:

local one
local two

one = function(n)
  print(n)
  return two(n + 1)
end

two = function(n)
  print(n)
  return one(n + 1)
end

-- Count until we run out of numbers
one(1)

But the following won't, since it returns an operation including the function call instead of just the function call itself:

local one
local two

one = function(n)
  print(n)
  return 1 + two(n)
end

two = function(n)
  print(n)
  return 1 + one(n)
end

-- This won't work!
one(1)

The case for recursive loops

So why would we want to do recursion? It seems trickier than a for loop and perhaps just as easy to mess up as a while loop.

It's not necessarily a replacement for the for loop, but allows you to do certain things you can't easily do without recursion. Take this example from Rosetta Code which will flatten a list of lists into a single, flat list. It uses a for loop and a recursive loop in conjunction with each other:

local function flatten(list)
  if type(list) ~= "table" then return {list} end
  local flat_list = {}
  for _, elem in ipairs(list) do
    for _, val in ipairs(flatten(elem)) do
      flat_list[#flat_list + 1] = val
    end
  end
  return flat_list
end

local test_list = {
  {1},
  2,
  {{3,4}, 5},
  {{{}}},
  {{{6}}},
  7,
  8,
  {}
}

print(table.concat(flatten(test_list), ","))

Which prints:

1,2,3,4,5,6,7,8

This function isn't tail-call optimized, but it probably won't be passed a nested list deep enough to cause a stack overflow.

Here's just a few of the many situations where recursion is usually the best tool for the job:

  • Sorting data
  • Searching trees (nested data) in a database or nested folders in a filesystem.
  • Finding the shortest path between two points
  • Loops that increment or decrement in irregular patterns
  • Evaluating a finite set of moves in a game like chess

The point isn't to replace the for loop, although you can. Take the following example, which returns the factorial of the given number (5):

local fact = function(n)
  local acc = 1
  for iteration = n, 1, -1 do
    acc = acc * iteration
  end
  return acc
end

print(fact(5))

The same functionality written with a recursive loop would look very different:

local function fact(n, acc)
  acc = acc or 1
  if n == 0 then
    return acc
  end
  return fact(n-1, n*acc)
end

...but one method wouldn't offer an advantage over the other here. Depending on the language you are working in, one method may be easier to read than the other. Maybe the language supports one type of loop and not the other. These are the factors that will often do the deciding for you.