Coroutine

教材

setjmp/longjmp

cross function goto

#include <stdio.h>
#include <setjmp.h>

static jmp_buf buf;

void second() {
    printf("second\n");
    longjmp(buf, 1); // jump back to where setjmp was called and  make it return 1
}

void first() {
    second();
    print("first\n");
}

int main() {
    if (!setjmp(buf)) // when executed, setjmp returned 0
        first();
    else
        printf("main\n");
    
    return 0;
}

compile and run

gcc example.c -o example && ./example
  • int setjmp(jmp_buf buf):
The functions described on this page are used for performing
"nonlocal gotos": transferring execution from one function to a
predetermined location in another function.  The setjmp() function
dynamically establishes the target to which control will later be
transferred, and longjmp() performs the transfer of execution.

callee-saved暫存器狀態存到jmp_buf,這是暫存器狀態應該跨呼叫應該保持的

Callee-saved registers (AKA non-volatile registers, or call-preserved) are used to hold long-lived values that should be preserved across calls.

rip存到jmp_buf

  • void longjmp(jmp_buf env, int val);
The longjmp() function shall restore the environment saved by the
most recent invocation of setjmp() in the same process, with the
corresponding jmp_buf argument. I

回復setjmp時候的狀態,
並且從執行setjmp的下一行指令執行

non-rpeemptive schedulers

#include <stdbool.h>
#include <stdio.h>
#include <setjmp.h>

jmp_buf* current_buffer;
jmp_buf main_buffer;

#define ARRAY_SIZE(arr) (sizeof(arr)/sizeof(arr[0]))
#define YIELD() { if (!setjmp(*current_buffer)) longjmp(main_buffer, 1); }

void task_0()
{
    while(true) {
        printf("0\n");
        sleep(1);
        YIELD();
    }
}

void task_1()
{
    while(true) {
        printf("1\n");
        sleep(1);
        YIELD();
    }
}

int main()
{
    void(*tasks[])(void) = {task_0, task_1}; // array of function ptr
    jmp_buf buffers[ARRAY_SIZE(tasks)];
    bool started = false;
    
    while(true) {
        for (int i=0; i < ARRAY_SIZE(tasks); i++) {
            if (setjmp(main_buffer)) { // yeild finishes back to here
                continue;
            }
            
            current_buffer = &buffers[i];
            if (!started) {
                tasks[i]();
            } else {
                printf("longjmp %d\n", i);
                longjmp(buffers[i], 1); // back to YIELD()'s setjmp
            }
        }
        printf("end for\n");
        started = true;
    }
    
    return 0;
}

stack 位移問題:導致任何區域變數會造成原本的堆疊被改寫
理由:全部co-routine分享一個stack,longjmp會改變stack ptr
導致原本的堆疊變數被覆寫

#include <stdbool.h>
#include <stdio.h>
#include <setjmp.h>

jmp_buf* current_buffer;
jmp_buf main_buffer;

#define ARRAY_SIZE(arr) (sizeof(arr)/sizeof(arr[0]))
#define YIELD() { if (!setjmp(*current_buffer)) longjmp(main_buffer, 1); }

void task_0()
{
    int i, j = 0;
    while(true) {
        printf("0: %d %d\n", i++, j++);
        sleep(1);
        YIELD();
    }
}

void task_1()
{
    while(true) {
        printf("1\n");
        sleep(1);
        YIELD();
    }
}

int main()
{
    void(*tasks[])(void) = {task_0, task_1}; // array of function ptr
    jmp_buf buffers[ARRAY_SIZE(tasks)];
    bool started = false;
    
    while(true) {
        for (int i=0; i < ARRAY_SIZE(tasks); i++) {
            if (setjmp(main_buffer)) { // yeild finishes back to here
                continue;
            }
            
            current_buffer = &buffers[i];
            if (!started) {
                tasks[i]();
            } else {
                printf("longjmp %d\n", i);
                longjmp(buffers[i], 1);
            }
        }
        printf("end for\n");
        started = true;
    }
    
    return 0;
}

解決辦法: 增加stacks

#include <stdbool.h>
#include <stdio.h>
#include <setjmp.h>

jmp_buf* current_buffer;
jmp_buf main_buffer;

#define ARRAY_SIZE(arr) (sizeof(arr)/sizeof(arr[0]))
#define YIELD() { if (!setjmp(*current_buffer)) longjmp(main_buffer, 1); }

void task_0()
{
    int i, j = 0;
    while(true) {
        printf("0: %d %d\n", i++, j++);
        sleep(1);
        YIELD();
    }
}

void task_1()
{
    while(true) {
        printf("1\n");
        sleep(1);
        YIELD();
    }
}

