百度百科中关于DTW的3处异常及改正
百度百科中关于DTW的3处错误及改正
DTW 是 Dynamic Time Warping,可以动态扭曲时间轴,来计算两个不同长度的序列之间的相似性。详细介绍见百度百科http://baike.baidu.com/view/1647336.htm
百科中关于原理的叙述
如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
(n ,m )=(n +1,m +2)
(n ,m )=(n +1,m +1)
(n ,m )=(n +1,m )
(n ,m )=(n +1,m +2)
(n ,m )=(n +1,m +1)
(n ,m )=(n +1,m )
错误的,其实应该是
如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
(n ,m )=(n ,m +1)
(n ,m )=(n +1,m +1)
(n ,m )=(n +1,m )
(n ,m )=(n ,m +1)
(n ,m )=(n +1,m +1)
(n ,m )=(n +1,m )
算法原理的叙述错误,导致了后面matlab程序的错误:
function dist = dtw(t,r) n=size(t,1); m= size(r,1); % 帧匹配距离矩阵 d = zeros(n,m); for i=1:n for j= 1:m d(i,j)=sum((t(i,:)-r(j,:)).^2); end end % 累积距离矩阵 D = ones(n,m) * realmax; % 百科错误之一:DTW的第一行没计算,改为如下即可 D(1,:) = cumsum(d(1,:)); % 动态规划 for i = 2:n for j = 1:m D1 = D(i-1,j); if j>1 D2 = D(i-1,j-1); else D2 = realmax; end if j>2 % 百科错误之二:原理中的错误在程序中的反映,改为如下即可 D3 = D(i,j-1); else D3 = realmax; end D(i,j) = d(i,j) + min([D1,D2,D3]); end end dist = D(n,m);
上面计算的是DTW距离的平方,只要开方下就是真正的DTW距离了。下面给出Java的代码实现:
package cn.edu.xjtu; import java.util.List; public class DTW { private void spreadCalc(int startColumn, int startRow, int endColumn, int endRow) { if (!(startColumn < endColumn) && !(startRow < endRow)) { throw new IllegalArgumentException("开始的位置必须位于矩阵内部"); } /* * 为了数组操纵的直观性,将数组索引由概念(从1开始)转为Java的索引(从0开始) */ startColumn = startColumn - 1; startRow = startRow - 1; endColumn = endColumn - 1; endRow = endRow - 1; double diredist = 0.0; double acumdist = 0.0; int stopIndex = 0; do { // 计算第startRow行的距离 for (int i = startColumn; i < endColumn + 1; i++) { diredist = first.get(i) - second.get(startRow); acumdist = minAcumDist(startRow, i); distanceMatrix[startRow][i] = Math.sqrt(diredist * diredist + acumdist * acumdist); } // 计算第startColumn列的距离 for (int i = startRow + 1; i < endRow + 1; i++) { diredist = first.get(startColumn) - second.get(i); acumdist = minAcumDist(i, startColumn); distanceMatrix[i][startColumn] = Math.sqrt(diredist * diredist + acumdist * acumdist); } startColumn = startColumn + 1; startRow = startRow + 1; if (!(startColumn < endColumn)) { startColumn = endColumn; stopIndex++; continue; } if (!(startRow < endRow)) { startRow = endRow; stopIndex++; continue; } } while (stopIndex < 2); } private double minAcumDist(int rowIndex, int columnIndex) { double pre = getNumber(rowIndex - 1, columnIndex - 1); double left = getNumber(rowIndex, columnIndex - 1); double up = getNumber(rowIndex - 1, columnIndex); return min(pre, left, up); } private double getNumber(int rowIndex, int columnIndex) { if (rowIndex >= 0 && columnIndex >= 0) { return distanceMatrix[rowIndex][columnIndex]; } else if (rowIndex == -1 && columnIndex == -1) { return 0.0; } else { return Double.MAX_VALUE; } } /** * 计算两个向量全尺寸的距离 * * @return */ public double calcDTWDistance() { spreadCalc(1, 1, first.size(), second.size()); return distanceMatrix[second.size()-1][first.size()-1]; } private List<Integer> first; private List<Integer> second; private double[][] distanceMatrix; public DTW(List<Integer> first, List<Integer> second) { this.first = first; this.second = second; distanceMatrix = new double[second.size()][first.size()]; } private double min(double... inputs) { double minValue = inputs[0]; for (double each : inputs) { minValue = Math.min(minValue, each); } return minValue; } public List<Integer> getFirst() { return first; } public void setFirst(List<Integer> first) { this.first = first; } public List<Integer> getSecond() { return second; } public void setSecond(List<Integer> second) { this.second = second; } }