百度百科中关于DTW的3处异常及改正

百度百科中关于DTW的3处错误及改正

DTW 是 Dynamic Time Warping,可以动态扭曲时间轴,来计算两个不同长度的序列之间的相似性。详细介绍见百度百科http://baike.baidu.com/view/1647336.htm

百科中关于原理的叙述
如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
  (n ,m )=(n +1,m +2)
  (n ,m )=(n +1,m +1)
  (n ,m )=(n +1,m )
错误的,其实应该是
如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
  (n ,m )=(n ,m +1)
  (n ,m )=(n +1,m +1)
  (n ,m )=(n +1,m )

 

算法原理的叙述错误,导致了后面matlab程序的错误:

function dist = dtw(t,r)
n=size(t,1);
m= size(r,1);
% 帧匹配距离矩阵
d = zeros(n,m);
for i=1:n
for j= 1:m
d(i,j)=sum((t(i,:)-r(j,:)).^2);
end
end
% 累积距离矩阵
D = ones(n,m) * realmax;
% 百科错误之一:DTW的第一行没计算,改为如下即可
D(1,:) = cumsum(d(1,:));

% 动态规划
for i = 2:n
for j = 1:m
D1 = D(i-1,j);
if j>1
D2 = D(i-1,j-1);
else
D2 = realmax;
end
if j>2
% 百科错误之二:原理中的错误在程序中的反映,改为如下即可
D3 = D(i,j-1);
else
D3 = realmax;
end
D(i,j) = d(i,j) + min([D1,D2,D3]);
end
end
dist = D(n,m);

 

上面计算的是DTW距离的平方,只要开方下就是真正的DTW距离了。下面给出Java的代码实现:

 

package cn.edu.xjtu;

import java.util.List;

public class DTW {

	private void spreadCalc(int startColumn, int startRow, int endColumn,
			int endRow) {

		if (!(startColumn < endColumn) && !(startRow < endRow)) {
			throw new IllegalArgumentException("开始的位置必须位于矩阵内部");
		}

		/*
		 * 为了数组操纵的直观性,将数组索引由概念(从1开始)转为Java的索引(从0开始)
		 */
		startColumn = startColumn - 1;
		startRow = startRow - 1;
		endColumn = endColumn - 1;
		endRow = endRow - 1;

		double diredist = 0.0;
		double acumdist = 0.0;
		int stopIndex = 0;

		do {

			// 计算第startRow行的距离
			for (int i = startColumn; i < endColumn + 1; i++) {

				diredist = first.get(i) - second.get(startRow);

				acumdist = minAcumDist(startRow, i);

				distanceMatrix[startRow][i] = Math.sqrt(diredist * diredist
						+ acumdist * acumdist);
			}

			// 计算第startColumn列的距离
			for (int i = startRow + 1; i < endRow + 1; i++) {

				diredist = first.get(startColumn) - second.get(i);

				acumdist = minAcumDist(i, startColumn);

				distanceMatrix[i][startColumn] = Math.sqrt(diredist * diredist
						+ acumdist * acumdist);
			}

			startColumn = startColumn + 1;
			startRow = startRow + 1;

			if (!(startColumn < endColumn)) {
				startColumn = endColumn;
				stopIndex++;
				continue;
			}

			if (!(startRow < endRow)) {
				startRow = endRow;
				stopIndex++;
				continue;
			}

		} while (stopIndex < 2);
	}

	private double minAcumDist(int rowIndex, int columnIndex) {

		double pre = getNumber(rowIndex - 1, columnIndex - 1);
		double left = getNumber(rowIndex, columnIndex - 1);
		double up = getNumber(rowIndex - 1, columnIndex);

		return min(pre, left, up);
	}

	private double getNumber(int rowIndex, int columnIndex) {
		if (rowIndex >= 0 && columnIndex >= 0) {
			return distanceMatrix[rowIndex][columnIndex];
		} else if (rowIndex == -1 && columnIndex == -1) {
			return 0.0;
		} else {
			return Double.MAX_VALUE;
		}
	}


	/**
	 * 计算两个向量全尺寸的距离
	 * 
	 * @return
	 */
	public double calcDTWDistance() {

		spreadCalc(1, 1, first.size(), second.size());

		return distanceMatrix[second.size()-1][first.size()-1];
	}

	private List<Integer> first;
	private List<Integer> second;
	private double[][] distanceMatrix;

	public DTW(List<Integer> first, List<Integer> second) {
		this.first = first;
		this.second = second;
		distanceMatrix = new double[second.size()][first.size()];
	}

	private double min(double... inputs) {

		double minValue = inputs[0];
		for (double each : inputs) {
			minValue = Math.min(minValue, each);
		}
		return minValue;
	}

	public List<Integer> getFirst() {
		return first;
	}

	public void setFirst(List<Integer> first) {
		this.first = first;
	}

	public List<Integer> getSecond() {
		return second;
	}

	public void setSecond(List<Integer> second) {
		this.second = second;
	}

}