From b1ccf74484f7400ebd6827fa6aa897b735f40d8e Mon Sep 17 00:00:00 2001 From: mwturvey Date: Mon, 13 Feb 2017 17:52:09 -0700 Subject: Optimize gradient descent for shallow valley --- tools/lighthousefind_tori/torus_localizer.c | 81 ++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) (limited to 'tools') diff --git a/tools/lighthousefind_tori/torus_localizer.c b/tools/lighthousefind_tori/torus_localizer.c index 15177f2..b276e99 100644 --- a/tools/lighthousefind_tori/torus_localizer.c +++ b/tools/lighthousefind_tori/torus_localizer.c @@ -640,7 +640,7 @@ static Point RefineEstimateUsingGradientDescent(Point initialEstimate, PointsAnd for (FLT f = startingPrecision; f > 0.0001; f *= descent) { - Point gradient = getGradient(lastPoint, pna, pnaCount, f/1000 /*somewhat arbitrary*/); + Point gradient = getGradient(lastPoint, pna, pnaCount, f / 1000 /*somewhat arbitrary*/); gradient = getNormalizedVector(gradient, f); //printf("Gradient: (%f, %f, %f)\n", gradient.x, gradient.y, gradient.z); @@ -679,6 +679,78 @@ static Point RefineEstimateUsingGradientDescent(Point initialEstimate, PointsAnd return lastPoint; } +// This is modifies the basic gradient descent algorithm to better handle the shallow valley case, +// which appears to be typical of this convergence. +static Point RefineEstimateUsingModifiedGradientDescent1(Point initialEstimate, PointsAndAngle *pna, size_t pnaCount, FILE *logFile) +{ + int i = 0; + FLT lastMatchFitness = getPointFitness(initialEstimate, pna, pnaCount); + Point lastPoint = initialEstimate; + //Point lastGradient = getGradient(lastPoint, pna, pnaCount, .00000001 /*somewhat arbitrary*/); + + + for (FLT g = 0.1; g > 0.00001; g *= 0.99) + { + i++; + Point point1 = lastPoint; + // let's get 3 iterations of gradient descent here. + Point gradient1 = getGradient(point1, pna, pnaCount, g / 1000 /*somewhat arbitrary*/); + Point gradientN1 = getNormalizedVector(gradient1, g); + + Point point2; + point2.x = point1.x + gradientN1.x; + point2.y = point1.y + gradientN1.y; + point2.z = point1.z + gradientN1.z; + + Point gradient2 = getGradient(point2, pna, pnaCount, g / 1000 /*somewhat arbitrary*/); + Point gradientN2 = getNormalizedVector(gradient2, g); + + Point point3; + point3.x = point2.x + gradientN2.x; + point3.y = point2.y + gradientN2.y; + point3.z = point2.z + gradientN2.z; + + Point specialGradient = { .x = point3.x - point1.x, .y = point3.y - point1.y, .z = point3.y - point1.y }; + + specialGradient = getNormalizedVector(specialGradient, g/2); + + Point point4; + + point4.x = point3.x + specialGradient.x; + point4.y = point3.y + specialGradient.y; + point4.z = point3.z + specialGradient.z; + + //point4.x = (point1.x + point2.x + point3.x) / 3; + //point4.y = (point1.y + point2.y + point3.y) / 3; + //point4.z = (point1.z + point2.z + point3.z) / 3; + + FLT newMatchFitness = getPointFitness(point4, pna, pnaCount); + + if (newMatchFitness > lastMatchFitness) + { + if (logFile) + { + writePoint(logFile, lastPoint.x, lastPoint.y, lastPoint.z, 0xFFFFFF); + } + + lastMatchFitness = newMatchFitness; + lastPoint = point4; + printf("+"); + } + else + { + printf("-"); + g *= 0.7; + + } + + + } + printf("\ni=%d\n", i); + + return lastPoint; +} + Point SolveForLighthouse(TrackedObject *obj, char doLogOutput) { PointsAndAngle pna[MAX_POINT_PAIRS]; @@ -714,7 +786,12 @@ Point SolveForLighthouse(TrackedObject *obj, char doLogOutput) //Point refinedEstimatePc = RefineEstimateUsingPointCloud(initialEstimate, pna, pnaCount, obj, logFile); - Point refinedEstimageGd = RefineEstimateUsingGradientDescent(initialEstimate, pna, pnaCount, logFile, 0.95, 0.1); + //Point refinedEstimageGd = RefineEstimateUsingGradientDescent(initialEstimate, pna, pnaCount, logFile, 0.95, 0.1); + + Point refinedEstimageGd = RefineEstimateUsingModifiedGradientDescent1(initialEstimate, pna, pnaCount, logFile); + + //Point p = { .x = 8 }; + //Point refinedEstimageGd = RefineEstimateUsingModifiedGradientDescent1(p, pna, pnaCount, logFile); //FLT fitPc = getPointFitness(refinedEstimatePc, pna, pnaCount); FLT fitGd = getPointFitness(refinedEstimageGd, pna, pnaCount); -- cgit v1.2.3