跳舞链解数独 静态数组优化

前几天有人问我之前写的那个跳舞链解数独的程序的内存泄漏问题如何解决,因此回顾了一下我的那个程序。现在看来那个程序简直不忍直视,于是大刀阔斧的改了。主要是把动态内存分配都改为了静态预分配,这样就可以避免频繁的调用malloc和free。同时静态分配的好处就是内存访问局部性比较好,cache不容易miss。而且在一行四个节点连续分配的情况下,就没有必要存储左右指针了。而且在连续分配的时候,指针都可以蜕变为数组索引,访问就比较简单了。还有一个好处就是整个程序可读性大大增强。现在这个版本的代码如下,利用http://staffhome.ecm.uwa.edu.au/~00013890/sudokumin.php中的测试文件,总耗时 16967ms.这个耗时是在不输出结果数独的情况下测试的,带输出的测试没有去做,屏幕太闪。整个代码如下,虽说刮着c++的招牌,其实全是c。。。

  1 #include <iostream>
  2 #include <vector>
  3 #include <stack>
  4 #include <map>
  5 #include <ctime>
  6 #include <fstream>
  7 using namespace std;
  8 #define shift_base 0x80000000
  9 struct basic_node
 10 {
 11     //这里之所以没有left和right是因为我们每次分配一行的时候,是四个点一起分的,所以可以直接通过加减1来搞定左右关系
 12     int down;
 13     int up;
 14     int column;
 15 };
 16 struct basic_node total_nodes[324 + 81 * 9 * 4];//324个头节点,81个格子,每个格子有9种情况,每种情况有四个点。
 17 int avail_node_index = 324;//分配节点时的编号
 18 int node_stack[81];
 19 int stack_index = 0;
 20 struct node_heap
 21 {
 22     int cul_value;//代表这个列中的1的个数
 23     int position_index;//代表着个点所指示的列的索引
 24 };
 25 struct node_heap mutual_index[324];//这个是堆
 26 int current_heap_number = 323;//这个是当前可用的堆中的节点数
 27 int available_column = 323;//这个是当前可用列数
 28 int position_index[324];//这个是列在堆中的位置
 29 int out[9][9] = { { 8, 0, 0, 0, 0, 0, 0, 0, 0 }, { 0, 0, 3, 6, 0, 0, 0, 0, 0 }, { 0, 7, 0, 0, 9, 0, 2, 0, 0 }, 
 30 {0, 5, 0, 0, 0, 7, 0, 0, 0}, { 0, 0, 0, 0, 4, 5, 7, 0, 0 }, { 0, 0, 0, 1, 0, 0, 0, 3, 0 }, { 0, 0, 1, 0, 0, 0, 0, 6, 8 }, 
 31 {0, 0, 8, 5, 0, 0, 0, 1, 0}, { 0, 9, 0, 0, 0, 0, 4, 0, 0 } };
 32 
 33 void initial(void)
 34 {
 35     for (int i = 0; i < 324; i++)
 36     {
 37         total_nodes[i].column = i;
 38         total_nodes[i].down = i;
 39         total_nodes[i].up = i;
 40         mutual_index[i].cul_value= 0;
 41         mutual_index[i].position_index = i;
 42         position_index[i] = i;
 43     }
 44     stack_index = 0;
 45     available_column = 323;
 46     current_heap_number = 323;
 47     avail_node_index = 324;
 48 }
 49 void swap_heap(int index_one, int index_two)//交换在堆中的两个元素的值,及相关数据索引
 50 {
 51     int intermidate_one, intermidate_two;
 52     intermidate_one = mutual_index[index_one].cul_value;
 53     intermidate_two = mutual_index[index_one].position_index;
 54     mutual_index[index_one].cul_value = mutual_index[index_two].cul_value;
 55     mutual_index[index_one].position_index = mutual_index[index_two].position_index;
 56     mutual_index[index_two].cul_value = intermidate_one;
 57     mutual_index[index_two].position_index = intermidate_two;
 58     position_index[mutual_index[index_two].position_index] = index_two;
 59     position_index[mutual_index[index_one].position_index] = index_one;
 60 }
 61 void heap_initial()//初始化堆,这个动作是在所有的行插入完成之后做的
 62 {
 63     int k, i = 0;
 64     int current_min;
 65     for (i = (current_heap_number - 1) / 2; i >= 0; i--)
 66     {
 67         k = i;
 68         while (2 * k + 1 <= current_heap_number)
 69         {
 70             current_min = mutual_index[k].cul_value;
 71             current_min = current_min < mutual_index[2 * k + 1].cul_value ? current_min : mutual_index[2 * k + 1].cul_value;
 72             if (2 * k + 2 <= current_heap_number)
 73             {
 74                 current_min = current_min < mutual_index[2 * k + 2].cul_value ? current_min : mutual_index[2 * k + 2].cul_value;
 75             }
 76             if (current_min == mutual_index[k].cul_value)
 77             {
 78                 break;
 79             }
 80             else
 81             {
 82                 if (current_min == mutual_index[2 * k + 1].cul_value)
 83                 {
 84                     swap_heap(k, 2 * k + 1);
 85                     k = 2 * k + 1;
 86                 }
 87                 else
 88                 {
 89                     swap_heap(k, 2 * k + 2);
 90                     k = 2 * k + 2;
 91                 }
 92             }
 93         }
 94     }
 95 }
 96 void delete_minimal()//删除堆中最小的元素
 97 {
 98     int k;
 99     int current_min;
100     if (current_heap_number != 0)
101     {
102         swap_heap(0, current_heap_number);//交换最高元素与最低元素
103         current_heap_number--;//然后将堆的大小进行缩减
104         k = 0;
105         while (2 * k + 1 <= current_heap_number)//然后,下面便是一些维护性的工作,用来维护最小堆
106         {
107             current_min = mutual_index[k].cul_value;
108             current_min = current_min < mutual_index[2 * k + 1].cul_value ? current_min : mutual_index[2 * k + 1].cul_value;
109             if (2 * k + 2 <= current_heap_number)
110             {
111                 current_min = current_min < mutual_index[2 * k + 2].cul_value ? current_min : mutual_index[2 * k + 2].cul_value;
112             }
113             if (current_min == mutual_index[k].cul_value)
114             {
115                 return;
116             }
117             else
118             {
119                 if (current_min == mutual_index[2 * k + 1].cul_value)
120                 {
121                     swap_heap(k, 2 * k + 1);
122                     k = 2 * k + 1;
123                 }
124                 else
125                 {
126                     swap_heap(k, 2 * k + 2);
127                     k = 2 * k + 2;
128                 }
129             }
130         }
131     }
132     else//如果只剩下一个元素,那就不需要进行交换,直接将堆元素的个数降低一
133     {
134         current_heap_number = -1;
135     }
136 }
137 void heap_renew(int target_position, int new_value)//对于第target_position列,进行度数更新
138 {
139     int heap_target_position, k, current_min;
140     heap_target_position = position_index[target_position];//这个是这一列在堆中所在的位置
141     k = heap_target_position;
142     if (new_value < mutual_index[k].cul_value)//如果值是减少的,就直接进行赋值,然后维护堆的性质
143     {
144         mutual_index[k].cul_value = new_value;
145         while (k > 0 && (mutual_index[(k - 1) / 2].cul_value > mutual_index[k].cul_value))//维护堆
146         {
147             swap_heap((k - 1) / 2, k);
148             k = (k - 1) / 2;
149         }
150         if (new_value == 0)//如果是赋值为0,则从堆中进行删除,因为我们每次操纵一个元素,所以最多会有一个元素为0,所以肯定是最小值。
151         {
152             delete_minimal();
153         }
154     }
155     else//对于值增大的情况
156     {
157         mutual_index[k].cul_value = new_value;
158         if (new_value == 1)//如果新的值是1,则把这个元素重新加入堆中
159         {
160             current_heap_number++;//扩大堆的范围,我们可以证明重新加入堆中的元素一定是排在堆的末尾,当然条件是删除与插入的顺序是对应相反的
161             while (k > 0 && (mutual_index[(k - 1) / 2].cul_value > mutual_index[k].cul_value))//由于新的值是1,所以不可能比上一个数大
162             {
163                 swap_heap((k - 1) / 2, k);
164                 k = (k - 1) / 2;
165             }
166         }
167         else//如果不是1,说明已经在堆中,所以不需要扩大堆的范围,直接赋值之后进行维护堆结构就行
168         {
169             while (2 * k + 1 <= current_heap_number)
170             {
171                 current_min = mutual_index[k].cul_value;
172                 current_min = current_min < mutual_index[2 * k + 1].cul_value ? current_min : mutual_index[2 * k + 1].cul_value;
173                 if (2 * k + 2 <= current_heap_number)
174                 {
175                     current_min = current_min < mutual_index[2 * k + 2].cul_value ? current_min : mutual_index[2 * k + 2].cul_value;
176                 }
177                 if (current_min == mutual_index[k].cul_value)
178                 {
179                     break;
180                 }
181                 else
182                 {
183                     if (current_min == mutual_index[2 * k + 1].cul_value)
184                     {
185                         swap_heap(k, 2 * k + 1);
186                         k = 2 * k + 1;
187                     }
188                     else
189                     {
190                         swap_heap(k, 2 * k + 2);
191                         k = 2 * k + 2;
192                     }
193                 }
194             }
195         }
196     }
197 }
198 void node_heap_decrease(int node_index)//对于一个点进行她所在的行的删除,因为一行中一定有四个元素,所以有四列,我们对这四列的度数都进行减少1
199 {
200     int leftmost_node;//当前节点所在行的最左节点的索引
201     leftmost_node = node_index - (node_index % 4);
202     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value -1);
203     leftmost_node++;
204     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value -1);
205     leftmost_node++;
206     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value -1);
207     leftmost_node++;
208     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value -1);
209 }
210 void node_heap_increase(int node_index)//增加与减少的顺序是刚好相反的
211 {
212     int leftmost_node;//当前节点所在行的最右节点的索引
213     leftmost_node = node_index - (node_index % 4)+3;
214     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value + 1);
215     leftmost_node--;
216     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value + 1);
217     leftmost_node--;
218     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value + 1);
219     leftmost_node--;
220     heap_renew(total_nodes[leftmost_node].column, mutual_index[position_index[total_nodes[leftmost_node].column]].cul_value + 1);
221 }
222 void insert_row(int current_row_index, int current_column_index, int value)
223 {
224     int current_leftmost = avail_node_index;
225     avail_node_index += 4;
226     int column_index;
227     column_index = current_row_index * 9 + value - 1;
228     total_nodes[current_leftmost].column = column_index;
229     total_nodes[current_leftmost].down = column_index;
230     total_nodes[current_leftmost].up = total_nodes[column_index].up;
231     total_nodes[total_nodes[column_index].up].down = current_leftmost;
232     total_nodes[column_index].up = current_leftmost;
233     mutual_index[column_index].cul_value++;
234     current_leftmost++;
235     column_index = 81 + current_column_index * 9 + value - 1;
236     total_nodes[current_leftmost].column = column_index;
237     total_nodes[current_leftmost].down = column_index;
238     total_nodes[current_leftmost].up = total_nodes[column_index].up;
239     total_nodes[total_nodes[column_index].up].down = current_leftmost;
240     total_nodes[column_index].up = current_leftmost;
241     mutual_index[column_index].cul_value++;
242     current_leftmost++;
243     column_index= 162 + ((current_row_index / 3) * 3 + current_column_index / 3) * 9 + value - 1;
244     total_nodes[current_leftmost].column = column_index;
245     total_nodes[current_leftmost].down = column_index;
246     total_nodes[current_leftmost].up = total_nodes[column_index].up;
247     total_nodes[total_nodes[column_index].up].down = current_leftmost;
248     total_nodes[column_index].up = current_leftmost;
249     mutual_index[column_index].cul_value++;
250     current_leftmost++;
251     column_index = 243 + current_row_index * 9 + current_column_index;
252     total_nodes[current_leftmost].column = column_index;
253     total_nodes[current_leftmost].down = column_index;
254     total_nodes[current_leftmost].up = total_nodes[column_index].up;
255     total_nodes[total_nodes[column_index].up].down = current_leftmost;
256     total_nodes[column_index].up = current_leftmost;
257     mutual_index[column_index].cul_value++;
258 }
259 void print_result()//打印出结果
260 {
261     int i, j, k, current_index;
262     int m, n;
263     int output[9][9];
264     for (i = 0; i < 9; i++)
265     {
266         for (j = 0; j < 9; j++)
267         {
268             output[i][j] = 0;
269         }
270     }
271     for (m = 0; m < stack_index; m++)
272     {
273         current_index = node_stack[m]-node_stack[m]%4;
274         k = total_nodes[current_index].column % 9;
275         i = (total_nodes[current_index].column-total_nodes[current_index].column % 9)/9;
276         current_index ++;
277         j = (total_nodes[current_index].column - total_nodes[current_index].column % 9-81) / 9;
278         output[i][j] = k + 1;
279     }
280     printf("***********************
");
281     for (m = 0; m < 9; m++)
282     {
283         for (n = 0; n < 9; n++)
284         {
285             printf("%d ", output[m][n]);
286         }
287         printf("
");
288     }
289 }
290 
291 void creat_dlx_sudoku()//利用矩阵来建立十字网格
292 {
293     int i, j, k;
294     int row_position[9][9];//这个是行
295     int column_position[9][9];//这个是列
296     int small_position[9][9];//这个是每一个小方格
297     initial();
298     for (i = 0; i < 9; i++)
299     {
300         for (j = 0; j < 9; j++)
301         {
302             row_position[i][j] = 1;
303             column_position[i][j] = 1;
304             small_position[i][j] = 1;
305         }
306         
307     }
308     for (i = 0; i < 9; i++)
309     {
310         for (j = 0; j < 9; j++)
311         {
312             if (out[i][j] != 0)
313             {
314                 row_position[i][out[i][j]-1] = 0;
315                 column_position[j][out[i][j]-1] = 0;
316                 small_position[(i / 3) * 3 + j / 3][out[i][j]-1] = 0;
317             }
318         }
319     }
320     for (i = 0; i < 9; i++)
321     {
322         for (j = 0; j < 9; j++)
323         {
324             if (out[i][j] != 0)
325             {
326                 insert_row(i, j, out[i][j]);
327             }
328             else
329             {
330                 for (k = 0; k < 9; k++)
331                 {
332                     if ((row_position[i][k] * column_position[j][k] * small_position[(i / 3) * 3 + j / 3][k])==1)
333                     {
334                         insert_row(i, j, k + 1);
335                     }
336                     else
337                     {
338                         //do nothing
339                     }
340                 }
341             }
342         }
343     }
344     heap_initial();
345 }
346 void in_stack(int target_to_stack)
347 {
348     int leftmost = target_to_stack - target_to_stack % 4;
349     for (int i = 0; i < 4; i++)//对于当前行的每一列
350     {
351         int current_column_traversal = leftmost + i;
352         current_column_traversal = total_nodes[current_column_traversal].down;
353         while (current_column_traversal != leftmost + i)//删除当前列相交的行
354         {
355             if (current_column_traversal != total_nodes[current_column_traversal].column)//即不是头行
356             {
357                 int temp_node = current_column_traversal - current_column_traversal % 4-1;
358                 for (int j = 0; j < 4; j++)
359                 {
360                     temp_node++;
361                     if (temp_node != current_column_traversal)
362                     {
363                         total_nodes[total_nodes[temp_node].down].up = total_nodes[temp_node].up;
364                         total_nodes[total_nodes[temp_node].up].down = total_nodes[temp_node].down;
365                     }
366                 }
367                 node_heap_decrease(temp_node);
368             }
369             current_column_traversal = total_nodes[current_column_traversal].down;
370         }
371     }
372     node_heap_decrease(target_to_stack);//最后对当前行进行删除
373     node_stack[stack_index++] = target_to_stack;//然后才是入栈
374     available_column -= 4;
375     //print_result();
376 }
377 void out_stack()//注意出栈的时候是相反的操作,所有删除都相反
378 {
379     int target_to_stack = node_stack[--stack_index];
380     int rightmost = target_to_stack - target_to_stack % 4+3;
381     for (int i = 0; i < 4; i++)//对于当前行的每一列
382     {
383         int current_column_traversal = rightmost - i;
384         current_column_traversal = total_nodes[current_column_traversal].up;
385         while (current_column_traversal != rightmost - i)//删除当前列相交的行
386         {
387             if (current_column_traversal != total_nodes[current_column_traversal].column)//即不是头行
388             {
389                 int temp_node = current_column_traversal - current_column_traversal % 4+4;
390                 for (int j = 0; j < 4; j++)
391                 {
392                     temp_node --;
393                     if (temp_node != current_column_traversal)
394                     {
395                         total_nodes[total_nodes[temp_node].down].up = temp_node;
396                         total_nodes[total_nodes[temp_node].up].down = temp_node;
397                     }
398                 }
399                 node_heap_increase(temp_node);
400             }
401             current_column_traversal = total_nodes[current_column_traversal].up;
402         }
403     }
404     node_heap_increase(target_to_stack);//最后对当前行进行回复
405     available_column += 4;
406     //print_result();
407 }
408 int find_next()//用来找下一个可以入栈的元素,如果无法入栈或者已经找到了解,则返回并进行回溯操作
409 {
410     int target_position;
411     int temp_node_one;
412     if (available_column == current_heap_number)
413     {
414         if (available_column == -1)
415         {
416             //print_result();
417             return 2;
418         }
419         else
420         {
421             target_position = mutual_index[0].position_index;
422             temp_node_one = total_nodes[target_position].down;
423             in_stack(temp_node_one);
424             return 1;
425         }
426     }
427     else
428     {
429         return 0;
430     }
431 }
432 void seek_sudoku()
433 {
434     int find_result = 0;
435     int temp_node_one;
436     while (1)
437     {
438         find_result = find_next();
439         if (!find_result)//如果无法入栈且目前没有找到解,则出栈
440         {
441             temp_node_one = node_stack[stack_index - 1];
442             out_stack();
443             temp_node_one = total_nodes[temp_node_one].down;
444             while ((temp_node_one==total_nodes[temp_node_one].column))//如果当前元素是当前列头节点,则递归出栈
445             {
446                 if (stack_index == 0)//如果栈空,则所有的搜索空间已经搜索完全 返回
447                 {
448                     return;
449                 }
450                 else
451                 {
452                     temp_node_one = node_stack[stack_index - 1];
453                     out_stack();
454                     temp_node_one = total_nodes[temp_node_one].down;
455                 }
456             }
457             in_stack(temp_node_one);//将所选元素入栈
458         }
459         else
460         {
461             if (find_result / 2)//如果已经找到结果,则返回,事实上我们可以更改这个逻辑来应对有多个解的情况,并把它全部打印
462             {
463                 return;
464             }
465         }
466     }
467 }
468 int main()
469 {
470     clock_t clock_one, clock_two, clock_three;
471     ifstream suduko_file("sudoku.txt");
472     char temp[82];
473     clock_one = clock();
474     int line = 1;
475     while (line!=49152)
476     {
477         suduko_file.getline(temp, 82);
478         for (int i = 0; i < 9; i++)
479         {
480             for (int j = 0; j < 9; j++)
481             {
482                 out[i][j] = temp[i * 9 + j] - '0';
483             }
484         }
485         creat_dlx_sudoku();
486         seek_sudoku();
487         line++;
488     }
489     clock_three = clock();
490     printf("%d mscond passed in seek_sudoku
", clock_three - clock_one);
491 }