Ошибка наследования функции стоимости Ceres-Solver: шаблоны могут быть не виртуальными

Я использую Церера-решатель уже давно и это удивительный инструмент. Мое использование до сих пор не было основано на повторно используемом коде, и я пытаюсь улучшить это. Церера использует определенную структуру с определенным шаблонным методом в качестве интерфейса к ее автоматическое дифференцирование. В проблеме, которую я пытаюсь решить, наследование имеет смысл, потому что разные функции стоимости, которые мне нужны, очень похожи друг на друга. Я создал пример, который похож (но это не имеет смысла, извините). Представьте, что мы хотим найти многоугольник с заданной областью. В моем примере полигоны могут быть треугольниками или прямоугольниками. Имея это в виду, имеет смысл иметь базовый класс, который реализует все, и конкретные классы, которые реализуют, в данном случае, вычисления площади для каждого конкретного многоугольника:

ShapeCostFunction

class shapeAreaCostFunction
{
public:
shapeAreaCostFunction(double desired_area): desired_area_(desired_area){}

template<typename T>
bool operator()(const T* shape, T* residual) const{
residual[0] = T(desired_area_) - area(shape);
return true;
}

template<typename T>
virtual T area(const T* shape) const = 0;

protected:
double desired_area_;
};

RectangleCostFunction

#include "shapeAreaCostFunction.h"#include "areaLibrary.h"
class rectangleAreaCostFunction : public shapeAreaCostFunction
{
public:
rectangleAreaCostFunction(double desired_area): shapeAreaCostFunction(desired_area){}

template<typename T>
T area(const T* triangle) const
{
return rectangleArea(triangle);
}
};

TriangleCostFunction

#include "shapeAreaCostFunction.h"#include "areaLibrary.h"
class triangleAreaCostFunction : public shapeAreaCostFunction
{
public:
triangleAreaCostFunction(double desired_area): shapeAreaCostFunction(desired_area){}

template<typename T>
T area(const T* triangle) const
{
return triangleArea(triangle);
}
};

AreaLibrary

template<typename T>
T rectangleArea(const T* rectangle)
{
return rectangle[0]*rectangle[1];
}

template<typename T>
T triangleArea(const T* triangle)
{
return rectangleArea(triangle)/T(2);
}

Главный

#include <ceres/ceres.h>
#include <iostream>

#include "rectangleAreaCostFunction.h"#include "triangleAreaCostFunction.h"#include "areaLibrary.h"
int main(int argc, char** argv){

// Initialize glogging
//google::InitGoogleLogging(argv[0]);

// Get values
/// Get total area
double total_area;
std::cout<<"Enter the desired area: ";
std::cin>>total_area;
/// Get initial rectangle
double rect[2];
std::cout<<"Enter initial rectangle base: ";
std::cin>>rect[0];
std::cout<<"Enter initial rectangle height: ";
std::cin>>rect[1];
/// Get initial triagnle
double tri[2];
std::cout<<"Enter initial triangle base: ";
std::cin>>tri[0];
std::cout<<"Enter initial triangle height: ";
std::cin>>tri[1];

// Copy initial values
double rect_ini[] = {rect[0],rect[1]};
double tri_ini[] = {tri[0],tri[1]};

// Create problem
ceres::Problem problem;
ceres::CostFunction* cost_function_rectangle = new ceres::AutoDiffCostFunction<rectangleAreaCostFunction, 1, 2>(
new rectangleAreaCostFunction(total_area));
ceres::CostFunction* cost_function_triangle = new ceres::AutoDiffCostFunction<triangleAreaCostFunction, 1, 2>(
new triangleAreaCostFunction(total_area));
problem.AddResidualBlock(cost_function_rectangle, NULL, rect);
problem.AddResidualBlock(cost_function_triangle, NULL, tri);

// Solve
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
options.max_num_iterations = 10;
ceres::Solver::Summary summary;
ceres::Solve(options, &problem, &summary);

// Compute final areas
double rect_area = rectangleArea(rect);
double tri_area = triangleArea(tri);

// Display results
std::cout << summary.FullReport() << std::endl;
std::cout<<"Rectangle: ("<<rect_ini[0]<<","<<rect_ini[1]<<") -> ("<<rect[0]<<","<<rect[1]<<") total area: "<<rect_area<<"("<< rect_area - total_area<<")"<<std::endl;
std::cout<<"Triangle: ("<<tri_ini[0]<<","<<tri_ini[1]<<") -> ("<<tri[0]<<","<<tri[1]<<") total area: "<<tri_area<<"("<< tri_area - total_area<<")"<<std::endl;

// Exit
return 0;
}

Проблема в том, что шаблонные функции не могут быть виртуальными, как объяснялось несколько раз в stackoverflowВот а также Вот). Тем не менее, кажется, есть некоторые обходные с помощью boost::any, Я пытался использовать это в моем примере, но безуспешно. Я также попытался переместить шаблон из метода класса в класс, аналогично Вот, но они не принимают это как функцию стоимости.

Мои вопросы (и, пожалуйста, имейте в виду, что я ограничен, чтобы иметь метод template<typename T> bool operator()(...)const иначе я не могу взаимодействовать с Церерой)

  1. Имеет ли смысл иметь такую ​​систему наследования (представьте, что это гораздо более сложная проблема, чем в примере)?
  2. Есть ли способ сохранить эту систему наследования и заставить код работать или я просто перенесу все в отдельные функции и вызову нужную функцию из каждой template<typename T> bool operator()(...)const метод класса?

Заранее спасибо.

1

Решение

Я могу думать о двух подходах.

Сначала составьте лямбды. Во-вторых, используйте CRTP.

Это лучше всего сделать с .

template<class Area>
auto cost_function(Area area, double desired){
return [=](auto const* shape, auto* residual){
using T=std::decay_t<decltype(*shape)>;
residual[0] = T(desired_area_) - area(shape);
return true;
};
}
auto triangle = [](auto* shape){return triangleArea(shape);};

Чтобы создать функцию стоимости площади треугольника:

auto tri_cost = cost_function(triangle, 3.14159);

и чтобы получить тип, decltype(tri_cost),

Так:

auto tri_cost = cost_function(triangle, 3.14159);
ceres::CostFunction* cost_function_triangle = new ceres::AutoDiffCostFunction<decltype(tri_cost), 1, 2>(
new decltype(tri_cost)(tri_cost));

Вы можете сделать подобную технику композиции без лямбд, но это более утомительно. Вы также можете обернуть некоторые из этих обнаженных новинок в вспомогательные функции.

template<class D>
class shapeAreaCostFunction {
public:
shapeAreaCostFunction(double desired_area): desired_area_(desired_area){}

template<typename T>
bool operator()(const T* shape, T* residual) const{
residual[0] = T(desired_area_) - static_cast<D const*>(this)->area(shape);
return true;
}
protected:
double desired_area_;
};

измените производные типы следующим образом:

class triangleAreaCostFunction :
public shapeAreaCostFunction<triangleAreaCostFunction>
{
using base=shapeAreaCostFunction<triangleAreaCostFunction>;
public:
triangleAreaCostFunction(double desired_area): base(desired_area){}

template<typename T>
T area(const T* triangle) const
{
return triangleArea(triangle);
}
};

это известно как использование CRTP для реализации статического полиморфизма.

2

Другие решения

Других решений пока нет …