int main()
{
    void(*tasks[])(void) = {task_0, task_1}; // array of function ptr
    jmp_buf buffers[ARRAY_SIZE(tasks)];
    char stacks[ARRAY_SIZE(tasks)][1024]; // 1kb stack size
    bool started = false;
    
    while(true) {
        for (int i=0; i < ARRAY_SIZE(tasks); i++) {
            if (setjmp(main_buffer)) { // yeild finishes back to here
                continue;
            }
            
            current_buffer = &buffers[i];
            if (!started) {
                // tasks[i]();
		// stack goes down on push
		char* stack = stacks[i] + sizeof(stacks[i]);
		asm("movq %0, %%rax;"
		    "movq %1, %%rsp;"
		    "call *%%rax"
		    :: "rm" (tasks[i]), "rm" (stack) : "rax");
            } else {
                printf("longjmp %d\n", i);
                longjmp(buffers[i], 1);
            }
        }
        printf("end for\n");
        started = true;
    }
    
    return 0;
}

Explanation of asm call

Put the address of the work (tasks[i]) into a special storage called rax.

Think of rax like a temporary container the CPU uses to remember something important.

  • Replace the current workspace (rsp) with your prepared workspace (stack).

    The rsp is like the computer's "desk" space where it keeps temporary stuff.

    Now, you're telling the CPU to switch and use your special workspace (stack) instead of the current one.

  • Start executing (call) the function stored at the address you just put into rax.

    This is like telling the CPU: "Go do the work I gave you earlier, using the new workspace we set up."

coro

source: https://github.com/sysprog21/concurrent-programs/tree/master/coro

Preemptive scheduling

目標

  • 學習使用ucontext.h
  • Linux signal API
  • 用SIGALRM完成搶佔式排程

source

API 學習

ucontext.h

  • #include <ucontext.h>

  • struct ucontext_t

typedef struct ucontext_t {
   struct ucontext_t *uc_link; // point to context will resume when current context terminated
   sigset_t          uc_sigmask; // signal blocked in this context
   stack_t           uc_stack; // stack by this context
   mcontext_t        uc_mcontext; // machine-
       specific representation of the saved context
   ...
} ucontext_t;
  • getcontext(ucontext_t *ucp): initialize the structure pointed by ucp to current active context: https://man7.org/linux/man-pages/man3/getcontext.3.html
  • setcontext(const ucontext_t *ucp): restore the user context pointed to by ucp
  • makecontext(ucontext_t *ucp, typeof(void (int arg0, )) *func,
    int argc, ): modify the context pointed by ucp. (need to initialize the stack, and define successor ucontext_t). args can only be integers
  • int swapcontext(ucontext_t *restrict oucp,
    const ucontext_t *restrict ucp): save the current context to oucp and activate the context pointed by ucp (主要switch)

signal.h

  • #include<signal.h> man
  • struct sigaction
struct sigaction {
    void (*sa_handler)(int); // default handler
    void (*sa_sigaction)(int, siginfo_t *, void); // when SA_SIGINFO set in sa_flags, will use this handler
    sigset_t sa_mask;
    int sa_flags;
    void (*sa_restorer)(void);
}
  • sigaddset(sigset_t *set, int signo): add signo to the set
  • int sigdelset(sigset_t *set, int signo): delete a signal from the set
  • sigfillset(sigset_t *set): all signals defined in the volume are inclued
  • int sigaction(int signum,
    const struct sigaction *_Nullable restrict act,
    struct sigaction *_Nullable restrict oldact): change an action of signal token by a process (SIGKILL, SIGSTOP not included)
  • int sigprocmask(int how, const sigset_t *_Nullable restrict set,
    sigset_t *_Nullable restrict oldset)
 // what current sinal is blocked
sigprocmask(0, NULL, &mask);

// set sig_set to the current signal mask
// https://man7.org/linux/man-pages/man2/sigprocmask.2.html
sigprocmask(SIG_SETMASK, sig_set, NULL);

// union the current blocked signals with the block_set
// old signal mask is saved in sig_set
sigprocmask(SIG_BLOCK, &block_set, sig_set);

unistd

  • #include <unistd.h>
  • useconds_t ualarm(useconds_t usecs, useconds_t interval): send SIGALRM to the invoking process after usecs. If interval is not 0, will trigger SIGALRM signal every interval usec

Task resched

static void task_add(task_callback_t *func, void *param) {
  struct task_struct *task = task_alloc(func, param);
  // https://linux.die.net/man/3/getcontext
  // The function getcontext() initializes the structure pointed at by ucp to
  // the currently active context.
  if (getcontext(&task->context) == -1)
    abort();

  task->context.uc_stack.ss_sp = task->stack;
  task->context.uc_stack.ss_size = 1 << 20;
  task->context.uc_stack.ss_flags = 0;
  task->context.uc_link = NULL;

  union task_ptr ptr = {.p = task};
  makecontext(&task->context, (void (*)(void))task_trampoline, 2, ptr.i[0],
              ptr.i[1]);

  /* When we switch to it for the first time, timer signal must be blocked.
   * Paired with task_trampoline().
   */
  sigaddset(&task->context.uc_sigmask, SIGALRM);

  // add task to task list
  preempt_disable();
  list_add_tail(&task->list, &task_main.list);
  preempt_enable();
}

設定task context, stack, callback
執行時會把timer signal block

union task_ptr {
    void *p;
    int i[2];
};

static void local_irq_restore_trampoline(struct task_struct *task)
{
    sigdelset(&task->context.uc_sigmask, SIGALRM);
    local_irq_restore(&task->context.uc_sigmask);
}

