自己实现的一个光流算法

自己实现的一个光流算法

自己实现的一个光流算法,通过模式搜索匹配的方式计算相邻两张图片的平移量。

模式匹配:选择方块模式或者X形模式,在两个图片中,将该模式像素灰度值做差并求和,相差最小的认为上最匹配的。

多模式匹配:在图片中选择多个位置,检索最符合的位置,最后将多个位置的匹配结果作平均值。

经过测试,在草地、柏油路面、地毯等非规则图形的粗糙表面上表现良好。

// optical flow use Multi-Pattern-Match algrithm, use the big inner diff pattern to do multi pattern match, then averge the result
// already implemented: square-pattern, X-pattern
class COpticalFlow_MPM
{
public:
    COpticalFlow_MPM(){}
    virtual ~COpticalFlow_MPM(){}

    static bool AddImplementation(COpticalFlow_MPM* imp)
    {
        if(m_impNum < c_maxImpNum){
            m_impTbl[m_impNum++] = imp;
            return true;
        }
        return false;
    }

    static void SetImageDimesion(int width, int height, int lineBytes)
    {
        for(int i = 0; i < m_impNum; ++i){
            m_impTbl[i]->m_width = width;
            m_impTbl[i]->m_height = height;
            m_impTbl[i]->m_lineBytes = lineBytes;
            m_impTbl[i]->GenerateSearchTable();
            m_impTbl[i]->GeneratePatternTable();
        }
    }

    // auto choose the pattern to do optical flow
    static void AutoOpticalFlow(uint8_t* image1, uint8_t* image2)
    {
        m_impTbl[m_impCurr]->calcOpticalFlow(image1, image2);

        // check if need switch pattern
        static int s_goodCount = 0;
        static int s_badCount = 0;
        if(m_quality > 0){
            s_goodCount++;
        }else{
            s_badCount++;
        }
        if(s_goodCount + s_badCount > 30){
            if(s_badCount * 2 > s_goodCount){
                m_impCurr = m_impCurr < (m_impNum - 1) ? m_impCurr + 1 : 0;
            }
            s_goodCount = s_badCount = 0;
        }
    }

    // the result
    static uint8_t m_quality;    // 0 ~ 255, 0 means the optical flow is invalid.
    static float m_offset_x;    // unit is pixel
    static float m_offset_y;

protected:

    virtual const char* Name() = 0;
    virtual void GeneratePatternTable() = 0;

    // prepare the address offset tables, that can make the calculation simple and fast.
    void GenerateSearchTable()
    {
        // generate the search offset from corresponding location to the max distance
        int index = 0;
        int yNum, ay[2];
        for (int dist = 1; dist <= c_searchD; ++dist){
            for (int x = -dist; x <= dist; ++x){
                // for each x, only have 1 or 2 dy choices.
                ay[0] = dist - abs(x);
                if (ay[0] == 0){
                    yNum = 1;
                }
                else{
                    yNum = 2;
                    ay[1] = -ay[0];
                }
                for (int iy = 0; iy < yNum; ++iy){
                    m_searchOffsets[index++] = ay[iy] * m_lineBytes + x;
                }
            }
        }

        // generate the watch points.
        index = 0;
        int center = m_width * m_height / 2 + m_width / 2;
        for (int y = -c_watchN; y <= c_watchN; ++y){
            for (int x = -c_watchN; x <= c_watchN; ++x){
                m_watchPoints[index++] = center + y * c_watchG * m_lineBytes + x * c_watchG * m_width / m_height;
            }
        }
    }

    void ResetResult()
    {
        m_quality = 0;
        m_offset_x = 0;
        m_offset_y = 0;
    }

