MPI兑现fft的迭代算法 源于并行计算——结构。算法。编程中伪码 更新3
MPI实现fft的迭代算法 源于并行计算——结构。算法。编程中伪码 更新3
allgather开销太大,所以,可以考虑用输入数据闭包的办法,避免通信。但是fft的碟形算法的输入闭包是所有数据。但是每次迭代的时候,节点间通信一定是以第2的某整数次幂开始的2的整数次幂个数据成块交换。所以,可以对并行度进行规定,让通信按块进行。
#include "mpi.h" #include <stdio.h> #include <stdlib.h> #include <math.h> #include <sys/stat.h> #include <memory.h> /* both input num and thread num should be 2 to a power of some integer */ typedef struct { double real; double img; } com; double PI; int readBinary(char* file,void *buf,int fsize);//if named read causes override and miscall int writeBinary(char *file,com *array,int size); //don't use the omega array,not every node needs a whole copy of omega,not efficient static inline void cp(com f,com *t);//copy complex static inline void add(com a,com b,com* r); static inline void mul(com a,com b,com* r); static inline void sub(com a,com b,com* r); int br(int src,int size);//bit reverse void send(com *c,com *r,int s,int t); void show(com a) { printf("%.4f %.4f \n",a.real,a.img); } int main(int argc,char *argv[]) { if(argc<3) { printf("wtf\n"); return 1; } double st,et; PI=atan(1)*4; int self,size;//process id and total number of processes MPI_Init(&argc,&argv); st=MPI_Wtime(); MPI_Comm_rank(MPI_COMM_WORLD,&self); MPI_Comm_size(MPI_COMM_WORLD,&size); int fsize,n; void *buf; com *in; if(0==self) { //printf("start \n"); struct stat fstat; stat(argv[1],&fstat); fsize=fstat.st_size; buf=malloc(fsize); n=readBinary(argv[1],buf,fsize)/2;//n stands for total complex number } MPI_Bcast(&n,1,MPI_INT,0,MPI_COMM_WORLD);//every thread should know the total size if(-1==n) { printf("error reading \n"); MPI_Finalize(); } in=(com*)malloc(n*sizeof(com)); if(0==self) { memcpy(in,((int*)buf+1),n*sizeof(com)); free(buf); } int psize=n/size; //data com w,m;//omega com t[psize]; int l=log(n)/log(2); int off=self*psize; MPI_Request sr,rr; MPI_Status s; //initialize data MPI_Scatter(in,psize*2,MPI_DOUBLE,t,psize*2,MPI_DOUBLE,0,MPI_COMM_WORLD); if(0!=self) { memcpy(in,t,psize*sizeof(com)); } for(int h=l-1;h>=0;--h) { //calculate int p=pow(2,h); int q=n/p; int k; if(psize<=p) { if(off%p==off%(2*p)) {//inter-node communication needed int next=(off+p)/psize; MPI_Issend(in,psize*2,MPI_DOUBLE,next,0,MPI_COMM_WORLD,&sr); MPI_Irecv(t,psize*2,MPI_DOUBLE,next,0,MPI_COMM_WORLD,&rr); } else { int next=(off-p)/psize; MPI_Issend(in,psize*2,MPI_DOUBLE,next,0,MPI_COMM_WORLD,&sr); MPI_Irecv(t,psize*2,MPI_DOUBLE,next,0,MPI_COMM_WORLD,&rr); } MPI_Wait(&sr,&s); MPI_Wait(&rr,&s); //calculation for(k=off;k<off+psize;++k) { if(k%p==k%(2*p)) { int time=p*(br(k,l)%q); w.real=cos(2*PI*time/n); w.img=sin(2*PI*time/n); mul(t[k-off],w,&m); add(in[k-off],m,&in[k-off]); } else { int time=p*(br(k-p,l)%q); w.real=cos(2*PI*time/n); w.img=sin(2*PI*time/n); mul(in[k-off],w,&m); sub(t[k-off],m,&in[k-off]); } } } else {//intra-node if(0==self) { printf("memcpy \n"); } memcpy(t,in,psize*sizeof(com)); //calculation for(k=off;k<off+psize;++k) { if(k%p==k%(2*p)) { int time=p*(br(k,l)%q); w.real=cos(2*PI*time/n); w.img=sin(2*PI*time/n); mul(t[k+p-off],w,&m); add(in[k-off],m,&in[k-off]); } else { int time=p*(br(k-p,l)%q); w.real=cos(2*PI*time/n); w.img=sin(2*PI*time/n); mul(in[k-off],w,&m); sub(t[k-p-off],m,&in[k-off]); } } } if(0==self) { printf("%d \n",h); for(int i=0;i<psize;++i) { printf("b%d :%.4f %.4f \n",i,in[i].real,in[i].img); } } } memcpy(t,in,psize*sizeof(com)); MPI_Allgather(t,2*psize,MPI_DOUBLE,in,2*psize,MPI_DOUBLE,MPI_COMM_WORLD); //reverse all data int rs=0; for(int k=off;k<off+psize;++k) {//post all comunications first //tag is always the sending row number t[k-off]=in[br(k,l)]; } MPI_Gather(t,psize*2,MPI_DOUBLE,in,psize*2,MPI_DOUBLE,0,MPI_COMM_WORLD); if(0==self) { /* for(int i=0;i<n;++i) { printf("b%d :%.4f %.4f \n",i,in[i].real,in[i].img); }*/ writeBinary(argv[2],in,n); } free(in); et=MPI_Wtime(); MPI_Finalize(); // printf("%f \n",et-st); return 0; } int readBinary(char* file,void *buf,int fsize) { FILE *in; if(!(in=fopen(file,"r"))) { printf("can't open \n"); return -1; } fread(buf,sizeof(char),fsize,in); int size=((int*)buf)[0]; fclose(in); return size; } int writeBinary(char *file,com *array,int size) { FILE *out; if(!(out=fopen(file,"w"))) { printf("can't open \n"); return -1; } int bsize=sizeof(int)+size*sizeof(com); void *buf=malloc(bsize); ((int*)buf)[0]=2*size; memcpy(((int*)buf+1),array,size*sizeof(com)); fwrite(buf,sizeof(char),bsize,out); free(buf); fclose(out); return 0; } void cp(com f,com *t) { t->real=f.real; t->img=f.img; } void add(com a,com b,com *c) { c->real=a.real+b.real; c->img=a.img+b.img; } void mul(com a,com b,com *c) { c->real=a.real*b.real-a.img*b.img; c->img=a.real*b.img+a.img*b.real; } void sub(com a,com b,com *c) { c->real=a.real-b.real; c->img=a.img-b.img; } int br(int src,int size) { int tmp=src; int des=0; for(int i=size-1;i>=0;--i) { des=((tmp&1)<<i)|des; tmp=tmp>>1; } return des; }集群挂掉了,没办法实验。