Аналитическое вычисление производных на шаблонах C++

Тут на днях писали про аналитическое нахождение производных, что напомнило мне об одной моей маленькой библиотечке на C++, которая делает почти то же, но во время компиляции.

7d7d2bf029364d65ad444fd10d692fe0.png

В чём профит? Ответ прост: мне нужно было запрогать нахождение минимума достаточно сложной функции, считать производные этой функции по её параметрам ручкой на бумажке было лень, проверять потом, что я не опечатался при написании кода, и поддерживать этот самый код — лень вдвойне, поэтому было решено написать штуковину, которая это сделает за меня. Ну, чтобы в коде можно было написать что-то такое:

using Formula_t = decltype (k * (_1 - r0) / (_1 + r0) * (g0 / (alpha0 - logr0 / Num<300>) - _1));    // сама формула
const auto residual = Formula_t::Eval (datapoint) - knownValue;    // регрессионный остаток

// производные по параметрам:
const auto dg0 = VarDerivative_t::Eval (datapoint);
const auto dalpha0 = VarDerivative_t::Eval (datapoint);
const auto dk = VarDerivative_t::Eval (datapoint);

вместо крокодилов, которые получатся, если брать частные производные функции на картинке вначале (вернее, некоторого её упрощённого варианта, но он выглядит не так страшно).

Ещё неплохо быть достаточно уверенным, что компилятор это соптимизирует так, как если бы соответствующие производные и функции были написаны руками. А уверенным быть бы хотелось — находить минимум нужно было очень много раз (действительно много, где-то от сотни миллионов до миллиарда, в этом была суть некоего вычислительного эксперимента), поэтому вычисление производных было бы бутылочным горлышком, происходи оно во время выполнения через какую-нибудь рекурсию по древообразной структуре. Если же заставить компилятор вычислять производную, собственно, во время компиляции, то есть шанс, что он по получившемуся коду ещё пройдётся оптимизатором, и мы не потеряем по сравнению с ручным выписыванием всех производных. Шанс реализовался, кстати.

Под катом — небольшое описание, как оно там всё работает.

Начнём с представления функции в программе. Почему-то так получилось, что каждая функция — это тип. Функция — это ещё и дерево выражений, и узел этого дерева представляется типом Node:

template
struct Node;

Здесь NodeClass — тип узла (переменная, число, унарная функция, бинарная функция), Args — параметры этого узла (индекс переменной, значение числа, дочерние узлы).

Узлы умеют себя дифференцировать, печатать и вычислять для данных значений свободных переменных и параметров. Так, если определен тип для представления узла с обычным числом:

using NumberType_t = long long;

template
struct Number {};

то специализация узла для чисел тривиальна:
template
struct Node>
{
	template
	using Derivative_t = Node>;

	static std::string Print ()
	{
		return std::to_string (N);
	}

	template
	static typename Vec::value_type Eval (const Vec&)
	{
		return N;
	}

	constexpr Node () {}
};

Производная любого числа по любой переменной — ноль (за это отвечает тип Derivative_t, оставим пока его шаблонные параметры). Распечатать число — тоже просто (см. Print()). Вычислить узел с числом — вернуть это число (см. Eval(), шаблонный параметр Vec обсудим позже).

Переменная представляется похожим образом:

template
struct Variable {};

Здесь Family и Index — «семейство» и индекс переменной. Так, для dabfa625489b4c9896ae6465fea6e913.png они будут равняться 'w' и 1, а для a713b7bb0d24499aaf59143a8e432714.png — 'x' и 2 соответственно.

Узел для переменной определяется чуть интереснее, чем для числа:

template
struct Node>
{
	template
	using Derivative_t = std::conditional_t>,
			Node>>;

	static std::string Print ()
	{
		return std::string { Family, '_' } + std::to_string (Index);
	}

	template
	static typename Vec::value_type Eval (const Vec& values)
	{
		return values (Node {});
	}

	constexpr Node () {}
};

Так, производная переменной по ей же самой равна единице, а по любой другой — нулю. Собственно, параметры FPrime и IPrime для типа Derivative_t — это семейство и индекс переменной, по которой требуется взять производную.

Вычисление значения функции, состоящей из одной переменной сводится к её нахождению в словаре значений values, который передаётся в функцию Eval(). Словарь сам умеет находить значение нужной переменной по её типу, поэтому ему мы просто передадим тип нашей переменной и вернём соответствующее значение. Как словарь это делает, мы рассмотрим позже.

С унарными функциями всё становится ещё интереснее.

enum class UnaryFunction
{
	Sin,
	Cos,
	Ln,
	Neg
};

template
struct UnaryFunctionWrapper;

В специализации UnaryFunctionWrapper мы запихнём логику по взятию производных каждой конкретной унарной функции. Чтобы минимально дублировать код, будем брать производную унарной функции по её аргументу, за дальнейшее дифференцирование аргумента по целевой переменной через chain rule будет отвечать вызывающий код:

template<>
struct UnaryFunctionWrapper
{
	template
	using Derivative_t = Node;
};

template<>
struct UnaryFunctionWrapper
{
	template
	using Derivative_t = Node>;
};

template<>
struct UnaryFunctionWrapper
{
	template
	using Derivative_t = Node>, Child>;
};

template<>
struct UnaryFunctionWrapper
{
	template
	using Derivative_t = Node>;
};

Тогда сам узел выглядит следующим образом:

template
struct Node, Node>
{
	using Child_t = Node;

	template
	using Derivative_t = Node::template Derivative_t,
			typename Node::template Derivative_t>;

	static std::string Print ()
	{
		return FunctionName (UF) + "(" + Node::Print () + ")";
	}

	template
	static typename Vec::value_type Eval (const Vec& values)
	{
		const auto child = Child_t::Eval (values);
		return EvalUnary (UnaryFunctionWrapper {}, child);
	}
};

Считаем производную через chain rule — выглядит страшно, идея простая. Вычисляем тоже просто: считаем значение дочернего узла, затем вычисляем значение нашей унарной функции на этом значении при помощи функции EvalUnary(). Вернее, семейства функций: первым аргументом функции идёт тип, определяющий нашу унарную функцию, чтобы гарантировать выбор нужной перегрузки во время компиляции. Да, можно было бы передавать само значение UF, и умный компилятор почти наверняка сделал бы все нужные constant propagation passes, но здесь проще перестраховаться.

Кстати, отдельную унарную операцию отрицания можно было бы и не вводить, заменив её на умножение на минус единицу.

С бинарными узлами всё аналогично, только производные выглядят совсем страшно. Для деления, например:

template<>
struct BinaryFunctionWrapper
{
	template
	using Derivative_t = Node,
						V
					>,
					Node
						>
					>
				>,
				Node
			>;
};

Тогда искомая метафункция VarDerivative_t определяется довольно просто, ибо по факту лишь вызывает Derivative_t у переданного ей узла:

template
struct VarDerivative;

template
struct VarDerivative>>
{
	using Result_t = typename Expr::template Derivative_t;
};

template
using VarDerivative_t = typename VarDerivative>::Result_t;

Если теперь определить вспомогательные переменные и типы, например:

// алиасы для типов унарных и бинарных функций:
using Sin = UnaryFunctionWrapper;
using Cos = UnaryFunctionWrapper;
using Neg = UnaryFunctionWrapper;
using Ln = UnaryFunctionWrapper;

using Add = BinaryFunctionWrapper;
using Mul = BinaryFunctionWrapper;
using Div = BinaryFunctionWrapper;
using Pow = BinaryFunctionWrapper;

// variable template из C++14 для определения переменной в общем виде:
template
constexpr Node> Var {};

// определим переменную x0 для удобства, авось, ей часто пользоваться будут:
using X0 = Node>;
constexpr X0 x0;
// и так далее для других переменных

// константа для единицы, единица часто встречается в формулах:
constexpr Node> _1;

// перегрузки операторов, им даже не нужно тело, достаточно типа:
template
Node, std::decay_t> operator+ (T1, T2);

template
Node, std::decay_t> operator* (T1, T2);

template
Node, std::decay_t> operator/ (T1, T2);

template
Node, Node>> operator- (T1, T2);

// не совсем операторы, но тоже чтобы удобно писать было:
template
Node> Sin (T);

template
Node> Cos (T);

template
Node> Ln (T);

то можно будет писать код прямо как в самом начале поста.

Что осталось?

Во-первых, разобраться с тем типом, который передаётся в функцию Eval(). Во-вторых, упомянуть про возможность преобразований искомого выражения с заменой одного поддерева на другое. Начнём со второго, оно проще.

Мотивация (можно пропустить): если немного попрофилировать код, который получится с текущей версией, то в глаза бросится, что довольно много времени уходит на вычисление 50b5af0e231d43feaccf093b6cf4b9a9.png, который, вообще говоря, один и тот же для каждой экспериментальной точки. Не беда! Введём отдельную переменную, которую посчитаем один раз перед расчётом значений нашей формулы на каждой из экспериментальных точек, и заменим все вхождения 50b5af0e231d43feaccf093b6cf4b9a9.png на эту переменную (собственно, в мотивационном коде в самом начале это уже и сделано). Однако, когда мы будем брать производную по a7a558927bbc4000a8a81788f75e1ae5.png, нам придётся вспомнить, что 50b5af0e231d43feaccf093b6cf4b9a9.png, вообще говоря, не свободный параметр, а функция от a7a558927bbc4000a8a81788f75e1ae5.png. Вспомнить очень просто: заменим 50b5af0e231d43feaccf093b6cf4b9a9.png на a7a558927bbc4000a8a81788f75e1ae5.png (для этого используется метафункция ApplyDependency_t, хотя правильнее было бы её назвать Rewrite_t или вроде того), продифференцируем, вернём a7a558927bbc4000a8a81788f75e1ae5.png на 50b5af0e231d43feaccf093b6cf4b9a9.png обратно:

using Unwrapped_t = ApplyDependency_t;
using Derivative_t = VarDerivative_t;
using CacheLog_t = ApplyDependency_t;

Реализация многословна, но идейно проста. Рекурсивно спускаемся по дереву формулы, подменяя элемент дерева, если он в точности совпадает с шаблоном, иначе ничего не меняем. Итого три специализации: для спуска по дочернему узлу унарной функции, для спуска по дочерним узлам бинарной функции, и собственно для замены, при этом специализации для спуска по дочерним узлам должны проверять, что шаблон не совпадает с поддеревом, соответствующим рассматриваемой подфункции:

template
struct ApplyDependency
{
	using Result_t = Formula;
};

template
using ApplyDependency_t = typename ApplyDependency, std::decay_t, Formula>::Result_t;

template
struct ApplyDependency, Child>,
		std::enable_if_t, Child>>::value>>
{
	using Result_t = Node<
				UnaryFunctionWrapper,
				ApplyDependency_t
			>;
};

template
struct ApplyDependency, FirstNode, SecondNode>,
		std::enable_if_t, FirstNode, SecondNode>>::value>>
{
	using Result_t = Node<
				BinaryFunctionWrapper,
				ApplyDependency_t,
				ApplyDependency_t
			>;
};

template
struct ApplyDependency
{
	using Result_t = Expr;
};

Ффух. Осталось разобраться с передачей значений параметров.

Вспомним, что каждый параметр имеет свой собственный тип, поэтому если мы построим семейство функций, перегруженных по типу параметров, каждая из которых возвращает соответствующее значение, то снова (прямо как с вычислением унарных функций чуть ранее) есть шанс, что компилятор это дело свернёт и соптимизирует (а он, кстати, и соптимизирует, умница такой). Ну, что-то вроде

auto GetValue (Variable<'x', 0>)
{
    return value_for_x0;
}

auto GetValue (Variable<'x', 1>)
{
    return value_for_x1;
}

...

Только мы хотим сделать это красиво, чтобы можно было написать, например:

BuildFunctor (g0, someValue,
        alpha0, anotherValue,
        k, yetOneMoreValue,
        r0, independentVariable,
        logr0, logOfTheIndependentVariable);

где g0, alpha0 и компания — объекты, имеющие типы соответствующих переменных, а следом за ними идут соответствующие значения.

Как мы можем скрестить ужа и ежа, сделав в общем виде функцию, тип параметра которой задаётся в компил-тайме, а значение — в рантайме? Лямбды спешат на помощь!

template
auto BuildFunctor (NodeType, ValueType val)
{
    return [val] (NodeType) { return val; };
}

Пусть у нас есть две таких функции, как мы можем получить семейство функций в одном пространстве имён, чтобы нужная выбиралась перегрузкой? Наследование спешит на помощь!

template
struct Map : F, S
{
	using F::operator();
	using S::operator();

	Map (F f, S s)
	: F { std::forward (f) }
	, S { std::forward (s) }
	{
	}
};

Мы наследуемся от обеих лямбд (ведь лямбда разворачивается в структуру со сгенерированным компилятором именем, а значит, от неё можно наследоваться) и приносим в скоуп их операторы-круглые-скобочки.

Более того, можно наследоваться не только от лямбд, но и от произвольных структур, имеющих какие-либо операторы-круглые-скобочки. Опа, получили алгебру. Таким образом, если есть N лямбд, можно отнаследовать первую Map от первых двух лямбд, следующую Map — от первой Map и следующей лямбды, и так далее. Оформим это в виде кода:

template
auto Augment (F&& f)
{
	return f;
}

template
auto Augment (F&& f, S&& s)
{
	return Map, std::decay_t> { f, s };
}

template
auto BuildFunctor ()
{
	struct
	{
		ValueType operator() () const
		{
			return {};
		}

		using value_type = ValueType;
	} dummy;
	return dummy;
}

template
auto BuildFunctor (NodeType, ValueType val, Tail&&... tail)
{
	return detail::Augment ([val] (NodeType) { return val; },
			BuildFunctor (std::forward (tail)...));
}

Собственно, всё.

Разве что, один мой приятель, которому я это показывал в своё время, предложил, на мой взгляд, более элегантное решение на constexpr-функциях, но у меня до него уже 9 месяцев не доходят руки.

Ну и линк на библиотечку: I Am Mad. К продакшену не готово, пуллреквесты принимаются, и всё такое.

Ну и ещё можно поудивляться, насколько умны современные компиляторы, которые могут продраться сквозь вот эти все слои шаблонов поверх шаблонов поверх лямбд поверх шаблонов и сгенерировать достаточно оптимальный код.

Комментарии (0)

© Habrahabr.ru