    void calcOpticalFlow(uint8_t* image1, uint8_t* image2)
    {
        ResetResult();

        int betterStart;
        int matchedOffset;
        int x1, y1, x2, y2;
        int matchedCount = 0;
        int offset_x[c_watchS];
        int offset_y[c_watchS];

        for (int i = 0; i < c_watchS; ++i){
            if (SearchMaxInnerDiff(image1, m_watchPoints[i], betterStart)){
                int32_t minDiff = SearchBestMatch(image1 + betterStart, m_patternOffsets, c_patternS, image2, betterStart, matchedOffset);
                if (minDiff < c_patternS * c_rejectDiff){
                    x1 = betterStart % m_lineBytes;        y1 = betterStart / m_lineBytes;
                    x2 = matchedOffset % m_lineBytes;    y2 = matchedOffset / m_lineBytes;
                    m_offset_x += (x2 - x1);
                    m_offset_y += (y2 - y1);
                    offset_x[matchedCount] = (x2 - x1);
                    offset_y[matchedCount] = (y2 - y1);
                    matchedCount++;
                }
            }
        }

        if (matchedCount >= 4){
            m_offset_x /= matchedCount;
            m_offset_y /= matchedCount;

            // calculate the variance, and use the variance to get the quality.
            float varX = 0, varY = 0;
            for (int i = 0; i < matchedCount; ++i){
                varX += (offset_x[i] - m_offset_x) * (offset_x[i] - m_offset_x);
                varY += (offset_y[i] - m_offset_y) * (offset_y[i] - m_offset_y);
            }
            varX /= (matchedCount - 1);
            varY /= (matchedCount - 1);
            float varMax = varX > varY ? varX : varY;
            m_quality = (uint8_t)(varMax > 2 ? 0 : (2-varMax) * 255 / 2);

            if(m_quality == 0){
                ResetResult();
            }
        }
    }

    // get the pattern inner diff, the pattern is center of the area.
    inline int32_t InnerDiff(const uint8_t* center, const int* patternPoints, const int patternSize)
    {
        int32_t sum = 0;
        int32_t mean = 0;

        for (int i = 0; i < patternSize; ++i){
            sum += center[patternPoints[i]];
        }
        mean = sum / patternSize;

        int32_t sumDiff = 0;
        for (int i = 0; i < patternSize; ++i){
            sumDiff += abs(center[patternPoints[i]] - mean);
        }

        return sumDiff;
    }

    // get the sum diff between two pattern, the pattern is the center of the area.
    inline int32_t PatternDiff(const uint8_t* center1, const uint8_t* center2, const int* patternPoints, const int patternSize)
    {
        int32_t sumDiff = 0;
        for (int i = 0; i < patternSize; ++i){
            sumDiff += abs(center1[patternPoints[i]] - center2[patternPoints[i]]);
        }
        return sumDiff;
    }

    // search the max inner diff location, image is the full image begining, the return value searchOffset is base on the image begining.
    inline bool SearchMaxInnerDiff(const uint8_t* image, int searchStart, int& betterStart)
    {
        // if the inner diff is less than this number, cannot use this pattern to do search.
        const int c_minInnerDiff = c_patternS * 4;
        const int c_acceptInnerDiff = c_patternS * 12;

        const uint8_t* searchCenter = image + searchStart;
        int32_t currDiff = InnerDiff(searchCenter, m_patternOffsets, c_patternS);
        int32_t maxDiff = currDiff;
        betterStart = 0;

        for (int i = 0; i < c_searchS; ++i){
            currDiff = InnerDiff(searchCenter + m_searchOffsets[i], m_patternOffsets, c_patternS);
            if (currDiff > maxDiff){
                maxDiff = currDiff;
                betterStart = m_searchOffsets[i];
            }
            if (maxDiff > c_acceptInnerDiff){
                break;
            }
        }

        if (maxDiff < c_minInnerDiff){
            return false;
        }

        betterStart += searchStart;
        return true;
    }