__attribute__((noreturn)) static void task_trampoline(int i0, int i1)
{
    // union trick to get task_struct pointer
    union task_ptr ptr = {.i = {i0, i1}};
    struct task_struct *task = ptr.p;

     /* We switch to trampoline with blocked timer.  That is safe.
     * So the first thing that we have to do is to unblock timer signal.
     * Paired with task_add().
     */
    local_irq_restore_trampoline(task);
    task->callback(task->arg);
    task->reap_self = true;
    schedule();

    __builtin_unreachable();
}

透過task_ptr union重新把task結構組裝回來, 因為makecontext只限制整數參數
這裡重新允許timer signal 處理
執行callback(sort) 並排程

static void schedule(void)
{
    sigset_t set;
    local_irq_save(&set);

    struct task_struct *next_task = 
        list_first_entry(&task_current->list, struct task_struct, list);
    if (next_task) {
        if (task_current->reap_self)
            list_move(&task_current->list, &task_reap);
        task_switch_to(task_current, next_task);
    }

    struct task_struct *task, *tmp;
    // delete task shall use safe api
    list_for_each_entry_safe(task, tmp, &task_reap, list)
        task_destroy(task);
    
    local_irq_restore(&set);
}

scheduler主要邏輯

static void task_switch_to(struct task_struct *from, struct task_struct *to)
{
    task_current = to;
    // https://linux.die.net/man/3/swapcontext
    // save the current context in from
    // and switch to the context of to to task
    swapcontext(&from->context, &to->context);
}

context 轉換

// timer part
static void timer_handler(int signo, siginfo_t *info, ucontext_t *ctx)
{
    if (preempt_count) /* once preemption is disabled */
        return;
    
     /* We can schedule directly from sighandler because Linux kernel cares only
     * about proper sigreturn frame in the stack.
     */
    // run out of timeslice
    schedule();
}

static void timer_init(void)
{
    struct sigaction sa = {
        .sa_handler = (void (*)(int)) timer_handler,
        .sa_flags = SA_SIGINFO,
    };
    // https://man7.org/linux/man-pages/man3/sigfillset.3p.html
    // fill in all signals to the set
    sigfillset(&sa.sa_mask);
    // https://man7.org/linux/man-pages/man2/sigaction.2.html
    // sigaction() examines the new action pointed to by act and fills in the
    // fields of the sigaction structure pointed to by oldact with the
    // signum specifies the signal and can be any valid signal except
    // SIGKILL and SIGSTOP.
    sigaction(SIGALRM, &sa, NULL);
}

static void timer_create(unsigned int usecs)
{
    ualarm(usecs, usecs);
}

static void timer_cancel(void)
{
    ualarm(0, 0);
}

static void timer_wait(void)
{
    sigset_t mask;
    // what current sinal is blocked
    sigprocmask(0, NULL, &mask);
    // remove SIGALRM from the mask
    sigdelset(&mask, SIGALRM);
    // sigsuspend() temporarily changes the signal mask of the process and waits for a signal to arrive.
    // the process awake when SIGALRM arrives
    sigsuspend(&mask);
}

main

int main()
{
    timer_init();
    task_init();

    task_add(sort, "task1"), task_add(sort, "task2"), task_add(sort, "task3");

    preempt_disable();
    timer_create(10000); // 10ms

    while (!list_empty(&task_main.list) || !list_empty(&task_reap)) {
        preempt_enable();
        timer_wait();
        preempt_disable();
    }

    preempt_enable();
    timer_cancel();

    return 0;
}

主要流程:

  • 建立SIGALRM timer
  • 加入task
  • task 透過註冊的task_trampoline去執行真正的callback,這裡在前面拼湊task結構是SIGALRM blocked,自己認為是前面不該被搶佔
  • schedule函式是主要從全域變數task_current找尋串列是否有可以執行的task
  • main函式透過每一次timer interrupt來檢查全域變數任務串列task_main和回收任務串列task_reap是否為空, 全為空代表排程器完成它應有工作
  • 為什麼task_trampoline後面會unreachable? 理由:因為已經被標註成reap_self, 不再會被排程器執行

防搶佔和irq set/restore

這裡模仿實際kernel會使用到的方式

static void preempt_disable(void)
{
    preempt_count++;
}
static void preempt_enable(void)
{
    preempt_count--;
}

static void local_irq_save(sigset_t *sig_set)
{
    sigset_t block_set;
    sigfillset(&block_set);
    sigdelset(&block_set, SIGINT);
    // union the current signal mask with the block_set
    // old signal mask is saved in sig_set
    sigprocmask(SIG_BLOCK, &block_set, sig_set);
}

static void local_irq_restore(sigset_t *sig_set)
{
    // set sig_set to the current signal mask
    // https://man7.org/linux/man-pages/man2/sigprocmask.2.html
    sigprocmask(SIG_SETMASK, sig_set, NULL);
}

實例:防搶佔 printf

#define task_printf(...)    \
    ({                      \
        preempt_disable();  \
        printf(__VA_ARGS__);\
        preempt_enable();   \
    })

Tinync