Пишем try-catch в C не привлекая внимания санитаров

f0bb7505d31fba4ddd35c6e176b72f5c

Всё началось с безобидного пролистывания GCC расширений для C. Мой глаз зацепился за вложенные функции. Оказывается, в C можно определять функции внутри функций:

int main() {
    void foo(int a) {
        printf("%d\n", a);
    }
    for(int i = 0; i < 10; i ++)
        foo(i);
    return 0;
}

Более того, во вложенных функциях можно менять переменные из внешней функции и переходить по меткам из неё, но для этого необходимо, чтобы переменные были объявлены до вложенной функции, а метки явно указаны через __label__

int main() {
    __label__ end;
    int i = 1;

    void ret() {
        goto end;
    }
    void inc() {
        i ++;
    }
    
    while(1) {
        if(i > 10)
            ret();
        printf("%d\n", i);
        inc();
    }

  end:
    printf("Done\n");
    return 0;
}

Документация говорит, что обе внутренние функции валидны, пока валидны все переменные и мы не вышли из области внешней функции, то есть эти внутренни функции можно передавать как callback-и.

Приступим к написанию try-catch. Определим вспомогательные типы данных:

// Данными, как и выкинутой ошибкой может быть что угодно
typedef void *data_t;
typedef void *err_t;

// Определяем функцию для выкидывания ошибок
typedef void (*throw_t)(err_t);

// try и catch. Они тоже будут функциями
typedef data_t (*try_t)(data_t, throw_t);
typedef data_t (*catch_t)(data_t, err_t);

Подготовка завершена, напишем основную функцию. К сожалению на хабре нельзя выбрать отдельно язык C, поэтому будем писать try_, catch_, throw_ чтобы их подсвечивало как функции, а не как ключевые слова C++

data_t try_catch(try_t try_, catch_t catch_, data_t data) {
    __label__ fail;
    err_t err;
    // Объявляем функцию выбрасывания ошибки
    void throw_(err_t e) {
        err = e;
        goto fail;
    }
    // Передаём в try данные и callback для ошибки
    return try_(data, throw_);
    
  fail:
    // Если есть catch, передаём данные, над которыми 
    // работал try и ошибку, которую он выбросил
    if(catch_ != NULL)
        return catch_(data, err);
    // Если нет catch, возвращаем пустой указатель
    return NULL;
}

Напишем тестовую функцию взятия квадратного корня, с ошибкой в случае отрицательного числа

data_t try_sqrt(data_t ptr, throw_t throw_) {
    float *arg = (float *)ptr;
    if(*arg < 0)
        throw_("Error, negative number\n");
  
    // Выделяем кусок памяти для результата
    float *res = malloc(sizeof(float));
    *res = sqrt(*arg);
    return res;
}

data_t catch_sqrt(data_t ptr, err_t err) {
    // Если возникла ошибка, печатает её и ничего не возвращаем
    fputs(err, stderr);
    return NULL;
}

Добавляем функцию main, посчитаем в ней корень от 1 и от -1

int main() {
    printf("------- sqrt(1) --------\n");
    float a = 1;
    float *ptr = (float *) try_catch(try_sqrt, catch_sqrt, &a);

    if(ptr != NULL) {
        printf("Result of sqrt is: %f\n", *ptr);
        // Не забываем освободить выделенную память
        free(ptr);
    } else
        printf("An error occured\n");
    

    printf("------- sqrt(-1) -------\n");
    a = -1;
    ptr = (float *)try_catch(try_sqrt, catch_sqrt, &a);

    if(ptr != NULL) {
        printf("Result of sqrt is: %f\n", *ptr);
        // Аналогично
        free(ptr);
    } else
        printf("An error occured\n");
  
    return 0;
}

И, как и ожидалось, получаем

------- sqrt(1) --------
Result of sqrt is: 1.000000
------- sqrt(-1) -------
Error, negative number
An error occured

Try-catch готов, господа.