    // get the minnmum pattern diff with the 8 neighbors.
    inline int32_t MinNeighborDiff(const uint8_t* pattern)
    {
        const int32_t threshDiff = c_patternS * c_acceptDiff;

        // eight neighbors of a pattern
        const int neighborOffsets[8] = { -1, 1, -m_lineBytes, m_lineBytes, -m_lineBytes - 1, -m_lineBytes + 1, m_lineBytes - 1, m_lineBytes + 1 };

        int minDiff = PatternDiff(pattern, pattern + neighborOffsets[0], m_patternOffsets, c_patternS);
        if (minDiff < threshDiff){
            return minDiff;
        }

        int diff;
        for (int i = 1; i < 8; ++i){
            diff = PatternDiff(pattern, pattern + neighborOffsets[i], m_patternOffsets, c_patternS);
            if (diff < minDiff){
                minDiff = diff;
                if (minDiff < threshDiff){
                    return minDiff;
                }
            }
        }

        return minDiff;
    }

    // search the pattern that have max min_diff with neighbors, image is the full image begining, the return value betterStart is base on the image begining.
    inline bool SearchMaxNeighborDiff(const uint8_t* image, int searchStart, int& betterStart)
    {
        const uint8_t* searchCenter = image + searchStart;
        int32_t currDiff = MinNeighborDiff(searchCenter);
        int32_t maxDiff = currDiff;
        betterStart = 0;

        for (int i = 0; i < c_searchS; ++i){
            currDiff = MinNeighborDiff(searchCenter + m_searchOffsets[i]);
            if (currDiff > maxDiff){
                maxDiff = currDiff;
                betterStart = m_searchOffsets[i];
            }
        }
        if (maxDiff <= c_patternS * c_acceptDiff){
            return false;
        }

        betterStart += searchStart;
        return true;
    }

    // match the target pattern in the image, return the best match quality and matched offset; the pattern is the center, image is the full image begining.
    inline int32_t SearchBestMatch(const uint8_t* target, const int* patternPoints, const int patternSize, const uint8_t* image, int searchStart, int& matchedOffset)
    {
        const int thinkMatchedDiff = patternSize * c_acceptDiff;
        const uint8_t* searchCenter = image + searchStart;
        const uint8_t* matched = searchCenter;
        int32_t currDiff = PatternDiff(target, matched, patternPoints, patternSize);
        int32_t minDiff = currDiff;

        for (int i = 0; i < c_searchS; ++i){
            currDiff = PatternDiff(target, searchCenter + m_searchOffsets[i], patternPoints, patternSize);
            if (currDiff < minDiff){
                minDiff = currDiff;
                matched = searchCenter + m_searchOffsets[i];
            }
            if (minDiff < thinkMatchedDiff){
                break;
            }
        }

        matchedOffset = matched - image;
        return minDiff;
    }

    int m_width, m_height, m_lineBytes;

    static const int c_acceptDiff = 2;    // if the average pixel error is less than this number, think already matched
    static const int c_rejectDiff = 8;    // if the average pixel error is larger than this number, think it's not matched

    // all address offset to the pattern key location, the size is according to the square pattern.
    static const int c_patternN = 3;
    static const int c_patternS = (2 * c_patternN + 1) * (2 * c_patternN + 1);
    int m_patternOffsets[c_patternS];

    // the offsets to the image start for each seed point, the match is around these seed points.
    static const int c_watchN = 2;
    static const int c_watchS = (2 * c_watchN + 1) * (2 * c_watchN + 1);
    static const int c_watchG = 30;        // The gap of the watch grid in height direction
    int m_watchPoints[c_watchS];

    // the search offset to the search center, match the pattern from the corresponding location to the max distance. (not include distance 0.)
    static const int c_searchD = 10;    // search street-distance from the key location
    static const int c_searchS = 2 * c_searchD * c_searchD + 2 * c_searchD;
    int m_searchOffsets[c_searchS];

    // The implements table that use various pattern
    static int m_impCurr;
    static int m_impNum;
    static const int c_maxImpNum = 16;
    static COpticalFlow_MPM* m_impTbl[c_maxImpNum];
};

