Stopping All Threads in a VM

Why would you ever want to stop all the threads in a VM? If you've ever read anything about garbage collection, you might have heard about stop the world garbage collectors. It means pause execution of the running program, do some VM system specific tasks, and then continue execution. Another instance is if you want to JIT compile some code and deoptimize [paper] back to a normal execution state. In order to invalidate any code or JIT compile any code, you have to ensure that no threads are currently executing in the block of memory you will overwrite. How do we do this? In this post, I'm going to focus on the use case where you want to overwrite a block of memory and patch in new JIT compiled code.

Conceptually there are two ways to stop all threads in a running system. Both methods require the use of signals, mprotecting some area of memory, and catching the thrown signal. You can either a) Mark the target page as noexecute, or b) instrument code to load from a page of memory and read protect that page when you want to stop the thread. Once all the threads are stopped, you are now free to change the page however you want.

Why does this work? Essentially, every thread has it's own memory and each thread is executing at a different program counter (PC) location. All we really need is to ensure that each thread's current PC is not in a memory page we want to change. By making a thread catch a signal, the thread's PC jumps to a signal handler, evicting the current thread out of the target memory page. Now all that's left is implementing it!

Conceptually, we have five steps:

  1. Install the signal handler
  2. mprotect the target memory page 
  3. Catch the signal (stops all threads) 
  4. Alter the code in memory 
  5. Resume all threads

Installing the Signal Handler

For the following example, we're going to do what the HotSpot Java VM and Maxine JVM do: Assume code always performs a load from memory a dedicated page. When we want to raise the signal, we read protect the page, thereby raising the signal. More on why we use this technique will be in the next section.

There are two methods you can use to install a signal: SIGNAL or SIGACTION. From what I've been reading on the internet, SIGNAL is deprecated in Linux and it's better to use SIGACTION. SIGACTION is also more flexible than SIGNAL because you have more control about what happens when a signal is thrown.

The other piece of information we need is to know which signal to catch. A load error on OSX and Linux both throw the SIGSEGV signal. If you've ever seen a segfault crash, it's because the program didn't catch a SIGSEGV signal. To install a signal handler for SIGSEGV, let's use the following code snippet:

   1:  void installSignalHandlers() {
   2:      struct sigaction signalInfo;
   3:      signalInfo.sa_sigaction = catchSignal;
   4:      sigemptyset(&signalInfo.sa_mask);
   5:      signalInfo.sa_flags = SA_SIGINFO;
   6:      sigaction(SIGSEGV, &signalInfo, NULL);
   7:  } 

What is this code doing? First, sigaction is a struct provided by . The sa_sigaction field points to a function which will handle the signal when a SIGSEGV signal is thrown. sa_mask let's you describe which signals should be blocked from being caught in this process. sa_flags with sa_siginfo tells sigaction to use our custom signal handler, in this case a method named catchSignal. Finally, sigaction installs the actual signal. You can read more by reading the manpage for sigaction.

MProtecting the target page

Now that we can handle a signal, we can let the threads continue executing as is. However, let's say we want to change the currently executing code for a function foo. We want to stop all the threads so we have to throw the signal. We could just call raise(), which would immediatley stop all threads. However, the load from a page gives VMs one nice trick. The load should occur after a safepoint, which is where the VM can safely stop the world and then restore it once its business is finished. A VM cannot stop the world at every PC location, only at safepoints. Safepoints are points in the program that have metadata about the program state such that we can restore the program to it's correct state after some VM work. That's why loads are used instead of just raising the signal. Now let's mprotect the target page:

   1:  void markPageUnreadable(void* location) {
   2:    long address = (long) location;
   3:    void* pageStart = (void*)(address - (address % getpagesize()));
   4:    int protection = 0;
   5:    int result = mprotect(pageStart, getpagesize(), protection);
   6:    assert (result == 0);
   7:  }

mprotect only works at page boundaries. The instrumented load may not necessarily be page aligned, which means you have to find the page boundary before you can mprotect the page. Most unix* provide a function getpagesize(), which tells you how big a page of memory is. So all you need to do is pass in the location of the function, align the page, and mprotect the page with no read access. Now once a thread tries to load from the page, a signal will automatically be thrown! There is a small nuance here. Realistically, this means that once we mprotect a page, every thread does not stop immediatley. It only stops the next time it tries to load a value from the protected page. The VM must wait and check to see if every thread has stopped, but this usually happens very quickly. Another note is that we can't actually guarantee that the instrumented load will be the only valid load from a page. Hence, VMs malloc a whole page and ensure that this one page is dedicated only for this deopt purpose.

Catching the Signal

Now remember when we installed the signal, the sigaction.sa_sigaction = catchSignal? This means that we need a function catchSignal that will actually do some work. The nice thing is that each thread will automatically call this method, ensuring that the thread's PC is now outside of the page. Now in the signal handler, you can deopt, collect garbage, etc. Most threads can just busy wait or you can have one thread change the world. The catchSignal function must conform to the signal handler's function definition which is:

void handler(int, siginfo_t*, ucontext_t*);

Now let's see a sample implementation:

   1:  void catchSignal(int signum, siginfo_t* sigInfo, ucontext_t* context) {
   2:    void* global = sigInfo->si_addr;
   3:    if (cleanup) {
   4:      JIT::compile(global);
   5:    } else {
   6:      wait();
   7:    }
   8:    longjmp(..)
   9:  }

The sigInfo->si_addr tells us the thread's current PC when the signal was thrown. Now we know the address of the method we want to recompile it. A flag directs a single thread to perform some work while every other thread waits for a sign. Finally, the signal handler returns to the virtual PC where the current thread was executing. Since we invalidated the code, we don't actually want to return to the previous machine PC. setjmp/longjmp is useful for this reason - it is an actual machine jmp instruction where execution can continue somewhere else. So now we have one thread that will recompile the function while every other thread waits for the code to be patched.

Alter the Code in Memory

Since the world is effectively stopped, and only one thread is executing, the world looks pretty simple. We can now do standard JIT compilation techniques and patch any area of memory that we want. Once we've recompiled code, we just have to start the world again.

Resume All Threads

Now we JIT compiled or deoptimized a piece of code or cleaned up garbage. Now we want to resume the actual program execution, which means telling each thread to stop waiting() and resume execution. There are multiple ways to stop and start threads depending on your threading library. Pthreads has this notion of wait() and signal() which are different from system signals. Signal in pthreads means "wake up" and resume execution.

Conclusion

Today I Learned how to stop the world and patch some code. Done and done.