На этом статью можно было бы и закончить, но тут внимательный читатель заметит, что функция throw остаётся валидной в блоке catch. Можно вызвать её и там, и тогда мы уйдём в рекурсию. Заметим также, что функция throw, это не обычная функция, она noreturn и разворачивает стек, поэтому, даже если вызвать её в catch пару сотен раз, на стеке будет только последний вызов. Мы получаем хвостовую оптимизацию рекурсии.

Попробуем посчитать факториал на нашем try-catch. Для этого передадим указатель на функцию throw в функцию catch. Сделаем это через структуру, в которой также будет лежать аккумулятор вычислений.

struct args {
    uint64_t acc;
    throw_t throw_;
};

В функции try инициализируем поле throw у структуры, и заводим переменную num для текущего шага рекурсии.

data_t try_(data_t ptr, throw_t throw_) {
    struct args *args = ptr;
    // Записываем функцию в структуру, чтобы catch мог её pf,hfnm
    args->throw_ = throw_;
  
    // Заводим переменную для хранения текущего шага рекурсии
    uint64_t *num = malloc(sizeof(uint64_t));
    // Изначально в acc лежит начальное число, в нашем случае 10
    *num = args->acc; 
    // Уменьшаем число
    (*num) --;
    // Уходим в рекурсию
    throw_(num);
}

В функции catch будем принимать структуру и указатель на num, а дальше действуем как в обычном рекурсивном факториале.

data_t catch_(data_t ptr, err_t err) {
    struct args *args = ptr;
    // В err на самом деле лежит num
    uint64_t *num = err;
    // Печатаем num, будем отслеживать рекурсию
    printf("current_num: %"PRIu64"\n", *num);
    
    if(*num > 0) {
        args->acc *= *num;
        (*num) --;
        // Рекурсивный вызов
        args->throw_(num);
    }
    // Конец рекурсии
    // Не забываем осовободить выделенную память
    free(num);
    
    // Выводим результат
    printf("acc is: %"PRIu64"\n", args->acc);
    return &args->acc;
}
int main() {
    struct args args = { .acc = 10 };
    try_catch(try_, catch_, &args);

    return 0;
}

Вызываем, и получаем, как и ожидалось:

current_num: 9
current_num: 8
current_num: 7
current_num: 6
current_num: 5
current_num: 4
current_num: 3
current_num: 2
current_num: 1
current_num: 0
acc is: 3628800

main.c

#include 
#include 
#include 
#include 
#include 

typedef void *err_t;
typedef void *data_t;
typedef void (*throw_t)(err_t);
typedef data_t (*try_t)(data_t, throw_t);
typedef data_t (*catch_t)(data_t, err_t);


data_t try_catch(try_t try, catch_t catch, data_t data) {
    __label__ fail;
    err_t err;
    void throw(err_t e) {
        err = e;
        goto fail;
    }

    return try(data, throw);
    
  fail:
    if(catch != NULL)
        return catch(data, err);
    return NULL;
}

struct args {
    uint64_t acc;
    throw_t throw_;
};

data_t try_(data_t ptr, throw_t throw_) {
    struct args *args = ptr;
    args->throw_ = throw_;

    uint64_t *num = malloc(sizeof(uint64_t));
    *num = args->acc;
    (*num) --;
    
    throw_(num);
}

data_t catch_(data_t args_ptr, err_t num_ptr) {
    struct args *args = args_ptr;
    uint64_t *num = num_ptr;
    
    printf("current_num: %"PRIu64"\n", *num);
    
    if(*num > 0) {
        args->acc *= *num;
        (*num) --;
        args->throw_(num);
    }
    free(num);
    printf("acc is: %"PRIu64"\n", args->acc);
    return &args->acc;
}

int main() {
    struct args args = { .acc = 10 };
    try_catch(try_, catch_, &args);

    return 0;
}

Спасибо за внимание.

P.S. Текст попытался вычитать, но, так как русского в школе не было, могут быть ошибки. Прошу сильно не пинать и по возможности присылать всё в ЛС, постараюсь реагировать оперативно.

© Habrahabr.ru