// save the optical flow result
uint8_t COpticalFlow_MPM::m_quality;    // 0 ~ 255, 0 means the optical flow is invalid.
float COpticalFlow_MPM::m_offset_x;    // unit is pixel
float COpticalFlow_MPM::m_offset_y;

// the implements that use different pattern
int COpticalFlow_MPM::m_impCurr = 0;
int COpticalFlow_MPM::m_impNum = 0;
COpticalFlow_MPM* COpticalFlow_MPM::m_impTbl[COpticalFlow_MPM::c_maxImpNum];

// Multi-Pattern-Match-Square
class COpticalFlow_MPMS : public COpticalFlow_MPM
{
public:
    COpticalFlow_MPMS(){}
    virtual ~COpticalFlow_MPMS(){}

    virtual const char* Name()    { return "Square"; }

protected:
    // prepare the address offset tables, that can make the calculation simple and fast.
    virtual void GeneratePatternTable()
    {
        // generate the address offset of the match area to the center of the area.
        int index = 0;
        for (int y = -c_patternN; y <= c_patternN; ++y){
            for (int x = -c_patternN; x <= c_patternN; ++x){
                m_patternOffsets[index++] = y * m_lineBytes + x;
            }
        }
    }
};

// Multi-Pattern-Match-X
class COpticalFlow_MPMX : public COpticalFlow_MPM
{
public:
    COpticalFlow_MPMX(){}
    virtual ~COpticalFlow_MPMX(){}

    virtual const char* Name()    { return "X"; }

protected:
    // prepare the address offset tables, that can make the calculation simple and fast.
    virtual void GeneratePatternTable()
    {
        // generate the address offset of the match area to the center of the area.
        int index = 0;
        int armLen = (c_patternS - 1) / 4;
        for (int y = -armLen; y <= armLen; ++y){
            if(y == 0){
                m_patternOffsets[index++] = 0;
            }else{
                m_patternOffsets[index++] = y * m_lineBytes - y;
                m_patternOffsets[index++] = y * m_lineBytes + y;
            }
        }
    }
};

static COpticalFlow_MPMS of_mpms;
static COpticalFlow_MPMX of_mpmx;

void OpticalFlow::init()
{
    // set the optical flow implementation table
    COpticalFlow_MPM::AddImplementation(&of_mpms);
    COpticalFlow_MPM::AddImplementation(&of_mpmx);
    COpticalFlow_MPM::SetImageDimesion(m_width, m_height, m_lineBytes);
}

uint32_t OpticalFlow::flow_image_in(const uint8_t *buf, int len, uint8_t *quality, int32_t *centi_pixel_x, int32_t *centi_pixel_y)
{
    static uint8_t s_imageBuff1[m_pixelNum];
    static uint8_t s_imageBuff2[m_pixelNum];
    static uint8_t* s_imagePre = NULL;
    static uint8_t* s_imageCurr = s_imageBuff1;

    *quality = 0;
    *centi_pixel_x = 0;
    *centi_pixel_y = 0;

    memcpy(s_imageCurr, buf, len);

    // first image
    if(s_imagePre == NULL){
        s_imagePre = s_imageCurr;
        s_imageCurr = s_imageCurr == s_imageBuff1 ? s_imageBuff2 : s_imageBuff1;    // switch image buffer
        return 0;
    }

    COpticalFlow_MPM::AutoOpticalFlow(s_imagePre, s_imageCurr);
    if(COpticalFlow_MPM::m_quality > 0){
        *quality = COpticalFlow_MPM::m_quality;
        *centi_pixel_x = (int32_t)(COpticalFlow_MPM::m_offset_x * 100);
        *centi_pixel_y = (int32_t)(COpticalFlow_MPM::m_offset_y * 100);
    }

    s_imagePre = s_imageCurr;
    s_imageCurr = s_imageCurr == s_imageBuff1 ? s_imageBuff2 : s_imageBuff1;    // switch image buffer
    return 0;
}