# algorithm - Nelder-Mead optimization in C++ - Code Review Stack Exchange

Mon Jan 30 2023 15:22:55 GMT+0000 (Coordinated Universal Time)

Saved by @yoonchoi #cpp

```#include "Eigen/Dense"
#include <vector>
#include <limits>
#include <cstring>
#include <iostream>

using namespace Eigen;

namespace Optimization
{
template<int nDims>
public:
typedef Array<double, 1, nDims> FunctionParameters;
typedef std::function<double(FunctionParameters)> ErrorFunction;

double minError, double initialEdgeLength,
double shrinkCoeff = 1, double contractionCoeff = 0.5,
double reflectionCoeff = 1, double expansionCoeff = 1)
: errorFunction(errorFunction), minError(minError),
shrinkCoeff(shrinkCoeff), contractionCoeff(contractionCoeff),
reflectionCoeff(reflectionCoeff), expansionCoeff(expansionCoeff),
worstValueId(-1), secondWorstValueId(-1), bestValueId(-1)
{
this->errors = std::vector(nDims + 1, std::numeric_limits<double>::max());

const double b = initialEdgeLength / (nDims * SQRT2) * (sqrt(nDims + 1) - 1);
const double a = initialEdgeLength / SQRT2;

this->values = initial.replicate(nDims + 1, 1);

for (int i = 0; i < nDims; i++)
{
FunctionParameters simplexRow;
simplexRow.setConstant(b);
simplexRow(0, i) = a;
simplexRow += initial;

this->values.row(i+1) = simplexRow;
}
}

void optimize()
{
for (int i = 0; i < nDims+1; i++)
{
this->errors.at(i) = this->errorFunction(this->values.row(i));
}

this->invalidateIdsCache();

while (this->errors.at(this->bestValueId) > this->minError)
{
step();
auto bestError = this->errorFunction(this->best());
auto worstError = this->errorFunction(this->worst());

std::cout << "Best error " << std::to_string(bestError) << " with : " << this->best();
std::cout << " Worst error " << std::to_string(worstError) << " with : " << this->worst();
std::cout << '\n';
}
}

void step()
{
auto meanWithoutWorst = this->getMeanWithoutWorst();

auto reflectionOfWorst = this->getReflectionOfWorst(meanWithoutWorst);
auto reflectionError = this->errorFunction(reflectionOfWorst);

FunctionParameters newValue = reflectionOfWorst;

bool shrink = false;

if (reflectionError < this->errors.at(this->bestValueId))
{
auto expansionValue = this->expansion(meanWithoutWorst, reflectionOfWorst);
double expansionError = this->errorFunction(expansionValue);

if (expansionError < this->errors.at(this->bestValueId))
{
newValue = expansionValue;
}
}
else if (reflectionError > this->errors.at(this->worstValueId))
{
newValue = this->insideContraction(meanWithoutWorst);

if (newError > this->errors.at(this->worstValueId)) { shrink = true; }
}
else if (reflectionError > this->errors.at(this->secondWorstValueId))
{
newValue = this->outsideContraction(meanWithoutWorst);

if (newError > reflectionError) { shrink = true; }
}
else
{
newValue = reflectionOfWorst;
}

if (shrink)
{
this->shrink();
this->invalidateIdsCache();
return;
}

this->values.row(this->worstValueId) = newValue;
this->invalidateIdsCache();
}

inline FunctionParameters worst() { return this->values.row(this->worstValueId); }
inline FunctionParameters best() { return this->values.row(this->bestValueId); }

private:

void shrink()
{
auto bestVertex = this->values.row(this->bestValueId);

for (int i = 0; i < nDims + 1; i++)
{
if (i == this->bestValueId) { continue; }

this->values.row(i) = bestVertex + this->shrinkCoeff * (this->values.row(i) - bestVertex);
this->errors.at(i) = this->errorFunction(this->values.row(i));
}
}

inline FunctionParameters expansion(FunctionParameters meanWithoutWorst, FunctionParameters reflection)
{
return reflection + this->expansionCoeff * (reflection - meanWithoutWorst);
}

inline FunctionParameters insideContraction(FunctionParameters meanWithoutWorst)
{
return meanWithoutWorst - this->contractionCoeff * (meanWithoutWorst - this->worst());
}

inline FunctionParameters outsideContraction(FunctionParameters meanWithoutWorst)
{
return meanWithoutWorst + this->contractionCoeff * (meanWithoutWorst - this->worst());
}

FunctionParameters getReflectionOfWorst(FunctionParameters meanWithoutWorst)
{
return meanWithoutWorst + this->reflectionCoeff * (meanWithoutWorst - this->worst());
}

FunctionParameters getMeanWithoutWorst()
{
FunctionParameters mean(0);
for (int i = 0; i < nDims + 1; i++)
{
if (i == this->worstValueId) { continue; }

mean += this->values.row(i);
}

// Not divided by nDims+1 because there's one ignored value.
mean /= nDims;

return mean;
}

void invalidateIdsCache() {
double worstError = std::numeric_limits<double>::min();
int worstId = -1;
double secondWorstError = std::numeric_limits<double>::max();
int secondWorstId = -1;
double bestError = std::numeric_limits<double>::max();
int bestId = -1;

for (int i = 0; i < nDims + 1; i++)
{
auto error = this->errors.at(i);

if (error > worstError)
{
secondWorstError = worstError;
secondWorstId = worstId;
worstError = error;
worstId = i;
}
else if (error > secondWorstError)
{
secondWorstError = error;
secondWorstId = i;
}

if (error < bestError)
{
bestError = error;
bestId = i;
}
}

// If we deal with a problem in 1D, it won't be set.
if (secondWorstId == -1)
{
secondWorstId = worstId;
}

this->bestValueId = bestId;
this->worstValueId = worstId;
this->secondWorstValueId = secondWorstId;
}

ErrorFunction errorFunction;
Array<double, nDims + 1, nDims> values;
std::vector<double> errors;

int worstValueId;
int secondWorstValueId;
int bestValueId;

double minError;

double shrinkCoeff;
double expansionCoeff;
double contractionCoeff;
double reflectionCoeff;

const double SQRT2 = sqrt(2);
};
}
```
content_copyCOPY