User:Jeff/Algorithms/Memoization

Back to ACM page

Memoization is one of the most fundamental and useful techniques in computer programming. It is used on recursive functions by keeping a lookup table of return values, such that the function is only evaluated once for each set of parameters.

The Fibonnaci sequence
For example, take a program which computes the nth Fibonnaci number. The Fibonnaci sequence is 0, 1, 1, 2, 3, 5, 8, 13, etc., where each number is the sum of the previous two. The most straightforward way to program this is to actually create the sequence of numbers, and return the nth number calculated.

int getFibonnaci(int n) { vector fib(n + 1, 0); fib[1] = 1; for (int i(2); i <= n; ++i) fib[i] = fib[i-1] + fib[i-2]; return fib[n]; }

Another way to solve this problem is to use a recursive approach. Since our function will return any requested Fibonacci number, we can ask for the sum of the previous two terms directly, before actually calculating them.

int getFibRecursive(int n) { if (n == 0) return 0; if (n == 1) return 1; return getFibRecursive(n-1) + getFibRecursive(n-2); }

Let's take a closer look at how this program runs by evaluating getFibRecursive(4).

* getFib(4) = getFib(3) + getFib(2) +-* getFib(3) = getFib(2) + getFib(1) | +-* getFib(2) = getFib(1) + getFib(0) | | +-* getFib(1) = 1 | | +-* getFib(0) = 0 | +-* getFib(2) = 1 + 0 = 1 | +-* getFib(1) = 1 +-* getFib(3) = 1 + 1 = 2 +-* getFib(2) = getFib(1) + getFib(0) | +-* getFib(1) = 1 | +-* getFib(0) = 0 +-* getFib(2) = 1 + 0 = 1 * getFib(4) = 2 + 1 = 3

For an input of 4, our function gets called a total of 9 times. Note how getFib(2) is called twice, from the body of getFib(4) and getFib(3). Ideally, we would want to save (or cache) the value for getFib(2) the first time we calculate it, and return this value immediately the next time it's asked for. This is the principle behind memoization, and it is easy to implement in C++ using the map datatype.

int getFibMemoize(int n, map &mem) { if (mem.find(n) != mem.end) return mem[n]; if (n == 0) return mem[n] = 0; if (n == 1) return mem[n] = 1; return mem[n] = getFibMemoize(n-1, mem) + getFibMemoize(n-2, mem); }

The function now takes an extra parameter, "mem", which is the associative array of parameters and return values that we've already found. The ampersand signifies a pass by reference; we want to share the same database object with each function call, instead of the default action (which would make a copy of mem with every step). There are two new lines of code at the beginning of the function, which test to see if we've already know the value of getFibMemoize(n) and return it if found. Instead of simply returning the answer in the last line of code, we simultaneously save it into mem.

Memoizing an arbitrary function
Using the same technique presented in the Fibonacci example, we can memoize any recursive function. Say we have the following function header:

string myFunction(vector &data, int a, char b);

The first argument, data, is being passed by reference. Since it will be the same for each function call, we don't need to store its values in our cache. However, the other two parameters will presumably differ between calls, so we need to memoize on both variables at the same time. This can be achieved by use of the pair datatype:

pair key(a, b);

Now we can check to see if the combined value of "key" is in the map. An STL map needs two template parameters: the key type and the value type. The key type is pair, and the value type is the same as the return type of our function (in this case, string). So before our program calls myFunction for the first time, we would need to instantiate a map object like such:

map, string> mem;

In conclusion, there are four changes that must be made to a function to memoize it:
 * 1) Add the map to the function header as a pass-by-reference parameter.
 * 2) At the beginning of the function body, create a key from the applicable parameters. Check to see if the key exists in mem, and if so, return the previously calculated value.
 * 3) Replace all instances of "return x" to "return mem[key] = x".
 * 4) Finally, remember to change all recursive function calls so they pass along "mem" as a parameter as well.

Example
// Before memoization string myFunction(vector &data, int a, char b) { string ans; // ...  ans = myFunction(data, a - 1, 'w'); // ...  return ans; } int main { cout << myFunction(vector, 0, 'a'); return 0; } // After memoization string myFunction(vector &data, int a, char b,                  map, string> &mem) { pair key(a, b); if (mem.find(key) != mem.end) return mem[key]; string ans; // ...  ans = myFunction(data, a - 1, 'w', mem); // ...  return mem[key] = ans; } int main { map, string> mem; cout << myFunction(vector, 0, 'a', mem); return 0; }