Jekyll2023-04-11T15:55:15+00:00//Jeremie CoullonJeremie Coullon's blog: posts about programming, statistics, and jazzEarly Monte Carlo methods - Part 2: the Metropolis sampler2021-06-23T08:00:00+00:002021-06-23T08:00:00+00:00/2021/06/23/early_Monte_Carlo_Metropolis_1953<p>This post is the second of a two-part series on early Monte Carlo methods from the 1940s and 1950s. In my <a href="/2021/06/23/early_Monte_Carlo_1949_conference/">previous post</a> I gave an overview of Monte Carlo methods in the 40s and focused on the 1949 conference in Los Angeles. In this post I’ll go over the classic paper <a href="https://bayes.wustl.edu/Manual/EquationOfState.pdf"><em>Equation of State Calculations by Fast Computing Machines</em></a> (1953) by Nick Metropolis and co-authors. I’ll give an overview of the paper and its main results and then give some context about how it was written. I’ll then delve into some details of how the MANIAC computer worked to give an idea of what it must have been like to write algorithms such as MCMC on it.</p>
<h1 id="the-classic-metropolis-sampler-paper">The classic Metropolis sampler paper</h1>
<h2 id="overview-of-paper">Overview of paper</h2>
<p>The objective of the <a href="https://bayes.wustl.edu/Manual/EquationOfState.pdf">paper</a> is to estimate properties (in this case pressure) of a system of interacting particles. The system of study consists of \(N\) particles (which they model as hard disks) in a 2D domain. They mention however that they are working on a 3D problem as well. The potential energy of the system is given by:</p>
\[E = \frac{1}{2} \sum_i^N \sum^N_{j, i\neq j} V(d_{ij})\]
<p>Here \(d_{ij}\) is the distance between molecules and \(V\) is the potential between molecules.</p>
<p>It seems that a usual thing to do up to the 1950s to estimate properties for such complicated systems was to approximate them analytically. The paper compares results from MCMC with two standard approximations. They also mention that an alternative method would have to use ordinary Monte Carlo by sampling many particle configurations uniformly and weighing them by their probability (namely by their energy). The problem with this approach is that the weights would all be essentially zero if many particles are involved. Their solution is to rather do the opposite: sample particle configurations based on their probabilities and then take an average (with uniform weights).</p>
<p>After an overview of the problem they introduce the Metropolis sampler: for each particle suggest a new position using a uniform proposal and allow this move with a certain probability. Interestingly, this sampler would be called these days a Metropolis-within-Gibbs sampler or a single-site Metropolis sampler as each particle is updated one at a time. In figure 1 we see the description of the algorithm from the paper.</p>
<p> </p>
<figure class="post_figure">
<img src="/assets/early_Monte_Carlo/Metropolis_method.png" />
<figcaption>Figure 1: Description of the Metroplis sampler</figcaption>
</figure>
<p>The algorithm was coded up to run on the MANIAC computer, and took around 3 minutes to update all 242 particles (which is obviously slow by today’s standards). Note that they used the Middle Square method to generate the uniform random numbers in the proposal distribution and the accept-reject step. This random number generator has some issues but is fast and therefore much more convenient than reading in random numbers from a table (this method was introduced in the 1949 conference and is discussed in my <a href="/2021/06/23/early_Monte_Carlo_1949_conference/">previous post</a>).</p>
<p>They then justify the new sampler by giving an argument for how it will converge to the target distribution: they show that the system is ergodic and that detailed balanced is satisfied. Finally, they run experiments and estimate the pressure of a system of particles. They compare the results to two standard analytic approximations, and find that the MCMC results agree with the approximations in the parameter region where they are known to be accurate.</p>
<h2 id="discussion">Discussion</h2>
<p>This paper explains clearly and simply their new sampler, includes some theoretical justification, and have experiments to show that the method can work in practice. One thing I’m not clear about is that they don’t have access to the ”ground truth”, so it’s not completely clear how we know that the MCMC results are correct. However they do explain how the analytic approximations diverge from the MCMC results exactly in the parameter regions where those approximations are expected to break down.</p>
<p>Another point is that they include some discussion of the Monte Carlo error, but they seem to compute the error using the variance of samples and not correct for the correlation between samples. We now know that we must calculate the <a href="https://dfm.io/posts/autocorr/">integrated autocorrelation time</a> and use it to find the <a href="https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html">effective sample size</a>. So a nitpick of the paper would be that their Monte Carlo error estimate is too small! We’ll have to wait until <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.452.6839&rep=rep1&type=pdf">Hasting’s paper</a> (see section 2) in 1970 for a rigorous treatment of the variance of estimates from correlated samples.</p>
<p>Finally, they use 16 cycles as burn-in (one cycle involves updating all the particles) and then run 48 to 64 cycles to sample configurations (they run different experiments). I haven’t reproduced this experiment to see the trace plots but intuitively this seems fairly low. However they were limited by the computing power available to them and this must have been enough to get the estimates they were looking for.</p>
<p>An interesting <a href="https://www.dsf.unica.it/~fiore/GoR.pdf">article</a> from 2004 includes the recollections from Marshall Rosenbluth (one of the authors) where he explained the contributions of each of the authors of the paper. It turns out that he and Arianna Rosenbluth (his wife) did most of the work. More specifically, he did the mathematical work, Arianna wrote the code that ran on the MANIAC, August Teller wrote an earlier version of the code, and Edward Teller gave some critical suggestions about the methodology. Finally, Nick Metropolis provided the computing time; as he was first author the method is therefore named after him. But perhaps a more appropriate name would have been the Rosenbluth algorithm.</p>
<h1 id="maniac">MANIAC</h1>
<p>A lot of the early Monte Carlo methods were intimately linked with the development of modern computers, such as the <a href="https://en.wikipedia.org/wiki/ENIAC">ENIAC</a> and the <a href="https://en.wikipedia.org/wiki/MANIAC_I">MANIAC 1</a> (which was used in the Metropolis paper). It therefore makes sense to look into some detail how these computers worked to see what it must have been like to write code for them. We’ll look into some sections from the <a href="http://www.bitsavers.org/pdf/lanl/LA-1725_The_MANIAC_Jul54.pdf">MANIAC’s documentation</a> to get a taste for how this computer worked.</p>
<h2 id="historical-context">Historical context</h2>
<p>The MANIAC computer was a step in a series of progressively faster and more reliable computers. Its construction started in 1946, was operational in 1952 (just in time for the 1953 Metropolis paper), and shut down in 1958. It was used to work on a range of problems such as PDEs, integral equations, and stochastic processes. It was succeeded by the MANIAC II (built in 1957) and the MANIAC III (1964) which were faster and easier to use. To give some context, Fortran came out in 1957, Lisp in 1958, and Cobol in 1959. So code written for the MANIAC was not at all portable; you had to learn how to program for this specific machine.</p>
<h2 id="arithmetic">Arithmetic</h2>
<p>We start by looking at the introduction (often a nice place to start) of the <a href="http://www.bitsavers.org/pdf/lanl/LA-1725_The_MANIAC_Jul54.pdf">documentation</a> (page 6 in the pdf). We find that numbers are represented as binary digits (which was to be expected). Note that they use the word <em>bigit</em> to mean a binary digit; it’s perhaps a shame that this term didn’t stick. I’ll use it throughout the rest of the post as I like it.</p>
<p>The storage capacity of the MANIAC is:</p>
<ul>
<li>1 sign bigit</li>
<li>39 numerical bigits</li>
</ul>
<p>We then put a decimal point <em>before</em> the first numerical bigit. The number <code class="highlighter-rouge">0.2</code> (in decimal) would then be represented on the MANIAC as <code class="highlighter-rouge">0.11</code> (binary). Note that this means that numbers can only be between \(-1\) and \(1\). So if your program will generate numbers outside this range you must either scale the numbers before doing the calculations or adjust the magnitudes of numbers on the fly.</p>
<h2 id="negative-numbers">Negative numbers</h2>
<p>We now consider negative numbers on the MANIAC (page 7 of the pdf). The natural thing to do to generate a negative number would be to have the sign bigit be <code class="highlighter-rouge">0</code> for a positive number and <code class="highlighter-rouge">1</code> for a negative number (or vice versa). But the MANIAC represents the negative of the number \(x\) as the complement of \(x\) with respect to \(2\), namely:</p>
\[c = 2 - |x|\]
<p>As \(0 < x < 1\) we have \(1 < c < 2\). This means that the sign bigit will always be <code class="highlighter-rouge">1</code> for negative numbers, and that all the numerical bigits of \(c\) will be the bigits of \(x\) flipped.</p>
<p>To illustrate this, suppose that \(x = -.101110101...011\) (in binary). This means that it’s representation on the MANIAC will be <code class="highlighter-rouge">c = 1.010001010...101</code>. Note that the sign bigit is <code class="highlighter-rouge">1</code>, and that all the digits are flipped with the exception of the last one (which is always <code class="highlighter-rouge">1</code>). You can check that this is the case by calculating the difference \(2-x\) in binary, namely <code class="highlighter-rouge">10.000...000 - 0.101110101...011</code>.</p>
<p>This way of representing negative numbers may feel convoluted, but taking the complement of a number is easy to do on the MANIAC; you simply flip the numerical bigits and change the sign bigit.</p>
<h2 id="subtraction">Subtraction</h2>
<p>The benefits of writing negative numbers in this weird way become even more apparent when we consider subtraction. To subtract \(b\) from \(a\) you simply add \(a\) to the <em>complement of \(b\)</em>. This means that \(a-b\) becomes \(a + (2-b)\) (assuming that \(a,b>0\)).</p>
<p>Let’s check that this works for \(a,b>0\). We have to consider two cases: \(a>b\) and \(a<b\).</p>
<h3 id="first-case-ab">first case: \(a>b\)</h3>
<p>The first case to consider is when \(a>b\). As both numbers are between \(0\) and \(1\) we have:</p>
\[\begin{align}
0 &< a - b < 1 \\
2 &< 2 + (a - b) < 3
\end{align}\]
<p>We represent the number \(2 + (a-b)\) in binary as <code class="highlighter-rouge">10.<numerical bigits>..</code>. However the leftmost bigit is outside the capacity of the computer so is dropped (which subtracts \(2\) from the result), and you end up with the correct number: \(a-b\).</p>
<h3 id="second-case-ab">second case: \(a<b\)</h3>
<p>In this case we have:</p>
\[\begin{align}
-1 &< a - b < 0 \\
1 &< 2 + (a - b) < 2 \\
1 &< 2 - (b - a) < 2
\end{align}\]
<p>The term \(2- (b-a)\) is simply the complement of \(b-a\), namely “negative” \(b-a\). This is also exactly the result we wanted!</p>
<p>The other cases - as recommended in the documentation (see figure 2 below) - are left as exercises to the reader.</p>
<figure class="post_figure">
<img src="/assets/early_Monte_Carlo/subtraction_left_as_exercises.png" />
<figcaption>Figure 2: the documentation's recommendation </figcaption>
</figure>
<p>So subtraction reduces to flipping some bigits and doing addition.</p>
<h3 id="modular-arithmetic">Modular arithmetic</h3>
<p>Finally, this way of doing subtraction doesn’t feel so weird anymore if we think about <a href="https://en.wikipedia.org/wiki/Modular_arithmetic">modular arithmetic</a>. For example if we are working modulo \(10\) and want to subtract \(3\) from a number, we can simply add the complement of \(3\) which is \(10-3=7\). Namely:</p>
\[\begin{align}
- 3 &\equiv 7 \hspace{2mm} (10) \\
x - 3 &\equiv x + 7 \hspace{2mm} (10) \\
\end{align}\]
<p>This is exactly how the MANIAC does subtraction, but modulo \(2\) rather than \(10\).</p>
<h2 id="writing-programs">Writing programs</h2>
<p>The programmer had to be very comfortable with this way of representing numbers and be able to write complicated numerical methods such as MCMC or PDE solvers. They would describe every step of the program on a “computing sheet” before representing them on a punch card which was finally fed into the MANIAC. Figure 3 shows a sample from the documentation which describes the instructions need to implement subtraction.</p>
<figure class="post_figure">
<img src="/assets/early_Monte_Carlo/maniac_difference_code.png" />
<figcaption>Figure 3: The computing sheet for subtraction</figcaption>
</figure>
<p>The documentation then goes on to describe more complicated operations, such as taking square roots, as well as the basics of MANIAC’s internal design.</p>
<p>In conclusion, coding was complicated and required working with engineers as the machines broke down regularly. The only built-in operators were (+), (-), (*), (%), and inequality/equalities. Figure 4 (below) is a sample from the documentation and shows how this computer was an improvement from previous versions: it was so simple that even a programmer could turn in on and off!</p>
<figure class="post_figure">
<img src="/assets/early_Monte_Carlo/maniac_turn_on_off.png" />
<figcaption>Figure 4: an intuitive UI (page 291)</figcaption>
</figure>
<h1 id="conclusion">Conclusion</h1>
<p>The development of Monte Carlo methods took off in the 1940s, and this work was closely tied to the progress of computing. The usual applications at the time were in particle and nuclear physics, fluid dynamics, and statistical mechanics. In this post we went over the classic 1953 paper on the Metropolis sampler and went over some details of the MANIAC computer. We saw how the computers used in scientific research were powerful (for the era), but were complicated machines that required a lot of patience and detailed knowledge to use.</p>
<p><em>Thanks to <a href="https://users.monash.edu.au/~gmartin/">Gael Martin</a> and <a href="https://warwick.ac.uk/fac/sci/statistics/staff/academic-research/robert/">Christian Robert</a> for useful feedback on this post</em></p>This post is the second of a two-part series on early Monte Carlo methods from the 1940s and 1950s. In my previous post I gave an overview of Monte Carlo methods in the 40s and focused on the 1949 conference in Los Angeles. In this post I’ll go over the classic paper Equation of State Calculations by Fast Computing Machines (1953) by Nick Metropolis and co-authors. I’ll give an overview of the paper and its main results and then give some context about how it was written. I’ll then delve into some details of how the MANIAC computer worked to give an idea of what it must have been like to write algorithms such as MCMC on it.Early Monte Carlo methods - Part 1: the 1949 conference2021-06-23T08:00:00+00:002021-06-23T08:00:00+00:00/2021/06/23/early_Monte_Carlo_1949_conference<p>This post is the first of a two-part series on early Monte Carlo methods from the 1940s and 1950s. Although Monte Carlo methods had been previously hinted at (for example <a href="https://en.wikipedia.org/wiki/Buffon%27s_needle_problem">Buffon’s needle</a>), these methods only started getting serious attention in the 1940s with the development of fast computers. In 1949 a conference on Monte Carlo was held in Los Angeles in which the participants - many of them legendary mathematicians, physicists, and statisticians - shared their work involving Monte Carlo methods. It can be enlightening to review this research today to understand what issues these researchers were grappling with. In this post I’ll give a brief overview on computing and Monte Carlo in the 1940s, and then discuss some of the talks from this 1949 conference. In the <a href="/2021/06/23/early_Monte_Carlo_Metropolis_1953/">next post</a> I’ll go over the classic 1953 paper on the Metropolis sampler and give some details of the MANIAC computer.</p>
<h1 id="historical-context-monte-carlo-in-the-1940s">Historical context: Monte Carlo in the 1940s</h1>
<h2 id="computing-during-world-war-2">Computing during World War 2</h2>
<p>During WW2, researchers at Los Alamos worked on building the atomic bomb which involved solving nonlinear problems (for example in fluid dynamics and neutron transport). As the project advanced they used a series of increasingly powerful computers to do this. Firstly, they used mechanical desk calculators; these were standard calculators at the time, but were slow and unreliable. Then then upgraded to using electro-mechanical business machines which were bulkier but faster. One application was to solve partial differential equations (PDEs). To solve a PDE numerically, one punch card was used to represent each point in space and time, and a set of punch card would represent the state of the system at some point in time. You would pass each set of card through several machines to solve the PDE forward in time. The machines would often break down and would have to be fixed which made doing any calculation laborious. Finally, John von Neumann recommended the scientists use the ENIAC computer which was able to do calculations an order of magnitude faster. You can find more details on computing at Los Alamos in this <a href="https://permalink.lanl.gov/object/tr?what=info:lanl-repo/lareport/LA-UR-83-5073">report</a> by Francis Harlow and Nick Metropolis.</p>
<h2 id="monte-carlo-in-the-40s">Monte Carlo in the 40s</h2>
<p>A well known story about early Monte Carlo methods is about Stan Ulam playing solitaire while he was ill (see <a href="http://www-star.st-and.ac.uk/~kw25/teaching/mcrt/MC_history_3.pdf">this article</a> by Roger Eckhardt for more details). He wondered what the probability was of winning at solitaire if you simply laid out a random configuration. After trying to calculating this mathematically, he then considered laying out many random configurations and simply counting how many of them were successful. This is a simple example of the Monte Carlo method: using randomness to estimate a non-random quantity. Ulam - along with John von Neumann - then starting thinking how to apply this method to solve problems involving neutron diffusions. Von Neumann was important in developing new Monte Carlo methods and new computers as well as giving legitimacy to these new fields; that such a well respected scientist was interested in these things helped these fields become respectable among other scientists. At that time Monte Carlo methods mainly seemed to spread by word of mouth. For example we can see in figure 1 an extract from a letter written by von Neumann to Ulam in 1947 describing rejection sampling. You can see more of the letter in <a href="http://www-star.st-and.ac.uk/~kw25/teaching/mcrt/MC_history_3.pdf">Eckhardt’s article</a>.</p>
<figure class="post_figure">
<img src="/assets/early_Monte_Carlo/rejection_sampler_letter.png" />
<figcaption>Figure 1: sample of von Neumann's letter to Ulam describing rejection sampling.</figcaption>
</figure>
<h1 id="the-1949-conference-on-monte-carlo">The 1949 Conference on Monte Carlo</h1>
<p>In 1949 a conference was held in Los Angeles to discuss recent research on Monte Carlo. This was the first time a lot of these methods were published: the speakers gave talks and some of these were written up in a <a href="https://books.google.co.uk/books?id=4CJUylwIOGAC&printsec=frontcover&source=gbs_ge_summary_r&cad=0#v=onepage&q&f=false">report</a>. The attendees of this conference included a lot of famous mathematicians, statisticians, and physicists such as <a href="https://en.wikipedia.org/wiki/John_von_Neumann">von Neumann</a>, <a href="https://en.wikipedia.org/wiki/John_Tukey">Tukey</a>, <a href="https://en.wikipedia.org/wiki/Alston_Scott_Householder">Householder</a>, <a href="https://en.wikipedia.org/wiki/John_Wishart_(statistician)">Wishart</a>, <a href="https://en.wikipedia.org/wiki/Jerzy_Neyman">Neyman</a>, <a href="https://en.wikipedia.org/wiki/William_Feller">Feller</a>, <a href="https://en.wikipedia.org/wiki/Ted_Harris_(mathematician)">Harris</a>, and <a href="https://en.wikipedia.org/wiki/Mark_Kac">Kac</a>. The speakers introduced such sampling methods as the rejection sampler, the middle square method, and splitting schemes for rare event sampling (these three methods were actually suggested by von Neumann). There were also talks about applications in such areas as particle and nuclear physics as well as statistical mechanics.</p>
<p>We’ll now focus on two of the talks, both presenting techniques invented by von Neumann.</p>
<h3 id="the-middle-square-method">the Middle Square Method</h3>
<p>Von Neumann introduced the Middle Square method, which was a way to quickly generate pseudorandom numbers. A standard way at the time of obtaining random numbers was to generate a list of uniform random numbers using some kind of physical process, then use that list of random numbers in calculations. However when using computers such as the ENIAC or MANIAC, reading this list into the computer was much too slow. It was therefore necessary to generate these random numbers on the fly, even if that meant obtaining random samples of lower quality.</p>
<p>The method works as follows:</p>
<ul>
<li>Start from an integer with \(n\) digits (with \(n\) even):</li>
<li>Square the number</li>
<li>If the square has \(2n-1\) digits, add a leading zero (to the left of it) so that the new number has \(n\) digits</li>
<li>keep the \(n\) digits in the middle of the number.</li>
</ul>
<p>For example, if we start from the seed \(a_0 = 20\) (2 digits) the algorithm runs as follows:</p>
<ul>
<li>\(20^2=0400\), so \(a_1=40\)</li>
<li>\(40^2=1600\), so \(a_2 = 60\)</li>
<li>Then \(a_i = 60\) for \(i>2\). Here the process has converged and will only generate the number \(60\) from now on.</li>
</ul>
<p>So we notice that we need to use a carefully chosen <em>large</em> seed (namely, at least 10 digits). If we don’t choose a good seed the numbers might quickly converge and the algorithm will keep on generating the same number for ever. A lot of careful work was therefore done on finding good seeds and doing statistical tests on the generated numbers to assess their quality. One of the talks in the conference discusses some tests done on this algorithm (see talk number 12 by <a href="https://en.wikipedia.org/wiki/George_Forsythe">George Forsythe</a>).</p>
<p>Obviously today we have much better PRNGs, so we should use these modern methods and not the middle square process. But it is interesting to see some of the early work on PRNGs.</p>
<h3 id="rejection-sampling">Rejection sampling</h3>
<p>Von Neumann also introduce rejection sampling (though we saw in figure 1 that he had developed this method several years earlier). This talk is called <em>“Various techniques used in connection with random digits”</em> and a pdf of the talk (without the rest of the conference) can be found <a href="https://dornsifecms.usc.edu/assets/sites/520/docs/VonNeumann-ams12p36-38.pdf">here</a>. There also seems to be several dates attributed to this paper (see this <a href="https://math.stackexchange.com/questions/186626/finding-a-paper-by-john-von-neumann-written-in-1951">stackexchange question</a> about how to properly cite it). In the talk von Neumann first reviews methods used to generate uniform random numbers, then introduces the rejection sampling method and how to use it with a few examples.</p>
<p><strong>Generating uniform random numbers</strong></p>
<p>Von Neumann starts by considering that a physical process (such as nuclear process) can be used to generate high quality random numbers, and that a device could be built that would generate these as needed. However he points out that it would be impossible to reproduce the results (if you need to debug the program for example) as there would be no random seed that would reproduce the same random numbers. He concludes this by saying:</p>
<blockquote>
<p>“I think that the direct use of a physical supply of random digits is absolutely inacceptable for this reason and for this reason alone”.</p>
</blockquote>
<p>In light of this comment it is interesting to consider how most of modern research code involving random numbers does <em>not</em> keep track of the random seed used, and is therefore not reproducible in this sense (namely, in terms of the realisation of the random variables used).</p>
<p>He then considers generating random numbers using a physical process and printing out the results (this is discussed in another talk in the conference); one can then read in the random numbers to the computer as needed. However he points out that reading in numbers to a computer is the slowest aspect of it (in more modern terms, the problem would be <a href="https://en.wikipedia.org/wiki/I/O_bound">I/O bound</a>).</p>
<p>Finally, he concludes that using arithmetic methods such as the Middle Square method is a nice practical alternative that is both fast and reproducible. However one needs to check whether a random seed will generate random samples of sufficiently high quality. It is here that he famously says:</p>
<blockquote>
<p>“Any one who considers arithmetic methods of producing random digits is, of course, in a state of sin. For, as has been pointed out several times, there is no such thing as a random number - there are only methods to produce random numbers, and a strict arithmetic procedure of course is not such a method”</p>
</blockquote>
<p>However he takes the very practical approach of recommending we use such arithmetic methods and simply test the generated samples to make sure they are good enough for applications.</p>
<p><strong>The rejection method</strong></p>
<p>In the second part of the talk he considers the problem of generating non-uniform random numbers following some distribution \(f\), namely how to sample \(X \sim f\). This requires using uniformly distributed random numbers (generated using the middle square method for example) and transforming them appropriately.</p>
<p>He first considers using the <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">inverse transform</a> but considers this to be too inefficient. He then introduces the rejection method which is as follows:</p>
<ul>
<li>choose a scaling factor \(a \in \mathcal{R^+}\) such that \(af(x) \leq 1\)</li>
<li>sample \(X, Y \sim \mathcal{U}(0,1)\)</li>
<li>Accept \(X\) if \(Y \leq af(X)\). Reject otherwise</li>
</ul>
<p>This last step corresponds to accepting \(X\) with probability \(af(X)\) (which is between \(0\) and \(1\)).</p>
<p>Note that the more modern version of this method considers a general proposal distribution \(g\) (see these <a href="https://people.eecs.berkeley.edu/~jordan/courses/260-spring10/lectures/lecture17.pdf">lecture notes</a> or any textbook on Monte Carlo):</p>
<ul>
<li>let \(l(x) = cf(x)\) be the un-normalised target density (\(f\) is the target, \(c\) is unknown)</li>
<li>Let \(M \in \mathcal{R}\) be such that \(Mg(x) \geq l(x)\)</li>
<li>draw \(X \sim g\) and compute: \(r = \frac{l(X)}{Mg(X)}\)</li>
<li>Accept \(X\) with probability \(r\)</li>
</ul>
<p>Here we sample \(X\) from a general distribution \(g\) and similarly choose a constant \(M\) such that \(r\) is between \(0\) and \(1\).</p>
<p>Von Neumann then gives several examples such as how to sample from the exponential distribution by computing \(X = - \log(T)\) with \(T \sim Uni(0,1)\). However he considers it silly to generate random number and then plug them into a complicated power series to approximate the logarithm function. He therefore gives an alternative method to transforming the random number \(T\) by only taking simple operations rather than the logarithm.</p>
<h3 id="round-table-discussion">Round table discussion</h3>
<p>The conference ended with a round table discussion led by <a href="https://en.wikipedia.org/wiki/John_Tukey">John Tukey</a> which was documented in section 14 of the <a href="https://books.google.co.uk/books?id=4CJUylwIOGAC&printsec=frontcover&source=gbs_ge_summary_r&cad=0#v=onepage&q&f=false">report</a>. This was a mix of prepared statements as well as discussions between the participants.</p>
<p>To open the discussion, <a href="https://en.wikipedia.org/wiki/Leonard_Jimmie_Savage">Leonard Savage</a> read from <a href="https://en.wikipedia.org/wiki/Parallel_Lives">Plutarch’s Lives</a>; in particular a passage that talks about the <a href="https://en.wikipedia.org/wiki/Siege_of_Syracuse_(213%E2%80%93212_BC)">siege of Syracuse</a> where Archemedes built machines to defend the city. The passage discusses how Archemedes would not have built these applied machines if the king hadn’t asked him to; these machines were the <em>“mere holiday sport of a geometrician”</em>. Savage then reads about <a href="https://en.wikipedia.org/wiki/Archytas">Archytas</a> and <a href="https://en.wikipedia.org/wiki/Eudoxus_of_Cnidus">Eudoxus</a> - friends of Plato - who developed mathematical mechanics. Plato was apparently not happy about this as this development as:</p>
<blockquote>
<p>[mechanics was] destroying the real excellence of geometry by making it leave the region of pure intellect and come within that of the senses and become mixed up with bodies which require much base servile labor.</p>
</blockquote>
<p>As a result mechanics was separated from geometry and considered among the military arts. These passages illustrate how applied research and engineering have always been looked down upon by many theoretical researchers. In 1949 Monte Carlo methods and scientific computing were just getting developed so this would indeed have been the case.</p>
<p>Another interested portion of this discussion was given by <a href="https://en.wikipedia.org/wiki/John_Wishart_(statistician)">John Wishart</a>: he pointed out that he was impressed by the interactions between physicists, mathematicians, and statisticians. He considered that the different groups would be able to learn from each other. He also gave stories of how <a href="https://en.wikipedia.org/wiki/Karl_Pearson">Karl Pearson</a> was very practically minded and would regularly using his “hand computing machine” to solve integral which would help with his research. Pearson and his student <a href="https://en.wikipedia.org/wiki/L._H._C._Tippett">Leonard Tippett</a> also generated long lists of random numbers to use in research: these lists would allow them to estimate the sampling distribution of some statistics they were studying.</p>
<p>The rest of the discussion goes over different practical problems and the benefits of having interactions between statistics and physics.</p>
<h1 id="thoughts-on-this-conference">Thoughts on this conference</h1>
<p>There seemed to be a strong focus in the conference on practical calculations and experiments. The reading by Leonard Savage at the beginning of the round table discussion seems to reflect the general tone of the conference of being equally comfortable dealing with data and computers as well as maths and theory. Indeed computers at the time were unreliable and regularly broke down so researchers had to be very comfortable with engineering. Von Neumann’s remarks on pseudo random numbers also shows a very practical mindset of using “whatever works” rather than trying to find the “perfect” random number generator.</p>
<p>I also noticed that throughout the talks there was very little mention of Monte Carlo <em>error</em>. Researchers in the 40s had <a href="https://en.wikipedia.org/wiki/Central_limit_theorem#History">known for a long time</a> about the central limit theorem (CLT), but I didn’t find any explicit link between a Monte Carlo estimate (which is simply an average of samples) and CLT which would have given them an error bar (perhaps this was too obvious to mention?). The main mention of Monte Carlo error I found was by Alston Householder who - in his talk - gives estimates along with the standard errors (page 8). The only other hint of this that I found is the sentence by Ted Harris in the discussion at the end of the conference where he is talking about two different ways of obtaining Monte Carlo samples to estimate some quantity:</p>
<blockquote>
<p>… we know that if we do use the second estimate instead of the first and if we do it for long enough then we will come out with exactly the right answer. I am leaving aside the question of the relative variability of the two estimates.</p>
</blockquote>
<p>My guess is that either the CLT was too obvious to mention, or that it was hard enough to simply estimate these quantities without also computing error bars. In conclusion, I would recommend reading through some of the talk in the <a href="https://books.google.co.uk/books?id=4CJUylwIOGAC&printsec=frontcover&source=gbs_ge_summary_r&cad=0#v=onepage&q&f=false">report</a>, in particular the round table discussion.</p>
<h3 id="conclusion">Conclusion</h3>
<p>We saw how Monte Carlo methods took off in the 1940s and that a lot of these new methods and applications were presented in the 1949 conference. In my <a href="/2021/06/23/early_Monte_Carlo_Metropolis_1953/">next post</a> I’ll go over the classic 1953 paper on the Metropolis sampler and give some details of the MANIAC computer.</p>
<p><em>Thanks to <a href="https://users.monash.edu.au/~gmartin/">Gael Martin</a> and <a href="https://warwick.ac.uk/fac/sci/statistics/staff/academic-research/robert/">Christian Robert</a> for useful feedback on this post</em></p>This post is the first of a two-part series on early Monte Carlo methods from the 1940s and 1950s. Although Monte Carlo methods had been previously hinted at (for example Buffon’s needle), these methods only started getting serious attention in the 1940s with the development of fast computers. In 1949 a conference on Monte Carlo was held in Los Angeles in which the participants - many of them legendary mathematicians, physicists, and statisticians - shared their work involving Monte Carlo methods. It can be enlightening to review this research today to understand what issues these researchers were grappling with. In this post I’ll give a brief overview on computing and Monte Carlo in the 1940s, and then discuss some of the talks from this 1949 conference. In the next post I’ll go over the classic 1953 paper on the Metropolis sampler and give some details of the MANIAC computer.Ensemble samplers can sometimes work in high dimensions2021-02-26T08:00:00+00:002021-02-26T08:00:00+00:00/2021/02/26/functional_ensemble_sampler<p>A few years ago <a href="https://bob-carpenter.github.io/">Bob Carpenter</a> wrote a fantastic <a href="https://statmodeling.stat.columbia.edu/2017/03/15/ensemble-methods-doomed-fail-high-dimensions/">post</a> on how why ensemble methods cannot work in high dimensions. He explains how high dimensions proposals that interpolate and extrapolate among samples are unlikely to fall in the typical set, causing the sampler to fail. During my PhD I worked on sampling problems with infinite-dimension parameters (ie: functions) and intractible gradients. After reading Bob’s post I always wondered if there was a way around this problem.</p>
<p>I finally got round to working on this idea and wrote a <a href="https://link.springer.com/article/10.1007/s11222-021-10004-y">paper</a> (<a href="https://arxiv.org/abs/2010.15181">arXiv</a>) with <a href="https://cims.nyu.edu/~rw2515/">Robert J Webber</a> about an ensemble method for function spaces (ie: infinite dimensional parameters) which works very well. This sampler is aimed at sampling functions when gradients are unavailable (due to a black-box code base or discontinuous likelihoods for example).</p>
<p>So does that mean that Bob was wrong about ensemble samplers? Of course not; it rather turns out that not all high dimensional distributions are the same. Namely: there is sometimes a low-dimensional subset of parameter space that represents all the “interesting bits” of the posterior that you can focus on.</p>
<p>After giving a problem to motivate function space samplers, I’ll introduce the functional ensemble sampler (FES). I’ll then discuss what this means for using gradient-free samplers such as ensemble samplers in high dimensional spaces. You can skip directly to the <a href="#discussion">discussion section</a> for a reply to Bob’s post, which can be read independently of the description of the FES algorithm.</p>
<h1 id="functional-ensemble-sampler">Functional ensemble sampler</h1>
<h2 id="a-motivational-problem">A motivational problem</h2>
<p>Consider the 1D advection equation, a linear hyperbolic PDE. This PDE models how a quantity - such as the density of fluid - propagates through 1D domain over time. This PDE is a special case of more complicated nonlinear PDEs that arise in fluid dynamics, acoustic, motorway traffic flow, and many more applications. The equation is given as follows (with subscripts denoting partial differentiation):</p>
\[\rho_t + c \rho_x = 0\]
<p>Here \(\rho\) is density and \(c \in \mathcal{R}\) is the wave speed of the fluid. We define the initial condition \(\rho_0 \equiv \rho_0(x)\) to be the state of density of the fluid at time \(t=0\). This linear PDE simply advects the initial condition to the right or left of the domain depending on the sign of \(c\).</p>
<p>So the solution can be written as \(\rho(x,t) = \rho_0(x-ct)\). The following figure shows how an initial density profile is advected (ie: transported) to the right with wavespeed \(c=1\).</p>
<figure class="post_figure post_figure_larger">
<img src="/assets/FES_post/advection_solution.png" />
<figcaption>Figure 1: The left-most panel shows the initial condition. The next 2 panels show it being transported to the right with speed c=1</figcaption>
</figure>
<p>An inverse problem might then be: given some noisy observations of <em>flow</em> at a few locations (these could correspond to detectors along a pipe for example), recover the initial condition as well as the wave speed of the fluid. Note that flow is the product of density and speed: \(q = \rho c\). Using this relation and the solution of the PDE above, we can write the data model for a detector at location \(x\) measuring flow at time \(t\):</p>
\[q(x, t) = c\rho_0(x-ct) + \xi\]
<p>with \(\xi \sim \mathcal{N}(0,1)\) the observational noise.</p>
<p><!-- so we're observing flow and inferring the latent parameters $$\rho_0$$ (a function) and $$c$$ (a scalar). --></p>
<p>To solve this inverse problem, we need to set a prior on both parameters, so we choose a uniform prior for the wave speed, and a Gaussian Process (GP) prior for the initial condition.</p>
<h2 id="sampling-from-function-spaces">Sampling from function spaces</h2>
<h3 id="pcn">pCN</h3>
<p>The most basic gradient free sampler defined on function space is preconditioned Crank Nicholson (pCN) (see <a href="https://arxiv.org/abs/1202.0709">this paper</a> for an overview). This sampler makes the following simple proposal, with \(u\) current MCMC sample, \(\beta \in (0,1)\) and \(\xi \sim \mathcal{N}(0, \Sigma_0)\) a sample <em>from the prior</em>:</p>
\[\tilde{u} = \sqrt{1-\beta^2}u + \beta \xi\]
<p>The acceptance rate for this sampler is <em>independent of dimension</em> so is well suited for sampling function (though in practice they’re discretisations of functions). However this sampler can mix slowly if the posterior is very different from the prior, for example if some of the components of the function are very correlated or multimodal.</p>
<h3 id="fes">FES</h3>
<p>The idea of the functional ensemble sampler is to do an eigenfunction expansion of the prior and use that to isolate a low-dimensional subspace that includes the “difficult bit” of the posterior. This means that we can represent functions \(u\) in this space by truncating the eigenexpansion to \(M\) basis elements: \(u_+ = \sum^{M} u_j \phi_j\). This subspace might have very correlated components, be nonlinear, and more generally be difficult to sample from.</p>
<p>Functions in the rest of the space (ie: the complementary subspace) can be represented by \(u_- = \sum_{M+1}^{\infty} u_j \phi_j\) and are assumed to look like the prior. We can therefore alternate using a finite dimensional sampler (we’ll use the <a href="https://msp.org/camcos/2010/5-1/camcos-v5-n1-p04-p.pdf">affine invariant ensemble sampler</a>) to sample from this space, and use pCN to sample from the complementary subspace.</p>
<p>So the functional ensemble sampler is a Metropolis-within-Gibbs algorithm that alternates sampling from the low dimensional space using AIES and the complementary space using pCN. You can find more detail of the algorithm in the <a href="https://arxiv.org/pdf/2010.15181.pdf">paper</a> (see algorithm 1).</p>
<h3 id="performance">Performance</h3>
<p>We go back to the advection equation and try out the algorithm.
Given the equation \(\rho_t + c\rho_x = 0\), the inverse problem consists of inferring the wave speed \(c\) and initial conditions \(\rho_0\) from 9 noisy observations of <em>flow</em> \(q\). These observations come from equally spaced detectors at several time points. We discretise the initial condition \(\rho_0\) using 200 equally spaced points, which we note is a dimension where ensemble methods would usually fail.</p>
<!-- We run FES with $$L=100$$ walkers and try different values of the truncation parameter: $$M \in \{0,1,5,10,20\}$$. We note that setting $$M=0$$ corresponds to using AIES for the wave speed $$c$$ and pCN for the initial condition. -->
<p>We compare FES to a standard pCN sampler which uses the following joint update:</p>
<ul>
<li>Gaussian proposals for the wave speed \(c\)</li>
<li>pCN for the initial condition \(\rho_0(x)\)</li>
</ul>
<p>We run the samplers and find that the ensemble sampler can be up to two orders of magnitude faster than pCN (in terms of IATs). We also find that there’s an optimal value of the truncation parameter \(M\) to choose to get the fastest possible mixing. You can find details of this study in the <a href="https://arxiv.org/pdf/2010.15181.pdf">paper</a> in section 4.1.</p>
<p>To understand why the posterior was so challenging for a simple random walk, we plot in figure 2 samples from \(\rho_0\) conditioned on three values of the wave speed \(c\).</p>
<figure class="post_figure">
<img src="/assets/FES_post/advection_conditional_c.png" height="100%" />
<figcaption>Figure 2: Samples of the initial conditions given three values of the wave speed c.</figcaption>
</figure>
<p>This figure reveals the strong negative correlation between the wave speed and the mean of \(\rho_0\). This correlation between the two parameters can be understood from the solution of the PDE for flow as described <a href="#a-motivational-problem">earlier</a>: \(q(x, t) = c\rho_0(x-ct)\). If the wave speed \(c\) increases, the mean of \(\rho_0\) must decrease to keep flow at the detectors approximately constant (and vice-versa). Since the pCN sampler does not account for this correlation structure, large pCN updates are highly unlikely to be accepted and the sampler is slow. In contrast, FES adapts to the correlation structure, eliminating the major bottleneck in the sampling.</p>
<!-- We emphasise that we observe flow data (not density) as described in the data model in the [previous section](#a-motivational-problem). -->
<h1 id="discussion">Discussion</h1>
<p>Bob Carpenter’s explanation of why ensemble methods fail in high dimensions is of course correct: ensemble methods will fail in high dimensions because interpolating and extrapolating between points will fall outside the typical set. However high dimensional spaces are not all high dimensional in the same way.</p>
<p>Throughout the post Bob uses a high dimensional Gaussian (ie: a doughnut) to illustrate when ensemble methods fail. Indeed, we statistician often use the Gaussian distribution to illustrate ideas or simplify reasoning. For example, a standard line of thinking when developing a new method might be to argue: <em>“in the case of everything being Gaussian our method becomes exact..”</em>. This makes sense because Gaussians are easy to manipulate and pop up everywhere because of the <a href="https://en.wikipedia.org/wiki/Central_limit_theorem">central limit theorem</a>.</p>
<p>However, it can be unhelpful to build our mental models of difficult concepts - such as high dimensional distributions - solely on Gaussians. Indeed, difficult sampling problems are hard precisely because they are not Gaussian. This is similar to the problem of using <a href="https://en.wikipedia.org/wiki/Model_organism">model organsisms</a> in biology to learn about humans. So in a way, Gaussians are the <a href="https://en.wikipedia.org/wiki/Drosophila_melanogaster">fruit flies</a> of statistics.</p>
<h3 id="high-dimensional-distributions">High-dimensional distributions</h3>
<p>There are many ways in which high dimensional distributions might be different from a spherical Gaussian which we can use to help with sampling. One way is that there might exist a low dimensional subspace that represents most of the “interesting bits” of the posterior. The rest of the space (the complementary subspace) is then relatively unaffected by the likelihood and therefore acts like the prior. This is similar to the idea behind PCA where the data mainly lives on a low dimensional subspace.</p>
<p>In our <a href="#a-motivational-problem">inverse problem</a> involving functional parameters, the prior gives a natural way to find this low dimensional subspace. This is because the Gaussian process prior imposes a lot of structure on the parameter which allows the infinite dimensional problem to be well posed.</p>
<p>To be perfectly clear: if you have gradients, you should use them! In the function space setting there are many gradient-based methods that will do much better than this ensemble method, such as function space versions of <a href="https://spiral.imperial.ac.uk/bitstream/10044/1/66130/7/1-s2.0-S0021999116307033-main.pdf">HMC and MALA</a> as well as <a href="https://arxiv.org/abs/1403.4680">dimensionality reduction methods</a>. However in our paper we rather focused on the setting where gradients were unavailable. So gradient-free samplers are about doing the best you can with limited information.</p>
<h1 id="conclusion">Conclusion</h1>
<p>Making MCMC work is about finding a good parametrisation. HMC uses Hamilton’s equations to find a natural parametrisation (ie: level sets). However, this can still break down with difficult geometries and another <a href="https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html">reparametrisation</a> might be necessary.</p>
<p>If you don’t have gradients then ensemble methods can be a good solution (example: the <a href="https://emcee.readthedocs.io/en/stable/">emcee</a> package). However if the dimension is too high then these will break down as discussed in Bob Carpenter’s <a href="https://statmodeling.stat.columbia.edu/2017/03/15/ensemble-methods-doomed-fail-high-dimensions/">post</a>. Thankfully, this is not the end of the road for ensemble samplers! You’ll then need to think about your problem to identify natural groupings of parameters to apply your gradient-free sampler to. In this post I went over a practical way to do this in the case of infinite-dimensional inverse problems, yielding a simple but powerful gradient-free ensemble sampler defined on function spaces.</p>
<p><em>Thanks to <a href="https://cims.nyu.edu/~rw2515/">Robert J Webber</a> and <a href="https://bob-carpenter.github.io/">Bob Carpenter</a> for useful feedback on this post</em></p>A few years ago Bob Carpenter wrote a fantastic post on how why ensemble methods cannot work in high dimensions. He explains how high dimensions proposals that interpolate and extrapolate among samples are unlikely to fall in the typical set, causing the sampler to fail. During my PhD I worked on sampling problems with infinite-dimension parameters (ie: functions) and intractible gradients. After reading Bob’s post I always wondered if there was a way around this problem.How to add a progress bar to JAX scans and loops2021-01-29T08:00:00+00:002021-01-29T08:00:00+00:00/2021/01/29/Jax_progress_bar<p>JAX allows you to write optimisers and samplers which are really fast if you use the <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html"><code class="highlighter-rouge">scan</code></a> or <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html"><code class="highlighter-rouge">fori_loop</code></a> functions. However if you write them in this way it’s not obvious how to add progress bar for your algorithm. This post explains how to make a progress bar using Python’s <code class="highlighter-rouge">print</code> function as well as using <a href="https://pypi.org/project/tqdm/">tqdm</a>. After briefly setting up the sampler, we first go over how to create a basic version using Python’s <code class="highlighter-rouge">print</code> function, and then show how to create a nicer version using tqdm. You can find the code for the basic version <a href="https://gist.github.com/jeremiecoullon/4ae89676e650370936200ec04a4e3bef">here</a> and the code for the tqdm version <a href="https://gist.github.com/jeremiecoullon/f6a658be4c98f8a7fd1710418cca0856">here</a>.</p>
<p><em>Update January 2023: this is now available in a pip-installable package: <a href="https://github.com/jeremiecoullon/jax-tqdm">JAX-tqdm</a></em></p>
<h1 id="setup-sampling-a-gaussian">Setup: sampling a Gaussian</h1>
<p>We’ll use an <a href="https://en.wikipedia.org/wiki/Langevin_dynamics">Unadjusted Langevin Algorithm</a> (ULA) to sample from a Gaussian to illustrate how to write the progress bar. Let’s start by defining the log-posterior of a d-dimensional Gaussian and we’ll use JAX to get it’s gradient:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">log_posterior</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">*</span><span class="n">jnp</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">x</span><span class="p">)</span>
<span class="n">grad_log_post</span> <span class="o">=</span> <span class="n">jit</span><span class="p">(</span><span class="n">grad</span><span class="p">(</span><span class="n">log_posterior</span><span class="p">))</span>
</code></pre></div></div>
<p>We now define ULA using the <code class="highlighter-rouge">scan</code> function (see <a href="/2020/11/10/MCMCJax3ways/">this post</a> for an explanation of the <code class="highlighter-rouge">scan</code> function).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,))</span>
<span class="k">def</span> <span class="nf">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="n">paramGrad</span> <span class="o">=</span> <span class="n">grad_log_post</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="n">noise_term</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">subkey</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">param</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="n">param</span> <span class="o">=</span> <span class="n">param</span> <span class="o">+</span> <span class="n">dt</span><span class="o">*</span><span class="n">paramGrad</span> <span class="o">+</span> <span class="n">noise_term</span>
<span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">param</span>
<span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,))</span>
<span class="k">def</span> <span class="nf">ula_sampler</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">dt</span><span class="p">,</span> <span class="n">x_0</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">ula_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">carry</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">),</span> <span class="n">param</span>
<span class="n">carry</span> <span class="o">=</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x_0</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">samples</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">ula_step</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_samples</span><span class="p">))</span>
<span class="k">return</span> <span class="n">samples</span>
</code></pre></div></div>
<p>If we add a <code class="highlighter-rouge">print</code> function in <code class="highlighter-rouge">ula_step</code> above, it will only be called the first time it is called, which is when <code class="highlighter-rouge">ula_sampler</code> is compiled. This is because printing is a side effect, and <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions">compiled JAX functions are pure</a>.</p>
<h1 id="basic-progress-bar">Basic progress bar</h1>
<p>As a workaround, the JAX team has added the <a href="https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html"><code class="highlighter-rouge">host_callback</code></a> module (which is still experimental, so things may change). This module defines functions that allow you to call Python functions from within a JAX function. Here’s how you would use the <a href="https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-id-tap-to-call-a-jax-function-on-another-device-with-no-returned-values-but-full-jax-transformation-support"><code class="highlighter-rouge">id_tap</code></a> function to create a progress bar (from this <a href="https://github.com/google/jax/discussions/4763#discussioncomment-121452">discussion</a>):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">jax.experimental</span> <span class="kn">import</span> <span class="n">host_callback</span>
<span class="k">def</span> <span class="nf">_print_consumer</span><span class="p">(</span><span class="n">arg</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span>
<span class="n">iter_num</span><span class="p">,</span> <span class="n">num_samples</span> <span class="o">=</span> <span class="n">arg</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Iteration </span><span class="si">{</span><span class="n">iter_num</span><span class="p">:,</span><span class="si">}</span><span class="s"> / </span><span class="si">{</span><span class="n">num_samples</span><span class="p">:,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">progress_bar</span><span class="p">(</span><span class="n">arg</span><span class="p">,</span> <span class="n">result</span><span class="p">):</span>
<span class="s">"""
Print progress of a scan/loop only if the iteration number is a multiple of the print_rate
Usage: `carry = progress_bar((iter_num + 1, num_samples, print_rate), carry)`
Pass in `iter_num + 1` so that counting starts at 1 and ends at `num_samples`
"""</span>
<span class="n">iter_num</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">print_rate</span> <span class="o">=</span> <span class="n">arg</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">cond</span><span class="p">(</span>
<span class="n">iter_num</span> <span class="o">%</span> <span class="n">print_rate</span><span class="o">==</span><span class="mi">0</span><span class="p">,</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">host_callback</span><span class="p">.</span><span class="n">id_tap</span><span class="p">(</span><span class="n">_print_consumer</span><span class="p">,</span> <span class="p">(</span><span class="n">iter_num</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">),</span> <span class="n">result</span><span class="o">=</span><span class="n">result</span><span class="p">),</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">result</span><span class="p">,</span>
<span class="n">operand</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span>
</code></pre></div></div>
<p>The <code class="highlighter-rouge">id_tap</code> function behaves like the identity function, so calling <code class="highlighter-rouge">host_callback.id_tap(_print_consumer, (iter_num, num_samples), result=result)</code> will simply return <code class="highlighter-rouge">result</code>. However while doing this, it will also call the function <code class="highlighter-rouge">_print_consumer((iter_num, num_samples))</code> which we’ve defined to print the iteration number.</p>
<p>You need to pass an argument in this way because you need to include a data dependency to make sure that the print function gets called at the correct time. This is linked to the fact that computations in JAX are run <a href="https://jax.readthedocs.io/en/latest/async_dispatch.html">only when needed</a>. So you need to pass in a variable that changes throughout the algorithm such as the PRNG key at that iteration.</p>
<p>Also note also that the <code class="highlighter-rouge">_print_consumer</code> function takes in <code class="highlighter-rouge">arg</code> (which holds the current iteration number as well as the total number of iterations) and <code class="highlighter-rouge">transform</code>. This <code class="highlighter-rouge">transform</code> argument isn’t used here, but apparently should be included in the consumer for id_tap (namely: the Python function that gets called).</p>
<p>Here’s how you would use the progress bar in the ULA sampler:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">ula_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">carry</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">progress_bar</span><span class="p">((</span><span class="n">iter_num</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">print_rate</span><span class="p">),</span> <span class="n">key</span><span class="p">)</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">),</span> <span class="n">param</span>
</code></pre></div></div>
<p>We passed the <code class="highlighter-rouge">key</code> into the progress bar which comes out unchanged. We also set the print rate to be 10% of the number of samples. Note that this would also work for <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html"><code class="highlighter-rouge">lax.fori_loop</code></a> except that the first argument of <code class="highlighter-rouge">ula_step</code> would be the current iteration number.</p>
<h3 id="put-it-in-a-decorator">Put it in a decorator</h3>
<p>We can make this even easier to use by putting the progress bar in a decorator. Note that the decorator takes in <code class="highlighter-rouge">num_samples</code> as an argument.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">progress_bar_scan</span><span class="p">(</span><span class="n">num_samples</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_progress_bar_scan</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
<span class="n">print_rate</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_samples</span><span class="o">/</span><span class="mi">10</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">wrapper_progress_bar</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">):</span>
<span class="n">iter_num</span> <span class="o">=</span> <span class="n">progress_bar</span><span class="p">((</span><span class="n">iter_num</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">print_rate</span><span class="p">),</span> <span class="n">iter_num</span><span class="p">)</span>
<span class="k">return</span> <span class="n">func</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">)</span>
<span class="k">return</span> <span class="n">wrapper_progress_bar</span>
<span class="k">return</span> <span class="n">_progress_bar_scan</span>
</code></pre></div></div>
<p>Remember that writing a decorator with arguments means writing a function that returns a decorator (which itself is a function that returns a modified version of the main function you care about). See this <a href="https://stackoverflow.com/questions/5929107/decorators-with-parameters">StackOverflow question</a> about this.</p>
<p>Putting it all together, the result is very easy to use:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">ula_sampler_pbar</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">dt</span><span class="p">,</span> <span class="n">x_0</span><span class="p">):</span>
<span class="s">"ULA sampler with progress bar"</span>
<span class="o">@</span><span class="n">progress_bar_scan</span><span class="p">(</span><span class="n">num_samples</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">ula_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">carry</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">),</span> <span class="n">param</span>
<span class="n">carry</span> <span class="o">=</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x_0</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">samples</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">ula_step</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_samples</span><span class="p">))</span>
<span class="k">return</span> <span class="n">samples</span>
</code></pre></div></div>
<p>Now that we have a progress bar, we might also want to know when the function is compiling (which is especially useful when it takes a while to compile). Here we can use the fact that the <code class="highlighter-rouge">print</code> function only gets called during compilation. We can add <code class="highlighter-rouge">print("Compiling..")</code> at the beginning of <code class="highlighter-rouge">ula_sampler_pbar</code> and add <code class="highlighter-rouge">print("Running:")</code> at the end. Both of these will then only display when the function is first run. You can find the code for this sampler <a href="https://gist.github.com/jeremiecoullon/4ae89676e650370936200ec04a4e3bef">here</a>.</p>
<h1 id="tqdm-progress-bar">tqdm progress bar</h1>
<p>We’ll now use the same ideas to build a fancier progress bar: namely one that uses <a href="https://pypi.org/project/tqdm/">tqdm</a>. We’ll need to use <code class="highlighter-rouge">host_callback.id_tap</code> to define a <code class="highlighter-rouge">tqdm</code> progress bar and then call <code class="highlighter-rouge">tqdm.update</code> regularly to update it. We’ll also need to close the progress bar once we’re finished or else <code class="highlighter-rouge">tqdm</code> will act weirdly. To do with we’ll define a decorator that takes in arguments just like we did in the case of the simple progress bar.</p>
<p>This decorator defines the tqdm progress bar at the first iteration, updates it every <code class="highlighter-rouge">print_rate</code> number of iterations, and finally closes it at the end. You can optionally pass in a message to add at the beginning of the progress bar.</p>
<p>There are details to make sure the progress bar acts correctly in corner cases, such as if <code class="highlighter-rouge">num_samples</code> is less than 20, or if it’s not a multiple of 20. Note also that tqdm is closed at the last iteration only <em>after</em> the parameter update is done.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">progress_bar_scan</span><span class="p">(</span><span class="n">num_samples</span><span class="p">,</span> <span class="n">message</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="s">"Progress bar for a JAX scan"</span>
<span class="k">if</span> <span class="n">message</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">message</span> <span class="o">=</span> <span class="sa">f</span><span class="s">"Running for </span><span class="si">{</span><span class="n">num_samples</span><span class="p">:,</span><span class="si">}</span><span class="s"> iterations"</span>
<span class="n">tqdm_bars</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="n">num_samples</span> <span class="o">></span> <span class="mi">20</span><span class="p">:</span>
<span class="n">print_rate</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_samples</span> <span class="o">/</span> <span class="mi">20</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">print_rate</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># if you run the sampler for less than 20 iterations
</span> <span class="n">remainder</span> <span class="o">=</span> <span class="n">num_samples</span> <span class="o">%</span> <span class="n">print_rate</span>
<span class="k">def</span> <span class="nf">_define_tqdm</span><span class="p">(</span><span class="n">arg</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span>
<span class="n">tqdm_bars</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_samples</span><span class="p">))</span>
<span class="n">tqdm_bars</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_description</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">refresh</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_update_tqdm</span><span class="p">(</span><span class="n">arg</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span>
<span class="n">tqdm_bars</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">update</span><span class="p">(</span><span class="n">arg</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_update_progress_bar</span><span class="p">(</span><span class="n">iter_num</span><span class="p">):</span>
<span class="s">"Updates tqdm progress bar of a JAX scan or loop"</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">cond</span><span class="p">(</span>
<span class="n">iter_num</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">host_callback</span><span class="p">.</span><span class="n">id_tap</span><span class="p">(</span><span class="n">_define_tqdm</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="n">result</span><span class="o">=</span><span class="n">iter_num</span><span class="p">),</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">iter_num</span><span class="p">,</span>
<span class="n">operand</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">cond</span><span class="p">(</span>
<span class="c1"># update tqdm every multiple of `print_rate` except at the end
</span> <span class="p">(</span><span class="n">iter_num</span> <span class="o">%</span> <span class="n">print_rate</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">iter_num</span> <span class="o">!=</span> <span class="n">num_samples</span><span class="o">-</span><span class="n">remainder</span><span class="p">),</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">host_callback</span><span class="p">.</span><span class="n">id_tap</span><span class="p">(</span><span class="n">_update_tqdm</span><span class="p">,</span> <span class="n">print_rate</span><span class="p">,</span> <span class="n">result</span><span class="o">=</span><span class="n">iter_num</span><span class="p">),</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">iter_num</span><span class="p">,</span>
<span class="n">operand</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">cond</span><span class="p">(</span>
<span class="c1"># update tqdm by `remainder`
</span> <span class="n">iter_num</span> <span class="o">==</span> <span class="n">num_samples</span><span class="o">-</span><span class="n">remainder</span><span class="p">,</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">host_callback</span><span class="p">.</span><span class="n">id_tap</span><span class="p">(</span><span class="n">_update_tqdm</span><span class="p">,</span> <span class="n">remainder</span><span class="p">,</span> <span class="n">result</span><span class="o">=</span><span class="n">iter_num</span><span class="p">),</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">iter_num</span><span class="p">,</span>
<span class="n">operand</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">_close_tqdm</span><span class="p">(</span><span class="n">arg</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span>
<span class="n">tqdm_bars</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">close</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">close_tqdm</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">):</span>
<span class="k">return</span> <span class="n">lax</span><span class="p">.</span><span class="n">cond</span><span class="p">(</span>
<span class="n">iter_num</span> <span class="o">==</span> <span class="n">num_samples</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">host_callback</span><span class="p">.</span><span class="n">id_tap</span><span class="p">(</span><span class="n">_close_tqdm</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="n">result</span><span class="o">=</span><span class="n">result</span><span class="p">),</span>
<span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">result</span><span class="p">,</span>
<span class="n">operand</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">_progress_bar_scan</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
<span class="s">"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Note that `body_fun` must either be looping over `np.arange(num_samples)`,
or be looping over a tuple who's first element is `np.arange(num_samples)`
This means that `iter_num` is the current iteration number
"""</span>
<span class="k">def</span> <span class="nf">wrapper_progress_bar</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="ow">is</span> <span class="nb">tuple</span><span class="p">:</span>
<span class="n">iter_num</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="n">x</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">iter_num</span> <span class="o">=</span> <span class="n">x</span>
<span class="n">_update_progress_bar</span><span class="p">(</span><span class="n">iter_num</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">close_tqdm</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">)</span>
<span class="k">return</span> <span class="n">wrapper_progress_bar</span>
<span class="k">return</span> <span class="n">_progress_bar_scan</span>
</code></pre></div></div>
<p>Although this progress bar is more complicated than the previous one, you use it in exactly the same way. You simply add the decorator to the step function used in <code class="highlighter-rouge">lax.scan</code> with the number of samples as argument (and optionally the messsage to print at the beginning of the progress bar).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">ula_sampler_pbar</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">dt</span><span class="p">,</span> <span class="n">x_0</span><span class="p">):</span>
<span class="s">"ULA sampler with progress bar"</span>
<span class="o">@</span><span class="n">progress_bar_scan</span><span class="p">(</span><span class="n">num_samples</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">ula_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">iter_num</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">carry</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">),</span> <span class="n">param</span>
<span class="n">carry</span> <span class="o">=</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x_0</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">samples</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">ula_step</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_samples</span><span class="p">))</span>
<span class="k">return</span> <span class="n">samples</span>
</code></pre></div></div>
<h3 id="conclusion">Conclusion</h3>
<p>So we’ve built two progress bars: a basic version and a nicer version that uses tqdm. The code for these are on these two gists: <a href="https://gist.github.com/jeremiecoullon/4ae89676e650370936200ec04a4e3bef">here</a> and <a href="https://gist.github.com/jeremiecoullon/f6a658be4c98f8a7fd1710418cca0856">here</a>.</p>JAX allows you to write optimisers and samplers which are really fast if you use the scan or fori_loop functions. However if you write them in this way it’s not obvious how to add progress bar for your algorithm. This post explains how to make a progress bar using Python’s print function as well as using tqdm. After briefly setting up the sampler, we first go over how to create a basic version using Python’s print function, and then show how to create a nicer version using tqdm. You can find the code for the basic version here and the code for the tqdm version here.MCMC in JAX with benchmarks: 3 ways to write a sampler2020-11-10T08:00:00+00:002020-11-10T08:00:00+00:00/2020/11/10/MCMCJax3ways<p>This post goes over 3 ways to write a sampler using JAX. I found that although there are a bunch of tutorials about learning the basics of JAX, it was not clear to me what was the best way to write a sampler in JAX. In particular, how much of the sampler should you write in JAX? Just the log-posterior (or the loss in the case of optimisation), or the entire loop? This blog post tries to answer this by going over 3 ways to write a sampler while focusing on the speed of each sampler.</p>
<p>I’ll assume that you already know some JAX, in particular the functions <code class="highlighter-rouge">grad</code>, <code class="highlighter-rouge">vmap</code>, and <code class="highlighter-rouge">jit</code>, along with the random number generator. If not, you can check out how to use these in this <a href="https://colinraffel.com/blog/you-don-t-know-jax.html">blog post</a> or in the <a href="https://jax.readthedocs.io/en/latest/notebooks/quickstart.html">JAX documentation</a>! I will rather focus on the different ways of using JAX for sampling (using the ULA sampler) and the speed performance of each implementation. I’ll then redo these benchmarks for 2 other samplers (MALA and SLGD). The benchmarks are done on both CPU (in the post) and GPU (in the appendix) for comparison. You can find the code to reproduce all these examples on <a href="https://github.com/jeremiecoullon/jax_MCMC_blog_post">Github</a>.</p>
<h2 id="sampler-and-model">Sampler and model</h2>
<p>To benchmark the samplers we’ll Bayesian logistic regression throughout. As sampler we’ll start with the unadjusted Langevin algorithm (ULA) with Euler dicretisation, as it is one of the simplest gradient-based samplers out there due to the lack of accept-reject step. Let \(\theta_n \in \mathcal{R}^d\) be the parameter at iteration \(n\), \(\nabla \log \pi(\theta)\) the gradient of the log-posterior, \(dt\) the step size, and \(\xi \sim \mathcal{N}(0, I_d)\). Given a current position of the chain, the next sample is given by the equation:</p>
\[\theta_{n+1} = \theta_n + dt\nabla\log\pi(\theta_n) + \sqrt{2dt}\xi\]
<p>The setup of the logistic regression model is the same as the one from this <a href="https://arxiv.org/abs/1907.06986">SG-MCMC review paper</a>:</p>
<ul>
<li>Matrix of covariates \(\textbf{X} \in \mathcal{R}^{N\times d}\), and vector responses: \(\textbf{y} = \{ y_i \}_1^N\)</li>
<li>Parameters: \(\theta \in \mathcal{R^d}\)</li>
</ul>
<p><strong>Model:</strong></p>
<ul>
<li>\(y_i = \text{Bernoulli}(p_i)\) with \(p_i = \frac{1}{ 1+\exp(-\theta^T x_i)}\)</li>
<li>Prior: \(\theta \sim \mathcal{N}(0, \Sigma_{\theta})\) with \(\Sigma_{\theta} = 10\textbf{I}_d\)</li>
<li>Likelihood: \(p(X,y \mid \theta) = \Pi^N p_i^{y_i}(1-p_i)^{1-y_i}\)</li>
</ul>
<h2 id="version-1-python-loop-with-jax-for-the-log-posterior">Version 1: Python loop with JAX for the log-posterior</h2>
<p>In this version we only use JAX to write the log-posterior function (or the loss function in the case of optimisation). We use <code class="highlighter-rouge">vmap</code> to calculate the log-likelihood for each data point, <code class="highlighter-rouge">jit</code> to compile the function, and <code class="highlighter-rouge">grad</code> to get the gradient (see the code for the model on <a href="https://github.com/jeremiecoullon/jax_MCMC_blog_post/blob/master/logistic_regression_model.py">Github</a>). The rest of the sampler is a simple Python loop with NumPy to store the samples, as is shown below:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">ula_sampler_python</span><span class="p">(</span><span class="n">grad_log_post</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">dt</span><span class="p">,</span> <span class="n">x_0</span><span class="p">,</span> <span class="n">print_rate</span><span class="o">=</span><span class="mi">500</span><span class="p">):</span>
<span class="n">dim</span><span class="p">,</span> <span class="o">=</span> <span class="n">x_0</span><span class="p">.</span><span class="n">shape</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_samples</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span>
<span class="n">paramCurrent</span> <span class="o">=</span> <span class="n">x_0</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Python sampler:"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_samples</span><span class="p">):</span>
<span class="n">paramGradCurrent</span> <span class="o">=</span> <span class="n">grad_log_post</span><span class="p">(</span><span class="n">paramCurrent</span><span class="p">)</span>
<span class="n">paramCurrent</span> <span class="o">=</span> <span class="n">paramCurrent</span> <span class="o">+</span> <span class="n">dt</span><span class="o">*</span><span class="n">paramGradCurrent</span> <span class="o">+</span>
<span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">paramCurrent</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="n">samples</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">paramCurrent</span>
<span class="k">if</span> <span class="n">i</span><span class="o">%</span><span class="n">print_rate</span><span class="o">==</span><span class="mi">0</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Iteration </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">/</span><span class="si">{</span><span class="n">num_samples</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">return</span> <span class="n">samples</span>
</code></pre></div></div>
<p>In this sampler we write the udpate equation using NumPy and store the samples in the array <code class="highlighter-rouge">samples</code>.</p>
<h2 id="version-2-jax-for-the-transition-kernel">Version 2: JAX for the transition kernel</h2>
<p>With JAX we can compile functions using <code class="highlighter-rouge">jit</code> which makes them run faster (we did this for the log-posterior function). Could we not put the bit inside the loop in a function and compile that? The issue is that for <code class="highlighter-rouge">jit</code> to work, you can’t have NumPy arrays or use the NumPy random number generator (<code class="highlighter-rouge">np.random.normal()</code>).</p>
<p>JAX does random numbers a bit differently to NumPy. I won’t explain how this bit works; you can read about them in the <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG">documentation</a>. The main idea is that jit-compiled JAX function don’t allow side effects, such as updating a global random state. As a result, you have to explicitly pass in a PRNG (called <code class="highlighter-rouge">key</code>) to every function that includes randomness, and split the key to get different pseudorandom numbers.</p>
<p>Below is a function for the transition kernel of the sampler rewritten to only include JAX functions and arrays (so it can be compiled). The point of the <code class="highlighter-rouge">partial</code> decorator and the <code class="highlighter-rouge">static_argnums</code> argument is to point to which arguments will not change once the function is compiled. Indeed, the function for the gradient of the log-posterior or the step size will not change throughout the sampler, but the PRNG key and the parameter definitely will! The means that the function will run faster as it can hardcode these static values/functions during compilation. Note that if the argument is a function (as is the case for <code class="highlighter-rouge">grad_log_post</code>) you don’t have a choice and must set it as static. See the <a href="https://jax.readthedocs.io/en/latest/jax.html#jax.jit">documentation</a> for info on this.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="n">paramGrad</span> <span class="o">=</span> <span class="n">grad_log_post</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="n">param</span> <span class="o">=</span> <span class="n">param</span> <span class="o">+</span> <span class="n">dt</span><span class="o">*</span><span class="n">paramGrad</span> <span class="o">+</span> <span class="n">jnp</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">subkey</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">param</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">param</span>
</code></pre></div></div>
<p>The main loop in the previous function now becomes:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_samples</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="n">samples</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span>
<span class="k">if</span> <span class="n">i</span><span class="o">%</span><span class="n">print_rate</span><span class="o">==</span><span class="mi">0</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Iteration </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">/</span><span class="si">{</span><span class="n">num_samples</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<p>Notice how we split the random key inside <code class="highlighter-rouge">ula_kernel()</code> function which means it gets compiled (JAX’s random number generator can be <a href="https://github.com/google/jax/issues/968">slow in some cases</a>). We still save the samples in the NumPy array <code class="highlighter-rouge">samples</code> as in the previous case. Running this function several times with the same starting PRNG key will now produce exactly the sample samples, which means that the sampler is completely reproducible.</p>
<h2 id="version-3-full-jax">Version 3: full JAX</h2>
<p>We’ve written more of our function in JAX, but there is still some Python left. Could we rewrite the entire sampler in JAX? It turns out that we can! JAX does allow us write loops, but as it is designed to work on <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions">pure functions</a> you need to use the <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html"><code class="highlighter-rouge">scan</code></a> function. This function which allows you to loop over an array (similar to doing <code class="highlighter-rouge">for elem in mylist</code> in Python).</p>
<p>The way to use <code class="highlighter-rouge">scan</code> is to pass in a function that is called at every iteration. This function takes in <code class="highlighter-rouge">carry</code> which contains all the information you use in each iteration (and which you update as you go along). It also takes in <code class="highlighter-rouge">x</code> which is the value of the array you’re iterating over. It should return an updated version of <code class="highlighter-rouge">carry</code> along with anything who’s progress you want to keep track of: in our case, we want to store all the samples as we iterate.</p>
<p>Note that JAX also has a similar <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html"><code class="highlighter-rouge">fori_loop</code></a> function which apparently you should only use if you can’t use scan (see the <a href="https://github.com/google/jax/discussions/3850">discussion on Github</a>). In the case of our sampler <code class="highlighter-rouge">scan</code> is easier to use as you don’t need to explicitly keep track of the entire chain of samples; <code class="highlighter-rouge">scan</code> does it for you. In contrast, when using <code class="highlighter-rouge">fori_loop</code> you have to pass an array of samples in <code class="highlighter-rouge">state</code> which you update yourself as you go along. In terms of performance I did quick benchmark for both and didn’t see a speed difference in this case, though the <a href="https://github.com/google/jax/discussions/3850">discussion on Github</a> says there can be speed benefits.</p>
<p>Here is the function that we’ll pass in <code class="highlighter-rouge">scan</code>. Note that the first line unpacks <code class="highlighter-rouge">carry</code>. The <code class="highlighter-rouge">ula_kernel</code> function then generates the new key and parameter. We then return the new version of <code class="highlighter-rouge">carry</code> (ie: <code class="highlighter-rouge">(key, param)</code>) which includes the updated key and parameter, and return the current parameter (<code class="highlighter-rouge">param</code>) which <code class="highlighter-rouge">scan</code> will save in an array.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">ula_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">carry</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">),</span> <span class="n">param</span>
</code></pre></div></div>
<p>You can then pass this function along with the initial state in <code class="highlighter-rouge">scan</code>, and recover the final <code class="highlighter-rouge">carry</code> along with all the samples. The last two arguments in the <code class="highlighter-rouge">scan</code> function below mean that we don’t care what we’re iterating over; we simply want to run the sampler for <code class="highlighter-rouge">num_samples</code> number of iterations (as always, see the <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html">docs</a> for details).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">carry</span> <span class="o">=</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x_0</span><span class="p">)</span>
<span class="n">carry</span><span class="p">,</span> <span class="n">samples</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">ula_step</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">)</span>
</code></pre></div></div>
<p>Putting it all together in a single function, we get the following. Notice that we compile the entire function with <code class="highlighter-rouge">grad_log_post</code>, <code class="highlighter-rouge">num_samples</code>, and <code class="highlighter-rouge">dt</code> kept as static. We allow the PRNG key and the starting point of the chain <code class="highlighter-rouge">x_0</code> to vary so we can get different realisations of our chain.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">ula_sampler_full_jax_jit</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">dt</span><span class="p">,</span> <span class="n">x_0</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">ula_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">carry</span>
<span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="o">=</span> <span class="n">ula_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">),</span> <span class="n">param</span>
<span class="n">carry</span> <span class="o">=</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x_0</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">samples</span> <span class="o">=</span> <span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">ula_step</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">)</span>
<span class="k">return</span> <span class="n">samples</span>
</code></pre></div></div>
<p>Having the entire function written in JAX means that once the function is compiled it will usually be faster (see benchmarks below), and we can rerun it for different PRNG keys or different initial conditions to get different realisations of the chain. We can also run this function in <code class="highlighter-rouge">vmap</code> (mapping over the keys or inital conditions) to get several chains running in parallel. Check out this <a href="https://rlouf.github.io/post/jax-random-walk-metropolis/">blog post</a> for a benchmark of a Metropolis sampler in parallel using JAX and Tensorflow.</p>
<p>Note that another way to do this would have been to split the initial key once at the beginning (<code class="highlighter-rouge">keys = random.split(key, num_samples)</code>) and scan over (ie: loop over) all these keys: <code class="highlighter-rouge">lax.scan(ula_step, carry, keys)</code>. The <code class="highlighter-rouge">ula_step</code> and <code class="highlighter-rouge">ula_kernel</code> functions would then have to be modified slightly for this to work. This would simplify code even more as it means you don’t need to split the key at each iteration anymore.</p>
<p>The only thing left to do this the full JAX version is to print the progress of the chain, which is especially useful for long runs. This is not as straightforwards to do with jitted functions as with standard Python functions, but this <a href="https://github.com/google/jax/discussions/4763">discussion on Github</a> goes over how to do this.</p>
<p>The final thing to point out is this JAX code ports directly to GPU without any modifications. See the appendix for benchmarks on a GPU.</p>
<h1 id="benchmarks">Benchmarks</h1>
<p>Now that we’ve gone over 3 ways to write an MCMC sampler we’ll show some speed benchmarks for ULA along with two other algorithms. We use the logistic regression model presented above and run <code class="highlighter-rouge">20 000</code> samples throughout.</p>
<p>These benchmarks ran on my laptop (standard macbook pro). You can find the benchmarks of the same samplers on a GPU in the appendix.</p>
<h2 id="unadjusted-langevin-algorithm">Unadjusted Langevin algorithm</h2>
<h3 id="increase-amount-of-data">Increase amount of data:</h3>
<p>We run ULA for <code class="highlighter-rouge">20 000</code> samples for a 5 dimensional parameter. We vary the amount of data used and see how fast the algorithms are (time is in seconds).</p>
<table>
<thead>
<tr>
<th>dataset size</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (2nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(10^3\)</td>
<td>11</td>
<td>3.4</td>
<td>0.53</td>
<td>0.18</td>
</tr>
<tr>
<td>\(10^4\)</td>
<td>11</td>
<td>4.6</td>
<td>2.0</td>
<td>1.6</td>
</tr>
<tr>
<td>\(10^5\)</td>
<td>32</td>
<td>32</td>
<td>24</td>
<td>24</td>
</tr>
<tr>
<td>\(10^6\)</td>
<td>280</td>
<td>280</td>
<td>250</td>
<td>250</td>
</tr>
</tbody>
</table>
<p>We can see that for small amounts of data the full JAX sampler is much faster than the Python loop. In particular, for 1000 data points the full JAX sampler (once compiled) is almost 60 times faster than the Python loop version.</p>
<p>Note that all the samplers use JAX to get the gradient of the log-posterior (including the Python loop version). So the speedup comes from everything else in the sampler being compiled. We also notice that for small amounts of data, there’s a big difference between the first full JAX run (where the function is being compiled) and the second one (where the function is already compiled). This speedup would be especially useful if you need to run the sampler many times for different starting points or realisation (ie: choosing a different PRNG key). We can also see that simply writing the transition kernel in JAX already causes a 3x speedup over the Python loop version.</p>
<p>However as we add more data, the differences between the algorithms gets smaller. The full JAX version is still the fastest, but not by much. This is probably because the log-posterior dominates the computational cost of the sampler as the dataset increases. As that function is the same for all samplers, they end up having similar timings.</p>
<h3 id="increase-the-dimension">Increase the dimension:</h3>
<p>We now run the samplers with a fixed dataset size of 1000 data points, and run each sampler for 20K iterations while varying the dimension:</p>
<table>
<thead>
<tr>
<th>dimension</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(5\)</td>
<td>11</td>
<td>3.4</td>
<td>0.56</td>
<td>0.19</td>
</tr>
<tr>
<td>\(500\)</td>
<td>12</td>
<td>5.0</td>
<td>2.5</td>
<td>1.6</td>
</tr>
<tr>
<td>\(1000\)</td>
<td>13</td>
<td>7.7</td>
<td>4.3</td>
<td>3.4</td>
</tr>
<tr>
<td>\(2000\)</td>
<td>13</td>
<td>16</td>
<td>14</td>
<td>13</td>
</tr>
</tbody>
</table>
<p>Here the story is similar to above: for small dimensionality the full JAX sampler is 60x faster than the Python loop version. But as you increase the dimension the gap gets smaller. As in the previous case, this is probably because the main effect of increasing the dimensionality is seen in the log-posterior function (which is in JAX for all the samplers).</p>
<p>The only difference to note is that the JAX kernel version is slower than the Python loop version for dimension 2000. <a href="http://vanderplas.com/">Jake VanderPlas</a> suggests that this has to with moving data around which has a low overhead for NumPy but can be expensive when JAX and NumPy interact. But in any case this reinforces the idea that you should always benchmark your code to make sure it’s fast.</p>
<h2 id="stochastic-gradient-langevin-dynamics-sgld">Stochastic gradient Langevin dynamics (SGLD)</h2>
<p>We now try the same experiment with <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_Langevin_dynamics">stochastic gradient langevin dynamics</a> sampler. This is the same as ULA but calculates gradients based on mini-batches rather than on the full dataset. This makes it suited for application with very large datasets, but the sampler produces samples that aren’t exactly from the target distribution (often the variance is too high).</p>
<p>The transition kernel below is therefore quite similar to ULA, but randomly chooses minibatches of data to calculate gradients with. Note also that the <code class="highlighter-rouge">grad_log_post</code> function includes the minibatch dataset as arguments. Also note that we sample minibatches <em>with</em> replacement (<code class="highlighter-rouge">random.choice</code> has <code class="highlighter-rouge">replace=True</code> as default). This is because sampling without replacement is very expensive is JAX, so doing this will dramatically slow down the sampler!</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">4</span><span class="p">,</span><span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">sgld_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">grad_log_post</span><span class="p">,</span> <span class="n">dt</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y_data</span><span class="p">,</span> <span class="n">minibatch_size</span><span class="p">):</span>
<span class="n">N</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span>
<span class="n">key</span><span class="p">,</span> <span class="n">subkey1</span><span class="p">,</span> <span class="n">subkey2</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">idx_batch</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">choice</span><span class="p">(</span><span class="n">subkey1</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">minibatch_size</span><span class="p">,))</span>
<span class="n">paramGrad</span> <span class="o">=</span> <span class="n">grad_log_post</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">X</span><span class="p">[</span><span class="n">idx_batch</span><span class="p">],</span> <span class="n">y_data</span><span class="p">[</span><span class="n">idx_batch</span><span class="p">])</span>
<span class="n">param</span> <span class="o">=</span> <span class="n">param</span> <span class="o">+</span> <span class="n">dt</span><span class="o">*</span><span class="n">paramGrad</span> <span class="o">+</span> <span class="n">jnp</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">subkey2</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">param</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">param</span>
</code></pre></div></div>
<h3 id="increase-amount-of-data-1">Increase amount of data:</h3>
<p>We run the same experiment as before: <code class="highlighter-rouge">20 000</code> samples for a 5 dimensional parameter. with increasing the amount of data. As before the timings are in seconds. The minibatch sizes we use are \(10\%\), \(10\%\), \(1\%\), and \(0.1\%\) respectively.</p>
<table>
<thead>
<tr>
<th>dataset size</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(10^3\)</td>
<td>59</td>
<td>4.2</td>
<td>1.1</td>
<td>0.056</td>
</tr>
<tr>
<td>\(10^4\)</td>
<td>60</td>
<td>4.4</td>
<td>3.8</td>
<td>0.9</td>
</tr>
<tr>
<td>\(10^5\)</td>
<td>73</td>
<td>3.9</td>
<td>1.6</td>
<td>0.40</td>
</tr>
<tr>
<td>\(10^6\)</td>
<td>65</td>
<td>4.0</td>
<td>2.0</td>
<td>0.69</td>
</tr>
</tbody>
</table>
<p>Here we see that unlike in the case of ULA, we keep the large speedup from compiling everything in JAX. This is because the minibatches allow us to keep the cost of the log-posterior low.</p>
<p>We also notice that the Python and JAX kernel versions are slower than their ULA counterparts for low and medium datset sizes. This is probably due to the cost of sampling the minibatches and the fact that for these small dataset sizes the log-posterior function is efficient enough to not actually need minibatches. However for the last dataset (1 million data points) the benefit of using minibatches becomes clear.</p>
<h3 id="increase-the-dimension-1">Increase the dimension:</h3>
<p>We now run the samplers with a fixed dataset size of 1000 data points, and run each sampler for 20K iterations while varying the dimension. We use as minibatch size 10% of the data for all 4 runs.</p>
<table>
<thead>
<tr>
<th>dimension</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(5\)</td>
<td>61</td>
<td>4.2</td>
<td>1.1</td>
<td>0.055</td>
</tr>
<tr>
<td>\(500\)</td>
<td>62</td>
<td>4.2</td>
<td>1.9</td>
<td>0.56</td>
</tr>
<tr>
<td>\(1000\)</td>
<td>62</td>
<td>5.0</td>
<td>2.3</td>
<td>0.98</td>
</tr>
<tr>
<td>\(2000\)</td>
<td>68</td>
<td>6.4</td>
<td>3.3</td>
<td>1.95</td>
</tr>
</tbody>
</table>
<p>Here the two JAX samplers benefit from using minibatches, while the Python version is slower than its ULA counterpart in all cases.</p>
<h2 id="metropolis-adjusted-langevin-algorithm-mala">Metropolis Adjusted Langevin algorithm (MALA)</h2>
<p>We now re-run the same experiment but with <a href="https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm">MALA</a>, which is like ULA but with a Metropolis-Hastings correction to ensure that the samples are unbiased. This correction means that the transition kernel is more computationally expensive:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">jit</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">5</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">mala_kernel</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">paramCurrent</span><span class="p">,</span> <span class="n">paramGradCurrent</span><span class="p">,</span> <span class="n">log_post</span><span class="p">,</span> <span class="n">logpostCurrent</span><span class="p">,</span> <span class="n">dt</span><span class="p">):</span>
<span class="n">key</span><span class="p">,</span> <span class="n">subkey1</span><span class="p">,</span> <span class="n">subkey2</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">paramProp</span> <span class="o">=</span> <span class="n">paramCurrent</span> <span class="o">+</span> <span class="n">dt</span><span class="o">*</span><span class="n">paramGradCurrent</span> <span class="o">+</span> <span class="n">jnp</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">subkey1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="n">paramCurrent</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_log_post</span><span class="p">,</span> <span class="n">new_grad</span> <span class="o">=</span> <span class="n">log_post</span><span class="p">(</span><span class="n">paramProp</span><span class="p">)</span>
<span class="n">term1</span> <span class="o">=</span> <span class="n">paramProp</span> <span class="o">-</span> <span class="n">paramCurrent</span> <span class="o">-</span> <span class="n">dt</span><span class="o">*</span><span class="n">paramGradCurrent</span>
<span class="n">term2</span> <span class="o">=</span> <span class="n">paramCurrent</span> <span class="o">-</span> <span class="n">paramProp</span> <span class="o">-</span> <span class="n">dt</span><span class="o">*</span><span class="n">new_grad</span>
<span class="n">q_new</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.25</span><span class="o">*</span><span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">jnp</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">term1</span><span class="p">,</span> <span class="n">term1</span><span class="p">)</span>
<span class="n">q_current</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.25</span><span class="o">*</span><span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="n">dt</span><span class="p">)</span><span class="o">*</span><span class="n">jnp</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">term2</span><span class="p">,</span> <span class="n">term2</span><span class="p">)</span>
<span class="n">log_ratio</span> <span class="o">=</span> <span class="n">new_log_post</span> <span class="o">-</span> <span class="n">logpostCurrent</span> <span class="o">+</span> <span class="n">q_current</span> <span class="o">-</span> <span class="n">q_new</span>
<span class="n">acceptBool</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">subkey2</span><span class="p">))</span> <span class="o"><</span> <span class="n">log_ratio</span>
<span class="n">paramCurrent</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">acceptBool</span><span class="p">,</span> <span class="n">paramProp</span><span class="p">,</span> <span class="n">paramCurrent</span><span class="p">)</span>
<span class="n">current_grad</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">acceptBool</span><span class="p">,</span> <span class="n">new_grad</span><span class="p">,</span> <span class="n">paramGradCurrent</span><span class="p">)</span>
<span class="n">current_log_post</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">acceptBool</span><span class="p">,</span> <span class="n">new_log_post</span><span class="p">,</span> <span class="n">logpostCurrent</span><span class="p">)</span>
<span class="n">accepts_add</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">acceptBool</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">paramCurrent</span><span class="p">,</span> <span class="n">current_grad</span><span class="p">,</span> <span class="n">current_log_post</span><span class="p">,</span> <span class="n">accepts_add</span>
</code></pre></div></div>
<p>We run the usual 3 versions: a Python sampler with JAX for the log-posterior, a Python loop with the JAX transition kernel, and a “full JAX” sampler.</p>
<h3 id="increase-amount-of-data-2">Increase amount of data:</h3>
<p>We run each sampler for <code class="highlighter-rouge">20 000</code> samples for a 5 dimensional parameter while varying the size of the dataset.</p>
<table>
<thead>
<tr>
<th>dataset size</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (2nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(10^3\)</td>
<td>38</td>
<td>7</td>
<td>0.93</td>
<td>0.19</td>
</tr>
<tr>
<td>\(10^4\)</td>
<td>38</td>
<td>7</td>
<td>2.6</td>
<td>1.9</td>
</tr>
<tr>
<td>\(10^5\)</td>
<td>56</td>
<td>35</td>
<td>27</td>
<td>26</td>
</tr>
<tr>
<td>\(10^6\)</td>
<td>330</td>
<td>310</td>
<td>270</td>
<td>272</td>
</tr>
</tbody>
</table>
<p>The story here is similar to the story in the case of ULA. The main difference is that the speedup for the full JAX sampler is more pronounced in this case (especially for the smaller datasets). Indeed, for 1000 data points the full JAX (once it’s compiled) is 200 times faster than the Python loop version. This is probably because the transition kernel is more complicated and so contributes more to the overall computational cost of the sampler. As a result compiling it brings a larger speed increase than for ULA.</p>
<p>Furthermore (in the case of ULA) when the dataset size is increased the speed of the samplers start to converge to the same value.</p>
<h3 id="increase-the-dimension-2">Increase the dimension:</h3>
<p>We now run the samplers with a fixed dataset size of 1000 data points, and run each sampler for 20K iterations while varying the dimension.</p>
<table>
<thead>
<tr>
<th>dimension</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(5\)</td>
<td>39</td>
<td>7.2</td>
<td>0.94</td>
<td>0.20</td>
</tr>
<tr>
<td>\(500\)</td>
<td>40</td>
<td>7.2</td>
<td>2.6</td>
<td>1.5</td>
</tr>
<tr>
<td>\(1000\)</td>
<td>41</td>
<td>8.0</td>
<td>4.8</td>
<td>3.4</td>
</tr>
<tr>
<td>\(2000\)</td>
<td>43</td>
<td>15</td>
<td>14</td>
<td>13</td>
</tr>
</tbody>
</table>
<p>We we have a similar story to the case of increasing data: using full JAX speeds up the sampler a lot, but that gap gets smaller as you increase the dimensionality.</p>
<h1 id="conclusion">Conclusion</h1>
<p>We’ve seen that there are different ways to write MCMC samplers by having more or less of the code written in JAX. On one hand, you can use JAX to write the log-posterior function and use Python/NumPy for the rest. On the other hand you can use JAX to write the entire sampler. We’ve also seen that in general the full JAX sampler is faster than the Python loop version, but that this difference gets smaller as the amount of data and dimensionality increases.</p>
<p>The main conclusion we take from this is that in general writing more things in JAX speeds up the code. However you have to make sure it’s well written so you don’t accidentally slow things down (for example by re-compiling a function at every iteration by mis-using <code class="highlighter-rouge">static_argnums</code> when jitting). You should therefore always benchmark code and compare different ways of writing it.</p>
<p><em>All the code for this post is on <a href="https://github.com/jeremiecoullon/jax_MCMC_blog_post">Github</a></em></p>
<p><em>Thanks to <a href="http://vanderplas.com/">Jake VanderPlas</a> and <a href="https://rlouf.github.io/">Remi Louf</a> for useful feedback on this post as well as the High End Computing facility at Lancaster University for the GPU cluster (results in the appendix below)</em></p>
<h1 id="appendix-gpu-benchmarks">Appendix: GPU benchmarks</h1>
<p><em>edit: added this section on 9th December 2020</em></p>
<p>We show here the benchmarks on a single GPU compute node. For the runs where the dataset size increases we run the samplers for <code class="highlighter-rouge">20 000</code> iterations for a 5 dimensional parameter. For the ones where we increase the dimension we generate 1000 data points and run <code class="highlighter-rouge">20 000</code> iterations.</p>
<p>Note that here the ranges of dataset sizes and dimensions are much larger as the timings essentially didn’t vary for the ranges used in the previous benchmarks. Also notice how for small dataset sizes and dimensions the samplers are faster on CPU. This is because the GPU has a fixed overhead cost. However as the datasets gets larger the GPU does much better.</p>
<p>Timings are all in seconds.</p>
<h2 id="ula">ULA</h2>
<table>
<thead>
<tr>
<th>dataset size</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (2nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(10^3\)</td>
<td>18</td>
<td>8.4</td>
<td>2.2</td>
<td>1.5</td>
</tr>
<tr>
<td>\(10^6\)</td>
<td>18</td>
<td>12</td>
<td>5.8</td>
<td>5.2</td>
</tr>
<tr>
<td>\(10^7\)</td>
<td>49</td>
<td>50</td>
<td>43</td>
<td>42</td>
</tr>
<tr>
<td>\(2*10^7\)</td>
<td>90</td>
<td>92</td>
<td>84</td>
<td>82</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>dimension</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(100\)</td>
<td>18</td>
<td>8.2</td>
<td>2.2</td>
<td>1.5</td>
</tr>
<tr>
<td>\(10^4\)</td>
<td>33</td>
<td>10</td>
<td>4.0</td>
<td>3.0</td>
</tr>
<tr>
<td>\(2*10^4\)</td>
<td>47</td>
<td>14</td>
<td>6.5</td>
<td>5.0</td>
</tr>
<tr>
<td>\(3*10^4\)</td>
<td>61</td>
<td>18</td>
<td>9.1</td>
<td>7.1</td>
</tr>
</tbody>
</table>
<h2 id="sgld">SGLD</h2>
<p>The minibatch sizes for the increasing dataset sizes are \(10\%\), \(10\%\), \(1\%\), and \(0.1\%\) respectively.</p>
<table>
<thead>
<tr>
<th>dataset size</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (2nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(10^3\)</td>
<td>80</td>
<td>11</td>
<td>3.6</td>
<td>2.8</td>
</tr>
<tr>
<td>\(10^6\)</td>
<td>95</td>
<td>10</td>
<td>3.3</td>
<td>2.9</td>
</tr>
<tr>
<td>\(10^7\)</td>
<td>120</td>
<td>10</td>
<td>3.4</td>
<td>3.0</td>
</tr>
<tr>
<td>\(2*10^7\)</td>
<td>90</td>
<td>10</td>
<td>3.3</td>
<td>2.9</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>dimension</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(100\)</td>
<td>80</td>
<td>11</td>
<td>3.8</td>
<td>2.9</td>
</tr>
<tr>
<td>\(10^4\)</td>
<td>96</td>
<td>12</td>
<td>3.6</td>
<td>3.0</td>
</tr>
<tr>
<td>\(2*10^4\)</td>
<td>109</td>
<td>13</td>
<td>3.6</td>
<td>2.9</td>
</tr>
<tr>
<td>\(3*10^4\)</td>
<td>122</td>
<td>14</td>
<td>3.6</td>
<td>3.0</td>
</tr>
</tbody>
</table>
<h2 id="mala">MALA</h2>
<table>
<thead>
<tr>
<th>dataset size</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (2nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(10^3\)</td>
<td>57</td>
<td>14</td>
<td>3.2</td>
<td>2.4</td>
</tr>
<tr>
<td>\(10^6\)</td>
<td>56</td>
<td>14</td>
<td>6.9</td>
<td>5.8</td>
</tr>
<tr>
<td>\(10^7\)</td>
<td>83</td>
<td>54</td>
<td>46</td>
<td>44</td>
</tr>
<tr>
<td>\(2*10^7\)</td>
<td>126</td>
<td>98</td>
<td>89</td>
<td>86</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>dimension</th>
<th>python</th>
<th>JAX kernel</th>
<th>full JAX (1st run)</th>
<th>full JAX (1nd run)</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(100\)</td>
<td>57</td>
<td>14</td>
<td>3.6</td>
<td>2.7</td>
</tr>
<tr>
<td>\(10^4\)</td>
<td>72</td>
<td>16</td>
<td>5.4</td>
<td>3.6</td>
</tr>
<tr>
<td>\(2*10^4\)</td>
<td>88</td>
<td>17</td>
<td>9.4</td>
<td>5.7</td>
</tr>
<tr>
<td>\(3*10^4\)</td>
<td>101</td>
<td>19</td>
<td>12</td>
<td>7.8</td>
</tr>
</tbody>
</table>This post goes over 3 ways to write a sampler using JAX. I found that although there are a bunch of tutorials about learning the basics of JAX, it was not clear to me what was the best way to write a sampler in JAX. In particular, how much of the sampler should you write in JAX? Just the log-posterior (or the loss in the case of optimisation), or the entire loop? This blog post tries to answer this by going over 3 ways to write a sampler while focusing on the speed of each sampler.Implementing natural numbers in OCaml2020-04-06T08:00:00+00:002020-04-06T08:00:00+00:00/2020/04/06/NaturalNumbersOCaml<p>In this post we’re going to implement natural numbers (positive integers) in <a href="https://ocaml.org/">OCaml</a> to see how we can define numbers from first
principle, namely without using OCaml’s built in <code class="highlighter-rouge">Integer</code> type. We’ll then write a simple UI so that we have a basic (but inefficient) calculator. You can find all the code for this post on <a href="https://github.com/jeremiecoullon/natural_numbers_post">Github</a>.</p>
<h2 id="definition">Definition</h2>
<p>We’ll start with a recursive definition of natural numbers:</p>
\[n \in \mathcal{N} \iff n = \begin{cases}0 \\ S(m) \hspace{5mm} \text{for }m \in \mathcal{N}
\end{cases}\]
<p>We used the function \(S(m)\) which is called the <a href="https://en.wikipedia.org/wiki/Successor_function">successor function</a>. This simply returns the next natural number (for example \(S(0)=1\), and \(S(4)=5\)).</p>
<p>This definition means that a natural number is either \(0\) or the successor of another natural number. For example \(0\) is a natural number (the first case in the definition), but \(1\) is also a natural number, as it’s the successor of \(0\) (you would write \(1=S(0)\)). 2 can then be written as \(2 = S(S(0))\) , and so on. By using recursion (the definition of a natural number includes another natural number) we can “bootstrap” building numbers without using many other definitions.</p>
<p>We now write this definition as a type in OCaml, which looks a lot like the mathematical definition above:</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">type</span> <span class="n">nat</span> <span class="o">=</span>
<span class="o">|</span> <span class="nc">Zero</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="k">of</span> <span class="n">nat</span>
</code></pre></div></div>
<p>The vertical lines denote the two cases. Here you would write 1 as <code class="highlighter-rouge">Succ Zero</code>, 2 as <code class="highlighter-rouge">Succ Succ Zero</code>, and so on.</p>
<p>However we haven’t said what these numbers are (what <em>is</em> zero? What <em>are</em> numbers? ). To do
that we need to define how they act.</p>
<h2 id="some-operators">Some operators</h2>
<p>We’ll start off by defining how we can increment and decrement them.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="n">incr</span> <span class="n">n</span> <span class="o">=</span>
<span class="nc">Succ</span> <span class="n">n</span>
<span class="k">let</span> <span class="n">decr</span> <span class="n">n</span> <span class="o">=</span>
<span class="k">match</span> <span class="n">n</span> <span class="k">with</span>
<span class="o">|</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="nc">Zero</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="n">nn</span> <span class="o">-></span> <span class="n">nn</span>
</code></pre></div></div>
<p>The increment function simply adds a <code class="highlighter-rouge">Succ</code> before the number, so this corresonds to adding 1. So <code class="highlighter-rouge">incr (Succ Zero)</code> returns <code class="highlighter-rouge">Succ Succ Zero</code>. The decrement function checks whether the number <code class="highlighter-rouge">n</code> is <code class="highlighter-rouge">Zero</code> or the successor of a number. In the first case it simply returns <code class="highlighter-rouge">Zero</code> (So this means that <code class="highlighter-rouge">decr Zero</code> returns <code class="highlighter-rouge">Zero</code>. However this could be extended to include negative numbers). In the second case the function returns the number that precedes it. So <code class="highlighter-rouge">decr (Succ Succ Succ Zero)</code> returns <code class="highlighter-rouge">Succ Succ Zero</code>.</p>
<h3 id="addition">Addition</h3>
<p>We can now define addition as a recursive function which we denote by <code class="highlighter-rouge">++</code> (in OCaml we define <a href="https://en.wikipedia.org/wiki/Infix_notation">infix operators</a> using parentheses). So the addition function takes two elements <code class="highlighter-rouge">n</code> and <code class="highlighter-rouge">m</code> of type <code class="highlighter-rouge">nat</code> and returns an element of type <code class="highlighter-rouge">nat</code>. Note the <code class="highlighter-rouge">rec</code> added before the function name which means that it’s recursive.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="k">rec</span> <span class="p">(</span><span class="o">++</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="k">match</span> <span class="n">m</span> <span class="k">with</span>
<span class="o">|</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="n">n</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="n">mm</span> <span class="o">-></span> <span class="p">(</span><span class="nc">Succ</span> <span class="n">n</span><span class="p">)</span> <span class="o">++</span> <span class="n">mm</span>
</code></pre></div></div>
<p>Because we defined the function to be an infix operator we put it in between the arguments (ex: <code class="highlighter-rouge">Zero ++ (Succ Zero)</code>). This function checks whether <code class="highlighter-rouge">m</code> is <code class="highlighter-rouge">Zero</code> or the successor of a number. If it’s a successor of <code class="highlighter-rouge">mm</code> it returns the sum of <code class="highlighter-rouge">mm</code> and <code class="highlighter-rouge">Succ n</code>.</p>
<p>Let check that this definition behaves correctly by calculating 1+1 which we write as <code class="highlighter-rouge">(Succ Zero) ++ (Succ Zero)</code>. The first call to the function finds that the second argument is the successor of <code class="highlighter-rouge">Zero</code>, so returns the sum <code class="highlighter-rouge">(Succ Succ Zero) ++ Zero</code>. This calls the functions a second time which finds that the second argument is <code class="highlighter-rouge">Zero</code>. As a result the function return <code class="highlighter-rouge">Succ Succ Zero</code> which is 2 !</p>
<p>So in summary 1+1 is written as <code class="highlighter-rouge">(Succ Zero) ++ (Succ Zero)</code> = <code class="highlighter-rouge">(Succ Succ Zero) ++ Zero</code> = <code class="highlighter-rouge">Succ Succ Zero</code>. Math still works!</p>
<h3 id="subtraction">Subtraction</h3>
<p>We now define subtraction:</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="k">rec</span> <span class="p">(</span><span class="o">--</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="k">match</span> <span class="n">m</span> <span class="k">with</span>
<span class="o">|</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="n">n</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="n">mm</span> <span class="o">-></span> <span class="p">(</span><span class="n">decr</span> <span class="n">n</span><span class="p">)</span> <span class="o">--</span> <span class="n">mm</span>
</code></pre></div></div>
<p>This decrements both arguments until the second one is Zero. Note that if <code class="highlighter-rouge">m</code> is bigger than <code class="highlighter-rouge">n</code> then <code class="highlighter-rouge">n -- m</code> will still equal <code class="highlighter-rouge">Zero</code>.</p>
<h3 id="multiplication">Multiplication</h3>
<p>Moving on, we define multiplication:</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="p">(</span><span class="o">+*</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="k">let</span> <span class="k">rec</span> <span class="n">aux</span> <span class="n">n</span> <span class="n">m</span> <span class="n">acc</span> <span class="o">=</span>
<span class="k">match</span> <span class="n">m</span> <span class="k">with</span>
<span class="o">|</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="n">acc</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="n">mm</span> <span class="o">-></span> <span class="n">aux</span> <span class="n">n</span> <span class="n">mm</span> <span class="p">(</span><span class="n">n</span> <span class="o">++</span> <span class="n">acc</span><span class="p">)</span>
<span class="k">in</span>
<span class="n">aux</span> <span class="n">n</span> <span class="n">m</span> <span class="nc">Zero</span>
</code></pre></div></div>
<p>Here we use an auxiliary function (<code class="highlighter-rouge">aux</code>) which builds up the result in the accumulator <code class="highlighter-rouge">acc</code> by adding <code class="highlighter-rouge">n</code> to it <code class="highlighter-rouge">m</code> times. So applying this function to \(3\) and \(2\) gives: \(3*2 = 3*1 + 3 = 3*0 + 6 = 6\). And in code this is:</p>
<ul>
<li><code class="highlighter-rouge">(Succ (Succ (Succ Zero))) +* (Succ (Succ Zero))</code></li>
<li>Which returns <code class="highlighter-rouge">((Succ (Succ (Succ Zero))) +* (Succ Zero)) ++ (Succ (Succ (Succ Zero)))</code></li>
<li>Which returns <code class="highlighter-rouge">((Succ (Succ (Succ Zero))) +* Zero) ++ (Succ (Succ (Succ (Succ (Succ (Succ Zero)))))) </code></li>
<li>which returns <code class="highlighter-rouge">(Succ (Succ (Succ (Succ (Succ (Succ Zero))))))</code> (namely \(6\))</li>
</ul>
<h3 id="division">Division</h3>
<p>We also define the ‘strictly less than’ operator which we then use to define integer division.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="k">rec</span> <span class="p">(</span><span class="o"><<</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="k">match</span> <span class="p">(</span><span class="n">n</span><span class="o">,</span> <span class="n">m</span><span class="p">)</span> <span class="k">with</span>
<span class="o">|</span> <span class="p">(</span><span class="n">p</span><span class="o">,</span> <span class="nc">Zero</span><span class="p">)</span> <span class="o">-></span> <span class="bp">false</span>
<span class="o">|</span> <span class="p">(</span><span class="nc">Zero</span><span class="o">,</span> <span class="n">q</span><span class="p">)</span> <span class="o">-></span> <span class="bp">true</span>
<span class="o">|</span> <span class="p">(</span><span class="n">p</span><span class="o">,</span> <span class="n">q</span><span class="p">)</span> <span class="o">-></span> <span class="p">(</span><span class="n">decr</span> <span class="n">n</span><span class="p">)</span> <span class="o"><<</span> <span class="p">(</span><span class="n">decr</span> <span class="n">m</span><span class="p">)</span>
<span class="k">let</span> <span class="p">(</span><span class="o">//</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="k">let</span> <span class="k">rec</span> <span class="n">aux</span> <span class="n">p</span> <span class="n">acc</span> <span class="o">=</span>
<span class="k">let</span> <span class="n">lt</span> <span class="o">=</span> <span class="n">p</span> <span class="o"><<</span> <span class="n">m</span> <span class="k">in</span>
<span class="k">match</span> <span class="n">lt</span> <span class="k">with</span>
<span class="o">|</span> <span class="bp">true</span> <span class="o">-></span> <span class="n">acc</span>
<span class="o">|</span> <span class="bp">false</span> <span class="o">-></span> <span class="n">aux</span> <span class="p">(</span><span class="n">p</span> <span class="o">--</span> <span class="n">m</span><span class="p">)</span> <span class="p">(</span><span class="nc">Succ</span> <span class="n">acc</span><span class="p">)</span>
<span class="k">in</span>
<span class="n">aux</span> <span class="n">n</span> <span class="nc">Zero</span>
</code></pre></div></div>
<p>Like in the case of multiplication, the division function defines an auxiliary function that builds up the result in the accumulator <code class="highlighter-rouge">acc</code>. This function checks whether the first argument <code class="highlighter-rouge">p</code> is less than <code class="highlighter-rouge">m</code>. If it isn’t, then increment the accumulator by 1 and call <code class="highlighter-rouge">aux</code> again but with <code class="highlighter-rouge">p-m</code> as the first argument. Once <code class="highlighter-rouge">p</code> is less than <code class="highlighter-rouge">m</code> then return the accumulator. So this auxiliary function counts the number of times that <code class="highlighter-rouge">m</code> fits into <code class="highlighter-rouge">p</code>, which is exactly what integer division is. We run this function with <code class="highlighter-rouge">n</code> as first argument and with the accumulator as <code class="highlighter-rouge">Zero</code>.</p>
<p>Finally we can define the modulo operator. As we use previous definitions of division, multiplication, and subtraction, this definition is abstracted away from our implementation of natural numbers. This function gives the remainder when dividing <code class="highlighter-rouge">n</code> by <code class="highlighter-rouge">m</code>.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">let</span> <span class="p">(</span><span class="o">%</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="k">let</span> <span class="n">p</span> <span class="o">=</span> <span class="n">n</span> <span class="o">//</span> <span class="n">m</span> <span class="k">in</span>
<span class="n">n</span> <span class="o">--</span> <span class="p">(</span><span class="n">p</span> <span class="o">+*</span> <span class="n">m</span><span class="p">)</span>
</code></pre></div></div>
<h2 id="a-basic-ui">A basic UI</h2>
<p>We’ve defined the natural numbers and the basic operators, but it’s a bit unwieldy to use them in their current form. So we’ll write some code to convert them to the usual number system (represented as strings) and back.</p>
<h3 id="from-type-nat-to-string-representation">From type <code class="highlighter-rouge">nat</code> to string representation</h3>
<p>We’ll write some code to convert numbers to base 10 and then represent them in the usual Arabic numerals.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="n">ten</span> <span class="o">=</span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">)))))))))</span>
<span class="k">let</span> <span class="n">base10</span> <span class="n">n</span> <span class="o">=</span>
<span class="k">let</span> <span class="k">rec</span> <span class="n">aux</span> <span class="n">q</span> <span class="n">acc</span> <span class="o">=</span>
<span class="k">let</span> <span class="n">r</span> <span class="o">=</span> <span class="n">q</span> <span class="o">%</span> <span class="n">ten</span> <span class="k">in</span>
<span class="k">let</span> <span class="n">p</span> <span class="o">=</span> <span class="n">q</span> <span class="o">//</span> <span class="n">ten</span> <span class="k">in</span>
<span class="k">match</span> <span class="n">p</span> <span class="k">with</span>
<span class="o">|</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="n">r</span><span class="o">::</span><span class="n">acc</span>
<span class="o">|</span> <span class="n">pp</span> <span class="o">-></span> <span class="n">aux</span> <span class="n">p</span> <span class="p">(</span><span class="n">r</span><span class="o">::</span><span class="n">acc</span><span class="p">)</span>
<span class="k">in</span>
<span class="n">aux</span> <span class="n">n</span> <span class="bp">[]</span>
</code></pre></div></div>
<p>This function returns a list where each element corresponds to the number of 1s, 10s, 100s etc in the number. So if <code class="highlighter-rouge">n</code> is <code class="highlighter-rouge">Succ Succ Succ Succ Succ Succ Succ Succ Succ Succ Succ Succ Zero</code> (ie: 12), then <code class="highlighter-rouge">base10 n</code> returns <code class="highlighter-rouge">[Succ Zero; Succ Succ Zero]</code>.</p>
<p>We then define the 10 digits (with a hack for the cases bigger than 9) and put it all together in the function <code class="highlighter-rouge">string_of_nat</code>.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="n">print_nat_digits</span> <span class="o">=</span> <span class="k">function</span>
<span class="o">|</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"0"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"1"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"2"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"3"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"4"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"5"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"6"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"7"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"8"</span>
<span class="o">|</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Succ</span> <span class="nc">Zero</span> <span class="o">-></span> <span class="s2">"9"</span>
<span class="o">|</span> <span class="n">_</span> <span class="o">-></span> <span class="s2">"bigger than 9"</span>
<span class="k">let</span> <span class="n">string_of_nat</span> <span class="n">n</span> <span class="o">=</span>
<span class="k">let</span> <span class="n">base_10_rep</span> <span class="o">=</span> <span class="n">base10</span> <span class="n">n</span> <span class="k">in</span>
<span class="k">let</span> <span class="n">list_strings</span> <span class="o">=</span> <span class="nn">List</span><span class="p">.</span><span class="n">map</span> <span class="n">print_nat_digits</span> <span class="n">base_10_rep</span> <span class="k">in</span>
<span class="nn">String</span><span class="p">.</span><span class="n">concat</span> <span class="s2">""</span> <span class="n">list_strings</span>
</code></pre></div></div>
<p><code class="highlighter-rouge">string_of_nat</code> converts the number of type <code class="highlighter-rouge">nat</code> to base 10, then maps each of the list element to a string and concatenates those strings.</p>
<p>So <code class="highlighter-rouge">string_of_nat (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ Zero))))))))))))</code> returns <code class="highlighter-rouge">"12"</code> which is easier to read!</p>
<h3 id="from-string-representation-to-type-nat">From string representation to type <code class="highlighter-rouge">nat</code></h3>
<p>We then define some code to go the other way around: from string representation to natural numbers.</p>
<div class="language-ocaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">let</span> <span class="n">string_to_list</span> <span class="n">s</span> <span class="o">=</span>
<span class="k">let</span> <span class="k">rec</span> <span class="n">loop</span> <span class="n">acc</span> <span class="n">i</span> <span class="o">=</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="k">then</span> <span class="n">acc</span>
<span class="k">else</span>
<span class="n">loop</span> <span class="p">((</span><span class="nn">String</span><span class="p">.</span><span class="n">make</span> <span class="mi">1</span> <span class="n">s</span><span class="o">.</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">::</span> <span class="n">acc</span><span class="p">)</span> <span class="p">(</span><span class="n">pred</span> <span class="n">i</span><span class="p">)</span>
<span class="k">in</span> <span class="n">loop</span> <span class="bp">[]</span> <span class="p">(</span><span class="nn">String</span><span class="p">.</span><span class="n">length</span> <span class="n">s</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">let</span> <span class="n">nat_of_listnat</span> <span class="n">l</span> <span class="o">=</span>
<span class="k">let</span> <span class="n">lr</span> <span class="o">=</span> <span class="nn">List</span><span class="p">.</span><span class="n">rev</span> <span class="n">l</span> <span class="k">in</span>
<span class="k">let</span> <span class="k">rec</span> <span class="n">aux</span> <span class="n">n</span> <span class="n">b</span> <span class="n">lr</span> <span class="o">=</span>
<span class="k">match</span> <span class="n">lr</span> <span class="k">with</span>
<span class="o">|</span> <span class="bp">[]</span> <span class="o">-></span> <span class="n">n</span>
<span class="o">|</span> <span class="n">h</span><span class="o">::</span><span class="n">t</span> <span class="o">-></span> <span class="n">aux</span> <span class="p">(</span><span class="n">n</span> <span class="o">++</span> <span class="p">(</span><span class="n">b</span><span class="o">+*</span><span class="n">h</span><span class="p">))</span> <span class="p">(</span><span class="n">b</span><span class="o">+*</span><span class="n">ten</span><span class="p">)</span> <span class="n">t</span>
<span class="k">in</span>
<span class="n">aux</span> <span class="nc">Zero</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">)</span> <span class="n">lr</span>
<span class="k">let</span> <span class="n">nat_of_string_digits</span> <span class="o">=</span> <span class="k">function</span>
<span class="o">|</span> <span class="s2">"0"</span> <span class="o">-></span> <span class="nc">Zero</span>
<span class="o">|</span> <span class="s2">"1"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="nc">Zero</span>
<span class="o">|</span> <span class="s2">"2"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">)</span>
<span class="o">|</span> <span class="s2">"3"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">))</span>
<span class="o">|</span> <span class="s2">"4"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">)))</span>
<span class="o">|</span> <span class="s2">"5"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">))))</span>
<span class="o">|</span> <span class="s2">"6"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">)))))</span>
<span class="o">|</span> <span class="s2">"7"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">))))))</span>
<span class="o">|</span> <span class="s2">"8"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">)))))))</span>
<span class="o">|</span> <span class="s2">"9"</span> <span class="o">-></span> <span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="p">(</span><span class="nc">Succ</span> <span class="nc">Zero</span><span class="p">))))))))</span>
<span class="o">|</span> <span class="n">_</span> <span class="o">-></span> <span class="k">raise</span> <span class="p">(</span><span class="nc">Failure</span> <span class="s2">"string must be less than 10"</span><span class="p">)</span>
<span class="c">(* Converts string to nat *)</span>
<span class="k">let</span> <span class="n">nat_of_string</span> <span class="n">s</span> <span class="o">=</span>
<span class="k">let</span> <span class="n">liststring</span> <span class="o">=</span> <span class="n">string_to_list</span> <span class="n">s</span> <span class="k">in</span>
<span class="k">let</span> <span class="n">listNatbase</span> <span class="o">=</span> <span class="nn">List</span><span class="p">.</span><span class="n">map</span> <span class="n">nat_of_string_digits</span> <span class="n">liststring</span> <span class="k">in</span>
<span class="n">nat_of_listnat</span> <span class="n">listNatbase</span>
<span class="c">(*
final (infix) functions for adding, subtracting, multiplying, and dividing
which take strings as arguments and return a string
*)</span>
<span class="k">let</span> <span class="p">(</span><span class="o">+++</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="n">string_of_nat</span> <span class="p">((</span><span class="n">nat_of_string</span> <span class="n">n</span><span class="p">)</span> <span class="o">++</span> <span class="p">(</span><span class="n">nat_of_string</span> <span class="n">m</span><span class="p">))</span>
<span class="k">let</span> <span class="p">(</span><span class="o">---</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="n">string_of_nat</span> <span class="p">((</span><span class="n">nat_of_string</span> <span class="n">n</span><span class="p">)</span> <span class="o">--</span> <span class="p">(</span><span class="n">nat_of_string</span> <span class="n">m</span><span class="p">))</span>
<span class="k">let</span> <span class="p">(</span><span class="o">+**</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="n">string_of_nat</span> <span class="p">((</span><span class="n">nat_of_string</span> <span class="n">n</span><span class="p">)</span> <span class="o">+*</span> <span class="p">(</span><span class="n">nat_of_string</span> <span class="n">m</span><span class="p">))</span>
<span class="k">let</span> <span class="p">(</span><span class="o">///</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="n">string_of_nat</span> <span class="p">((</span><span class="n">nat_of_string</span> <span class="n">n</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">nat_of_string</span> <span class="n">m</span><span class="p">))</span>
<span class="k">let</span> <span class="p">(</span><span class="o">%%</span><span class="p">)</span> <span class="n">n</span> <span class="n">m</span> <span class="o">=</span>
<span class="n">string_of_nat</span> <span class="p">((</span><span class="n">nat_of_string</span> <span class="n">n</span><span class="p">)</span> <span class="o">%</span> <span class="p">(</span><span class="n">nat_of_string</span> <span class="n">m</span><span class="p">))</span>
</code></pre></div></div>
<p>So putting it all together, we have a working calculator for natural numbers!</p>
<p>Let’s try it out:</p>
<ul>
<li><code class="highlighter-rouge">"3" +++ "17"</code> returns <code class="highlighter-rouge">"20"</code></li>
<li><code class="highlighter-rouge">"182" --- "93"</code> returns <code class="highlighter-rouge">"89"</code></li>
<li><code class="highlighter-rouge">"12" +** "3"</code> returns <code class="highlighter-rouge">"36"</code></li>
<li><code class="highlighter-rouge">"41" /// "3"</code> returns <code class="highlighter-rouge">"13"</code></li>
<li><code class="highlighter-rouge">"41" %% "3"</code> returns <code class="highlighter-rouge">"2"</code></li>
</ul>
<h2 id="conclusion">Conclusion</h2>
<p>We have built up natural numbers from first principles and now have a working calculator. However these operators start getting very slow for numbers of around 7 digits or more, so sticking with built-in integers sounds preferable..</p>
<p><em>All the code for this post is on <a href="https://github.com/jeremiecoullon/natural_numbers_post">Github</a></em></p>
<p><em>Thanks to <a href="https://www.linkedin.com/in/james-jobanputra-62582669">James Jobanputra</a> for useful feedback on this post</em></p>In this post we’re going to implement natural numbers (positive integers) in OCaml to see how we can define numbers from first principle, namely without using OCaml’s built in Integer type. We’ll then write a simple UI so that we have a basic (but inefficient) calculator. You can find all the code for this post on Github.Testing MCMC code: the prior reproduction test2020-02-04T08:00:00+00:002020-02-04T08:00:00+00:00/2020/02/04/PriorReproductionTest<p><a href="https://darrenjw.wordpress.com/2010/08/15/metropolis-hastings-mcmc-algorithms/">Markov Chain Monte Carlo</a> (MCMC) is a class of algorithms for sampling from probability distributions. These are very useful algorithms, but it’s easy to go wrong and obtain samples from the wrong probability distribution. What’s more, it won’t be obvious if the sampler fails, so we need ways to check whether it’s working correctly.</p>
<p>This post is mainly aimed at MCMC practitioners and describes a powerful MCMC test called the Prior Reproduction Test (PRT). I’ll go over the context of the test, then explain how it works (and give some code). I’ll then explain how to tune it and discuss some limitations.</p>
<h2 id="why-should-we-test-mcmc-code-">Why should we test MCMC code ?</h2>
<p>There are two main ways MCMC can fail: either the chain doesn’t mix or the sampler targets the wrong distribution. We say that a chain mixes if it explores the target distribution in its entirety without getting stuck or avoiding a certain subset of the space. To check that a chain mixes, we use diagnostics such as running the chain for a long time and examining the trace plots, calculating the \(\hat{R}\) (or <a href="https://mc-stan.org/docs/2_21/reference-manual/notation-for-samples-chains-and-draws.html">potential scale reduction factor</a>), and using the multistart heuristic. See the <a href="https://www.mcmchandbook.net/">Handbook of MCMC</a> for a good overview of these diagnostics. These help check that the chain converges to a distribution.</p>
<p>However the target distribution of the sampler may not be the correct one. This could be due to a bug in the code or an error in the maths (for example the Hastings correction in the Metropolis-Hastings algorithm could be wrong). To test the software, we can do tests such as unit tests which check that individual functions act like they should. We can also do integration tests (testing the entire software rather than just a component). One such test is to try to recover simulated values (as recommended by the <a href="https://github.com/stan-dev/stan/wiki/Stan-Best-Practices#recover-simulated-values">Stan documentation</a>): generate data given some “true” parameters (using your data model) and then fit the model using the sampler. The true parameter that should be within the credible interval (loosely within 2 standard deviations of it). This checks that the sampler can indeed recover the true parameter.</p>
<p>However this test is only a “sanity check” and doesn’t check whether samples are truly from the target distribution. What’s needed here is a goodness of fit (GoF) test. As doing a GoF test for arbitrarily complex posterior distributions is hard, the PRT reduces the problem to testing that some samples are from the prior rather than the posterior. I had trouble finding books or articles written about this (a similar version of this test is described by Cook, Gelman, and Rubin <a href="http://www.stat.columbia.edu/~gelman/research/published/Cook_Software_Validation.pdf">here</a>, but they don’t call it PRT); if you know of any references let me know! <em>[Update April 2021: <a href="https://phylliswithdata.com/about/">Nianqiao (Phyllis) Ju</a> has pointed out some references for this in the literature: <a href="http://qed.econ.queensu.ca/pub/faculty/ferrall/quant/papers/04_04_29_geweke.pdf">Geweke</a> (2004) describes the same method. Other similar approaches: <a href="https://arxiv.org/pdf/1804.06788.pdf">Simulation Based Calibration</a> (2020), and <a href="https://arxiv.org/pdf/1412.5218.pdf">Testing MCMC Code</a> (2014)]</em>.</p>
<p>I know of this test from my PhD supervisor <a href="https://www.ucl.ac.uk/statistics/people/yvopokern">Yvo Pokern</a> who learnt it from another researcher during his postdoc. From talking to other researchers, it seems that this method has often been transmitted by word of mouth rather than from textbooks.</p>
<h2 id="the-prior-reproduction-test">The Prior Reproduction Test</h2>
<p>The prior reproduction test runs as follows: sample from the prior \(\theta_0 \sim \pi_0\), generate data using this prior sample \(X \sim p(X|\theta_0)\), and run the to-be-tested sampler long enough to get an independent sample from the posterior \(\theta_p \sim \pi(\theta|X)\). If the code is correct, the samples from the posterior should be distributed according to the prior.
One can repeat this procedure to obtain many samples \(\theta_p\) and test whether they are distributed according to the prior.</p>
<p>Here is the test in Python (code available on <a href="https://github.com/jeremiecoullon/PRT_post">Github</a>). First we define the observation operator \(\mathcal{G}\)) (the mapping from parameter to data, in this case simply the identity) along with the log-likelihood, log-prior, and log-posterior. So here our data is simply sampled from a Gaussian with mean 5 and standard deviation 3.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">G</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="s">"""
G(theta): observation operator. Here it's just the identity function, but it could
be a more complicated model.
"""</span>
<span class="k">return</span> <span class="n">theta</span>
<span class="c1"># data noise:
</span><span class="n">sigma_data</span> <span class="o">=</span> <span class="mi">3</span>
<span class="k">def</span> <span class="nf">build_log_likelihood</span><span class="p">(</span><span class="n">data_array</span><span class="p">):</span>
<span class="s">"Builds the log_likelihood function given some data"</span>
<span class="k">def</span> <span class="nf">log_likelihood</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="s">"Data model: y = G(theta) + eps"</span>
<span class="k">return</span> <span class="o">-</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">sigma_data</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span>
<span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">([(</span><span class="n">elem</span> <span class="o">-</span> <span class="n">G</span><span class="p">(</span><span class="n">theta</span><span class="p">))</span><span class="o">**</span><span class="mi">2</span> <span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">data_array</span><span class="p">])</span>
<span class="k">return</span> <span class="n">log_likelihood</span>
<span class="k">def</span> <span class="nf">log_prior</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="s">"uniform prior on [0, 10]"</span>
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="mi">0</span> <span class="o"><</span> <span class="n">theta</span> <span class="o"><</span> <span class="mi">10</span><span class="p">):</span>
<span class="k">return</span> <span class="o">-</span><span class="mi">9999999</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">build_log_posterior</span><span class="p">(</span><span class="n">log_likelihood</span><span class="p">):</span>
<span class="s">"Builds the log_posterior function given a log_likelihood"</span>
<span class="k">def</span> <span class="nf">log_posterior</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="k">return</span> <span class="n">log_prior</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span> <span class="o">+</span> <span class="n">log_likelihood</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span>
<span class="k">return</span> <span class="n">log_posterior</span>
</code></pre></div></div>
<p>We want to the test the code for a Metropolis sampler with Gaussian proposal (given in the <a href="https://github.com/jeremiecoullon/PRT_post/tree/master/MCMC"><code class="highlighter-rouge">MCMC</code> module</a>), so we run the PRT for it (the following code is in the <code class="highlighter-rouge">run_PRT()</code> function in <a href="https://github.com/jeremiecoullon/PRT_post/blob/master/PRT.py"><code class="highlighter-rouge">PRT.py</code></a>):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">results</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">B</span> <span class="o">=</span> <span class="mi">200</span>
<span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">B</span><span class="p">):</span>
<span class="c1"># sample from prior
</span> <span class="n">sam_prior</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">10</span><span class="p">)</span>
<span class="c1"># generate data points using the sampled prior
</span> <span class="n">data_array</span> <span class="o">=</span> <span class="n">G</span><span class="p">(</span><span class="n">sam_prior</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">sigma_data</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c1"># build the posterior function
</span> <span class="n">log_likelihood</span> <span class="o">=</span> <span class="n">build_log_likelihood</span><span class="p">(</span><span class="n">data_array</span><span class="o">=</span><span class="n">data_array</span><span class="p">)</span>
<span class="n">log_posterior</span> <span class="o">=</span> <span class="n">build_log_posterior</span><span class="p">(</span><span class="n">log_likelihood</span><span class="p">)</span>
<span class="c1"># define the sampler
</span> <span class="n">ICs</span> <span class="o">=</span> <span class="p">{</span><span class="s">'theta'</span><span class="p">:</span> <span class="mi">1</span><span class="p">}</span>
<span class="n">sd_proposal</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">mcmc_sampler</span> <span class="o">=</span> <span class="n">MHSampler</span><span class="p">(</span><span class="n">log_post</span><span class="o">=</span><span class="n">log_posterior</span><span class="p">,</span> <span class="n">ICs</span><span class="o">=</span><span class="n">ICs</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># add a Gaussian proposal
</span> <span class="n">mcmc_sampler</span><span class="p">.</span><span class="n">move</span> <span class="o">=</span> <span class="n">GaussianMove</span><span class="p">(</span><span class="n">ICs</span><span class="p">,</span> <span class="n">cov</span><span class="o">=</span><span class="n">sd_proposal</span><span class="p">)</span>
<span class="c1"># Get a posterior sample.
</span> <span class="c1"># Let the sampler run for 200 iterations to make sure it's independent from the initial condition
</span> <span class="n">mcmc_sampler</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">n_iter</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">print_rate</span><span class="o">=</span><span class="mi">300</span><span class="p">)</span>
<span class="n">last_sample</span> <span class="o">=</span> <span class="n">mcmc_sampler</span><span class="p">.</span><span class="n">all_samples</span><span class="p">.</span><span class="n">iloc</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">].</span><span class="n">theta</span>
<span class="c1"># store the results. Keep the posterior sample as well as the prior that generated the data
</span> <span class="n">results</span><span class="p">.</span><span class="n">append</span><span class="p">({</span><span class="s">'posterior'</span><span class="p">:</span> <span class="n">last_sample</span><span class="p">,</span> <span class="s">'prior'</span><span class="p">:</span> <span class="n">sam_prior</span><span class="p">})</span>
</code></pre></div></div>
<p>We then check that the posterior samples are uniformly distributed (i.e. the same as the prior) (see figure 1). Here we do this by eye, but we could have done this more formally (for example using the <a href="https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test">Kolmogorov-Smirnov test</a>).</p>
<figure class="post_figure">
<img src="/assets/PRT_post/empirical_CDF_data10.png" />
<figcaption>Figure 1: Empirical CDF of the output of PRT: these seem to be uniformly distributed</figcaption>
</figure>
<h2 id="tuning-the-prt">Tuning the PRT</h2>
<p>Notice how we let the sampler run for 200 iterations to make sure that the posterior sample we get is independent of the initial condition (<code class="highlighter-rouge">mcmc_sampler.run(n_iter=200, print_rate=300)</code>). The number of iterations used needs to be tuned to the sampler; if it’s slow then you’ll need more samples. This means that a slowly mixing sampler will cause the PRT to become more computationally expensive. We also needed to tune the proposal variance in the Gaussian proposal (called <code class="highlighter-rouge">sd_proposal</code>); ideally this will be a good tuning for any dataset generated in the PRT, but this may not always be the case. Sometimes the sampler needs hand tuning for each generated dataset; in this case it may also be too expensive to run the entire test. We’ll see later what other tests we can do in this case.</p>
<p>Finally, how do we choose the amount of data to generate (here we chose <code class="highlighter-rouge">10</code> data points)? Consider 2 extremes: if we choose too much data then the posterior will have a very low variance and will be centred around the true parameter. So almost any posterior sample we obtain will be close to the true parameter (which we sampled from the prior), and so the PRT will (trivially) produce samples from the prior. This doesn’t test the statistical properties of the sampler, but rather tests that the posterior is centred around the true parameter. In the other extreme case, if we have too little data the likelihood will have a weak effect on the posterior, which will then essentially be the prior. The MCMC sampler will then sample from a distribution that is very close to prior, and again the PRT becomes weaker. We therefore need to choose somewhere in the middle.</p>
<p>To tune the amount of data to generate we can plot the posterior vs the prior samples from the PRT as we can see in figure 2 below. Ideally there is a nice amount of variation around the line <code class="highlighter-rouge">y=x</code> as in the middle plot (for <code class="highlighter-rouge">N=10</code> data points). In the other two case the PRT will trivially recover prior samples and not test the software properly.</p>
<figure class="post_figure">
<img src="/assets/PRT_post/3_data_comparison.png" />
<figcaption>Figure 2: We need to tune the amount of data to generate in PRT</figcaption>
</figure>
<h2 id="limitations-and-alternatives">Limitations and alternatives</h2>
<p>In some cases however it’s not possible to run the PRT. The likelihood may be too computationally expensive; it might require solving numerically a differential equation for example. It’s also possible that the proposal distribution needs to be tuned for each dataset.
In this case you have to tune the proposal manually at each iteration of the PRT.</p>
<p>A way to deal with these problems is to only test conditionals of the posterior (in the case of higher dimensional posteriors).
For example if the posterior is \(\pi(\theta_1, \theta_2)\), then run the test on \(\pi(\theta_1 | \theta_2)\). In some cases this can solve the problem of needing to retune the proposal distribution for every dataset. This also helps with the problem of expensive likelihoods, as the dimension of the conditional posterior is lower than the original one. Less samples are then needed to run the test.</p>
<p>Another very simple alternative is to use the sampler to sample from the prior (so simply commenting out the likelihood function in the posterior). This completely bypasses the problem of expensive likelihoods and the need to retune the proposal at every step. This test checks that the MCMC proposal is correct (the Hastings correction for example), so is good for testing complicated proposals. However if the proposal needed to sample from the prior is qualitatively different from the proposal needed to sample from the posterior, then it’s not a useful test.</p>
<p>As mentioned in the introduction, the PRT reduces to testing goodness of fit of prior samples, the idea being that this is easier to test as prior distributions are often chosen for their simplicity. One can of course test goodness of fit on the MCMC samples directly (without the PRT) using a method such as the <a href="http://proceedings.mlr.press/v48/chwialkowski16.html">Kernel Goodness-of-fit test</a>. This avoids the problems discussed above, but it requires gradients of the log target density, whereas the PRT makes no assumptions about the target distribution.</p>
<h2 id="conclusions">Conclusions</h2>
<p>The Prior Reproduction Test is a powerful way to test MCMC code but can be expensive computationally. This test - along with its simplified versions described above - can be included in an arsenal of diagnostics to check that MCMC samples are from the correct distribution.</p>
<p><em>Code to reproduce the figures is on <a href="https://github.com/jeremiecoullon/PRT_post">Github</a></em></p>
<p><em>Thanks to <a href="http://herrstrathmann.de/">Heiko Strathmann</a> and <a href="https://uk.linkedin.com/in/lea-goetz-neuroscience">Lea Goetz</a> for useful feedback on this post</em></p>Markov Chain Monte Carlo (MCMC) is a class of algorithms for sampling from probability distributions. These are very useful algorithms, but it’s easy to go wrong and obtain samples from the wrong probability distribution. What’s more, it won’t be obvious if the sampler fails, so we need ways to check whether it’s working correctly.The DjangoVerse2019-11-27T11:01:52+00:002019-11-27T11:01:52+00:00/2019/11/27/DjangoVerse<p>The <a href="https://www.londondjangocollective.com/djangoverse/">DjangoVerse</a> is a 3D graph of gypsy jazz players around the world. I designed this with <a href="https://www.mattholborn.com">Matt Holborn</a> (he got the idea from <a href="https://www.coreymwamba.co.uk/resources/rhizome/">the Rhizome</a>) and built it using React and Django.</p>
<h2 id="how-does-it-work-">How does it work ?</h2>
<p>As anyone can modify it, people can <a href="https://www.londondjangocollective.com/djangoverse/forms/player/list">add themselves or players</a> they know to it. If you click on a player you get information about them: what instrument they play, a picture of them, a short bio, and a link to a youtube video of them. As the names are coloured by country, you can immediately see how many players there are in the different countries around the world. You can try out the DjangoVerse in the figure below:</p>
<figure style="text-align:center">
<iframe src="https://djangoversereact.s3.eu-west-2.amazonaws.com/index.html" style="width:96%; margin-left:2%; height:400px;"></iframe>
<figcaption><a href="https://www.londondjangocollective.com/djangoverse/">The DjangoVerse</a></figcaption>
</figure>
<p>The players have a link between them if they have gigged together, and if you click on a player you get those links highlighted in red. This allows you to see at a glance who they’ve played with and whether they’ve played with people from different countries. You can also filter the graph to only display players from chosen countries, based on the instruments they play, or whether or not they’re active. We started out by added around 60 players ourselves, and then shared it on Facebook and Instagram; the gypsy jazz community added the rest (there are 220 players across 21 countries at the time of writing).</p>
<h2 id="tech-stack">Tech stack</h2>
<p>I built the graph with React and <a href="https://github.com/vasturiano/react-force-graph">D3 force directed graph</a> and hosted it on S3 (<a href="https://github.com/jeremiecoullon/DjangoVerse-react">see code</a>). The API is built using Django and Postgres and is hosted on Heroku (with S3 for static files). As the DjangoVerse is part of the <a href="https://www.londondjangocollective.com/">London Django Collective</a>, I used the same <a href="https://github.com/jeremiecoullon/ldc">Django application</a> to serve the pages for the Collective as well as the API. As the React app with the graph is hosted on S3, the <a href="https://www.londondjangocollective.com/djangoverse/">page</a> in the Collective website simply has an iframe that points to it.</p>
<h1 id="the-design-process">The design process</h1>
<h2 id="a-first-attempt">A first attempt</h2>
<p>The main motivation was that I’ve wanted for a long time to create a 3D graph mapping links between related things (and had ideas about doing this for academic disciplines, jazz standards, and more). So this project was a way to scratch that itch. The objective more specifically was to be able to visualise the gypsy jazz scene in one place, discover new players and bands, and let people be able to promote their music/bands.</p>
<p>As a result we started off with many different types of nodes: players, bands, festivals, albums, and venues. So each of these would be added to the graph along with links between them. A link between a player and band would mean that a players is in a band, a link between a band and a festival would mean that it’s played at the festival, and so on. Each node would be a sphere of different size (the size would depend on the type) and the name would appear on hover; this was inspired by <a href="https://steemverse.com/">Steemverse</a> (a visualisation of a social network).</p>
<p>Furthermore, the links between two nodes would also have information about it, such as the year a band has played in a festival, or the years a player was active in a band. You would then be able to filter the graph to only show what happened in a given year, which would give a “snapshot” of the gypsy jazz scene at that moment in time.</p>
<h2 id="too-much-stuff">Too much stuff</h2>
<p>However, it quickly became clear that it was too much information: having all these types of nodes and information about the links would be too overwhelming to have in the graph. So we removed the venue and album types, along with the information about each link. We kept only the active/inactive tags which would allow to differentiate between the gypsy jazz scene in past and in the present.</p>
<p>We then tested a prototype (with players, bands, and venues all represented as spheres of different sizes) with some friends (see the classic <a href="https://www.amazon.co.uk/Dont-Make-Me-Think-Usability/dp/0321344758">Don’t Make Me Think</a> for an overview of user testing), and it turned out that it wasn’t very clear what the DjangoVerse was. For example one reaction was <em>“I’m guessing it’s a simulation of a molecule or something”</em>, which makes sense given that it essentially looked like <a href="https://vasturiano.github.io/3d-force-graph/example/async-load/">this</a>. This could maybe be fixed by adding names next to the nodes, but if you do this then D3 starts lagging quite quickly as you add many players.</p>
<p>Another problem was that festivals naturally ended up being at the centre of the graph, as they were the nodes with the most connections. The players and bands themselves then ended up seeming less important, even though we think a style of music is mainly about the players themselves rather than the festivals. As a visualisation is supposed to bring out the aspects of the data that the designer thinks is most important, we needed to have the players be more prominent.</p>
<h2 id="simplifying-the-design">Simplifying the design</h2>
<p>A fix to both of these problems was to simplify the graph again: we remove festivals and albums and kept just the players. We also just showed the names of the players rather than the spheres. As the names are immediately visible, a user can then recognise some of the players and guess immediately what this is about (this was confirmed with testing). However a downside of this is that having all the names rather than just spheres causes the graph to lag when there are more than 100 or so players. <a href="https://steemverse.com/">Steemverse</a> gets around this problem by only having names for the “category” types of nodes (which are rare); all other spheres only have names on hover.</p>
<p>For the aspect of users adding players, there is no authentication so anyone can add or modify a player without needing to log in. The benefit is that there is less of a barrier for people to add to the graph, but with the risk of people posting spam (or deleting all the players!). To mitigate this, I set up daily backups (easy to do with Heroku) which would allow to restore the graph to before there was a problem. If the problem persisted, I would have simply added authentication (for example OAuth with Google/Facebook/etc..).</p>
<h1 id="outcomes-and-comparison-to-other-graphs">Outcomes and comparison to other graphs</h1>
<p>Players on the gypsy jazz scene around the world added lots of players to the graph: there are 220 players spanning 21 countries and with 9 instruments represented. A feature that was used a lot was the possibility of adding a youtube video: this allows each player to showcase their music. The short bio for each player was also interesting; when we added the bio we didn’t think much of it nor consider too much how it would be used. However some of the users added information such as which players were related to each other (father, cousin etc..) which was really interesting!</p>
<h2 id="lessons">Lessons</h2>
<p>In terms of design, an important take-away to be learnt from graph visualisations such as this is about how much information to include in it. Although a main aspect of these visualisations is just “eye-candy” (ie: it looks fun), it would be good if it was also informative or insightful. At one end of the spectrum, if there is too little information then there is not much to learn from the visualisation. At the other extreme, if there is too much information (and the design isn’t done carefully) then it’s easy to get overwhelmed. For me, some examples of this are <a href="https://www.wikiverse.io/">Wikiverse</a> (it has a huge amount of information (it’s a subset of wikipedia!) and I find the interface very confusing), <a href="https://steemverse.com/">Steemverse</a> (it looks great, but there’s not much information in it) or the <a href="https://www.coreymwamba.co.uk/resources/rhizome/">Rhizome</a> (as it’s in only 2 dimensions, it’s hard to see what’s going on in the graph).</p>
<p>In contrast, an example of a simple graph that I think works well is this <a href="https://www.quantamagazine.org/frontier-of-physics-interactive-map-20150803/">map of “theories of everything”</a>. I don’t understand what these theories are (these are disciplines in theoretical physics), but the design is done very well and classifies them in a clear way.</p>
<p>Other examples of very well designed graphs are the ones built by <a href="http://concept.space/">concept.space</a>, such as this <a href="http://map.philosophies.space/">map of philosophy</a>. It has a huge amount of information, but most of it is hidden if you are zoomed out. As you zoom into a specific area of philosophy you get more and more detail about that area of philosophy until you have individual papers. When you click on a paper you then get the abstract and a link to it.</p>
<p>Notice also the minimap in the lower right hand corner that reminds you of where you currently are in the map. Finally, it seems that they have automated the process of adding and clustering the papers (from looking at the software <a href="http://philosophies.space/credits/">credited</a> on their website). They seemed to have scraped <a href="https://philpapers.org/">PhilPapers</a>, used <a href="https://code.google.com/archive/p/word2vec/">Word2Vec</a> to get word embeddings for each paper, <a href="https://github.com/lmcinnes/umap">reduced the dimension</a> of the space, and finally <a href="https://hdbscan.readthedocs.io/en/latest/">clustered</a> the result to find the location of each paper in the 2 dimensional map. As a result they could then use this workflow to create a similar map for <a href="http://map.climate.space/">climate science</a> and <a href="http://concept.space/projects/biomap/">biomedicine</a>.</p>
<p>In conclusion, the idea of a visual map showing the links between different things in a discipline (players in gypsy jazz, papers in philosophy, etc..) is a very appealing one. However, getting it right is surprisingly difficult; for me the best example is the map of philosophy described above.</p>
<p><em>Thanks to <a href="https://www.lukas.derungs.de/">Lukas DeRungs</a> for reading a draft of this post</em></p>The DjangoVerse is a 3D graph of gypsy jazz players around the world. I designed this with Matt Holborn (he got the idea from the Rhizome) and built it using React and Django.