algorithm - Nelder-Mead optimization in C++ - Code Review Stack Exchange
Mon Jan 30 2023 15:22:55 GMT+0000 (Coordinated Universal Time)
#include "Eigen/Dense"
#include <vector>
#include <limits>
#include <cstring>
#include <iostream>
using namespace Eigen;
namespace Optimization
{
template<int nDims>
class NelderMead {
public:
typedef Array<double, 1, nDims> FunctionParameters;
typedef std::function<double(FunctionParameters)> ErrorFunction;
NelderMead(ErrorFunction errorFunction, FunctionParameters initial,
double minError, double initialEdgeLength,
double shrinkCoeff = 1, double contractionCoeff = 0.5,
double reflectionCoeff = 1, double expansionCoeff = 1)
: errorFunction(errorFunction), minError(minError),
shrinkCoeff(shrinkCoeff), contractionCoeff(contractionCoeff),
reflectionCoeff(reflectionCoeff), expansionCoeff(expansionCoeff),
worstValueId(-1), secondWorstValueId(-1), bestValueId(-1)
{
this->errors = std::vector(nDims + 1, std::numeric_limits<double>::max());
const double b = initialEdgeLength / (nDims * SQRT2) * (sqrt(nDims + 1) - 1);
const double a = initialEdgeLength / SQRT2;
this->values = initial.replicate(nDims + 1, 1);
for (int i = 0; i < nDims; i++)
{
FunctionParameters simplexRow;
simplexRow.setConstant(b);
simplexRow(0, i) = a;
simplexRow += initial;
this->values.row(i+1) = simplexRow;
}
}
void optimize()
{
for (int i = 0; i < nDims+1; i++)
{
this->errors.at(i) = this->errorFunction(this->values.row(i));
}
this->invalidateIdsCache();
while (this->errors.at(this->bestValueId) > this->minError)
{
step();
auto bestError = this->errorFunction(this->best());
auto worstError = this->errorFunction(this->worst());
std::cout << "Best error " << std::to_string(bestError) << " with : " << this->best();
std::cout << " Worst error " << std::to_string(worstError) << " with : " << this->worst();
std::cout << '\n';
}
}
void step()
{
auto meanWithoutWorst = this->getMeanWithoutWorst();
auto reflectionOfWorst = this->getReflectionOfWorst(meanWithoutWorst);
auto reflectionError = this->errorFunction(reflectionOfWorst);
FunctionParameters newValue = reflectionOfWorst;
double newError = reflectionError;
bool shrink = false;
if (reflectionError < this->errors.at(this->bestValueId))
{
auto expansionValue = this->expansion(meanWithoutWorst, reflectionOfWorst);
double expansionError = this->errorFunction(expansionValue);
if (expansionError < this->errors.at(this->bestValueId))
{
newValue = expansionValue;
newError = expansionError;
}
}
else if (reflectionError > this->errors.at(this->worstValueId))
{
newValue = this->insideContraction(meanWithoutWorst);
newError = this->errorFunction(newValue);
if (newError > this->errors.at(this->worstValueId)) { shrink = true; }
}
else if (reflectionError > this->errors.at(this->secondWorstValueId))
{
newValue = this->outsideContraction(meanWithoutWorst);
newError = this->errorFunction(newValue);
if (newError > reflectionError) { shrink = true; }
}
else
{
newValue = reflectionOfWorst;
newError = reflectionError;
}
if (shrink)
{
this->shrink();
this->invalidateIdsCache();
return;
}
this->values.row(this->worstValueId) = newValue;
this->errors.at(this->worstValueId) = newError;
this->invalidateIdsCache();
}
inline FunctionParameters worst() { return this->values.row(this->worstValueId); }
inline FunctionParameters best() { return this->values.row(this->bestValueId); }
private:
void shrink()
{
auto bestVertex = this->values.row(this->bestValueId);
for (int i = 0; i < nDims + 1; i++)
{
if (i == this->bestValueId) { continue; }
this->values.row(i) = bestVertex + this->shrinkCoeff * (this->values.row(i) - bestVertex);
this->errors.at(i) = this->errorFunction(this->values.row(i));
}
}
inline FunctionParameters expansion(FunctionParameters meanWithoutWorst, FunctionParameters reflection)
{
return reflection + this->expansionCoeff * (reflection - meanWithoutWorst);
}
inline FunctionParameters insideContraction(FunctionParameters meanWithoutWorst)
{
return meanWithoutWorst - this->contractionCoeff * (meanWithoutWorst - this->worst());
}
inline FunctionParameters outsideContraction(FunctionParameters meanWithoutWorst)
{
return meanWithoutWorst + this->contractionCoeff * (meanWithoutWorst - this->worst());
}
FunctionParameters getReflectionOfWorst(FunctionParameters meanWithoutWorst)
{
return meanWithoutWorst + this->reflectionCoeff * (meanWithoutWorst - this->worst());
}
FunctionParameters getMeanWithoutWorst()
{
FunctionParameters mean(0);
for (int i = 0; i < nDims + 1; i++)
{
if (i == this->worstValueId) { continue; }
mean += this->values.row(i);
}
// Not divided by nDims+1 because there's one ignored value.
mean /= nDims;
return mean;
}
void invalidateIdsCache() {
double worstError = std::numeric_limits<double>::min();
int worstId = -1;
double secondWorstError = std::numeric_limits<double>::max();
int secondWorstId = -1;
double bestError = std::numeric_limits<double>::max();
int bestId = -1;
for (int i = 0; i < nDims + 1; i++)
{
auto error = this->errors.at(i);
if (error > worstError)
{
secondWorstError = worstError;
secondWorstId = worstId;
worstError = error;
worstId = i;
}
else if (error > secondWorstError)
{
secondWorstError = error;
secondWorstId = i;
}
if (error < bestError)
{
bestError = error;
bestId = i;
}
}
// If we deal with a problem in 1D, it won't be set.
if (secondWorstId == -1)
{
secondWorstId = worstId;
}
this->bestValueId = bestId;
this->worstValueId = worstId;
this->secondWorstValueId = secondWorstId;
}
ErrorFunction errorFunction;
Array<double, nDims + 1, nDims> values;
std::vector<double> errors;
int worstValueId;
int secondWorstValueId;
int bestValueId;
double minError;
double shrinkCoeff;
double expansionCoeff;
double contractionCoeff;
double reflectionCoeff;
const double SQRT2 = sqrt(2);
};
}



Comments