《算法》第六章部分程序 part 3

▶ 书中第六章部分程序,包括在加上自己补充的代码,后缀树的两种实现

● 后缀树实现一

  1 package package01;
  2 
  3 import java.util.Arrays;
  4 import edu.princeton.cs.algs4.StdIn;
  5 import edu.princeton.cs.algs4.StdOut;
  6 
  7 public class class01
  8 {
  9     private Suffix[] suffixes;                                  // 后缀数组
 10 
 11     public class01(String text)
 12     {
 13         int n = text.length();
 14         suffixes = new Suffix[n];
 15         for (int i = 0; i < n; i++)
 16             suffixes[i] = new Suffix(text, i);
 17         Arrays.sort(suffixes);
 18     }
 19 
 20     private static class Suffix implements Comparable<Suffix>   // 后缀类,包含原字符串和该后缀在原字符串中的起始索引,text[0] = originText[index]
 21     {                                                           // 每个后缀元素都保存了原字符串,浪费
 22         private final String text;
 23         private final int index;
 24 
 25         private Suffix(String inputText, int inputIndex)
 26         {
 27             text = inputText;
 28             index = inputIndex;
 29         }
 30 
 31         private int length()                                    // 求后缀元素的长度,要用原数组长度减去该后缀的起始索引
 32         {
 33             return text.length() - index;
 34         }
 35 
 36         private char charAt(int i)                              // 取后缀元素的第 i 字符,要用起始索引开始数
 37         {
 38             return text.charAt(index + i);
 39         }
 40 
 41         public int compareTo(Suffix that)                       // 比较两个后缀元素
 42         {
 43             if (this == that)
 44                 return 0;
 45             int n = Math.min(this.length(), that.length());
 46             for (int i = 0; i < n; i++)
 47             {
 48                 if (this.charAt(i) < that.charAt(i))
 49                     return -1;
 50                 if (this.charAt(i) > that.charAt(i))
 51                     return +1;
 52             }
 53             return this.length() - that.length();
 54         }
 55 
 56         public String toString()
 57         {
 58             return text.substring(index);
 59         }
 60     }
 61 
 62     public int length()
 63     {
 64         return suffixes.length;
 65     }
 66 
 67     public int index(int i)                         // 注意 index 表示后缀元素的起始索引,同时也是各后缀元素的原始序号
 68     {
 69         if (i < 0 || i >= suffixes.length)
 70             throw new IllegalArgumentException();
 71         return suffixes[i].index;
 72     }
 73 
 74     public int lcp(int i)                           // 计算排序的后缀数组中第 i 元素与第 i-1 元素的公共前缀长度
 75     {
 76         if (i < 1 || i >= suffixes.length)
 77             throw new IllegalArgumentException();
 78         Suffix s = suffixes[i], t = suffixes[i - 1];
 79         int n = Math.min(s.length(), t.length());
 80         for (int j = 0; j < n; j++)
 81         {
 82             if (s.charAt(j) != t.charAt(j))
 83                 return j;
 84         }
 85         return n;
 86     }
 87 
 88     public String select(int i)                     // 返回排序后的后缀数组中第 i 元素
 89     {
 90         if (i < 0 || i >= suffixes.length)
 91             throw new IllegalArgumentException();
 92         return suffixes[i].toString();
 93     }
 94 
 95     public int rank(String query)                   // 查找 query 在排序后的后缀数组中的排名
 96     {
 97         int lo = 0, hi = suffixes.length - 1;
 98         for (; lo <= hi;)
 99         {
100             int mid = lo + (hi - lo) / 2;
101             int cmp = compare(query, suffixes[mid]);
102             if (cmp < 0)
103                 hi = mid - 1;
104             else if (cmp > 0)
105                 lo = mid + 1;
106             else
107                 return mid;
108         }
109         return lo;
110     }
111 
112     private static int compare(String query, Suffix suffix) // 比较两个后缀类
113     {
114         int n = Math.min(query.length(), suffix.length());
115         for (int i = 0; i < n; i++)
116         {
117             if (query.charAt(i) < suffix.charAt(i))
118                 return -1;
119             if (query.charAt(i) > suffix.charAt(i))
120                 return +1;
121         }
122         return query.length() - suffix.length();
123     }
124 
125     public static void main(String[] args)
126     {
127         String s = StdIn.readAll().replaceAll("
", " ").trim();
128         class01 suffix = new class01(s);
129 
130         StdOut.println("  i ind lcp rnk select
---------------------------");
131         for (int i = 0; i < s.length(); i++)
132         {
133             int index = suffix.index(i);
134             String ith = """ + s.substring(index, Math.min(index + 50, s.length())) + """;
135             assert s.substring(index).equals(suffix.select(i));
136             StdOut.printf("%3d %3d %3d %3d %s
", i, index, (i == 0) ? 0 : suffix.lcp(i), suffix.rank(s.substring(index)), ith);
137         }
138     }
139 }

● 后缀树实现二

  1 package package01;
  2 
  3 import edu.princeton.cs.algs4.StdIn;
  4 import edu.princeton.cs.algs4.StdOut;
  5 
  6 public class class01
  7 {
  8     private static final int CUTOFF = 5;    // 插入排序分界
  9 
 10     private final char[] text;              // 原字符串,只保存一份
 11     private final int[] index;              // 索引数组,序后第 i 子串的首字符是原字符串中第 index[i] 字符
 12     private final int n;                    // 字符串长度
 13 
 14     public class01(String text)
 15     {
 16         n = text.length();
 17         text += ' ';
 18         text = text.toCharArray();
 19         index = new int[n];
 20         for (int i = 0; i < n; i++)         // 初始化 index,排序交换是对 index 进行的
 21             index[i] = i;
 22         sort(0, n - 1, 0);
 23     }
 24 
 25     private void sort(int lo, int hi, int d)// 基于字符串的第 d 位进行三路排序
 26     {
 27         if (hi <= lo + CUTOFF)
 28         {
 29             insertion(lo, hi, d);
 30             return;
 31         }
 32         int lt = lo, gt = hi;
 33         char v = text[index[lo] + d];
 34         for (int i = lo + 1; i <= gt;)      // 排序后 a[lo..lt-1] < v = a[lt..gt] < a[gt+1..hi]
 35         {
 36             char t = text[index[i] + d];
 37             if (t < v)
 38                 exch(lt++, i++);
 39             else if (t > v)
 40                 exch(i, gt--);
 41             else
 42                 i++;
 43         }
 44         sort(lo, lt - 1, d);                // 分别对前段,中段,后段进行排序
 45         if (v > 0)
 46             sort(lt, gt, d + 1);
 47         sort(gt + 1, hi, d);
 48     }
 49 
 50     private void insertion(int lo, int hi, int d)
 51     {
 52         for (int i = lo; i <= hi; i++)
 53         {
 54             for (int j = i; j > lo && less(index[j], index[j - 1], d); j--)
 55                 exch(j, j - 1);
 56         }
 57     }
 58 
 59     private boolean less(int i, int j, int d)// 比较
 60     {
 61         if (i == j)
 62             return false;
 63         for (i = i + d, j = j + d; i < n && j < n; i++, j++)
 64         {
 65             if (text[i] < text[j])
 66                 return true;
 67             if (text[i] > text[j])
 68                 return false;
 69         }
 70         return i > j;
 71     }
 72 
 73     private void exch(int i, int j)
 74     {
 75         int swap = index[i];
 76         index[i] = index[j];
 77         index[j] = swap;
 78     }
 79 
 80     public int length()
 81     {
 82         return n;
 83     }
 84 
 85     public int index(int i)
 86     {
 87         if (i < 0 || i >= n)
 88             throw new IllegalArgumentException();
 89         return index[i];
 90     }
 91 
 92     public int lcp(int i)
 93     {
 94         if (i < 1 || i >= n)
 95             throw new IllegalArgumentException();
 96         int s = index[i], t = index[i - 1], n = 0;
 97         for (; s < n && t < n; s++, t++, n++)
 98         {
 99             if (text[s] != text[t])
100                 return j;
101         }
102         return n;
103     }
104 
105     public String select(int i)
106     {
107         if (i < 0 || i >= n)
108             throw new IllegalArgumentException();
109         return new String(text, index[i], n - index[i]);
110     }
111 
112     public int rank(String query)
113     {
114         int lo = 0, hi = n - 1;
115         for (; lo <= hi;)
116         {
117             int mid = lo + (hi - lo) / 2;
118             int cmp = compare(query, index[mid]);
119             if (cmp < 0)
120                 hi = mid - 1;
121             else if (cmp > 0)
122                 lo = mid + 1;
123             else
124                 return mid;
125         }
126         return lo;
127     }
128 
129     private int compare(String query, int i)
130     {
131         int m = query.length(), j = 0;
132         for (; i < n && j < m; i++, j++)
133         {
134             if (query.charAt(j) != text[i])
135                 return query.charAt(j) - text[i];
136         }
137         if (i < n)
138             return -1;
139         if (j < m)
140             return +1;
141         return 0;
142     }
143 
144     public static void main(String[] args)
145     {
146         String s = StdIn.readAll().replaceAll("
", " ").trim();
147         class01 suffix = new class01(s);
148 
149         StdOut.println("  i ind lcp rnk  select
---------------------------");
150         for (int i = 0; i < s.length(); i++)
151         {
152             int index = suffix.index(i);
153             String ith = """ + s.substring(index, Math.min(index + 50, s.length())) + """;
154             assert s.substring(index).equals(suffix.select(i));
155             StdOut.printf("%3d %3d %3d %3d  %s
", i, index, (i == 0) ? 0 : suffix.lcp(i), suffix.rank(s.substring(index)), ith);
156         }
157     }
158 }