MPI兑现fft的迭代算法 源于并行计算——结构。算法。编程中伪码

MPI实现fft的迭代算法 源于并行计算——结构。算法。编程中伪码

#include "mpi.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#define T 0

typedef struct {
	double real;
	double img;
} com;

double PI;
//don't use the omega array,not every node needs a whole copy of omega,not efficient
static inline void cp(com f,com *t);//copy complex
static inline void add(com a,com b,com* r);
static inline void mul(com a,com b,com* r);
static inline void sub(com a,com b,com* r);
int br(int src,int size);//bit reverse
void show(com a) {
	printf("%.4f %.4f \n",a.real,a.img);
}
int main(int argc,char *argv[]) {
	PI=atan(1)*4;
	int k,n;//process id and total number of processes
	MPI_Init(&argc,&argv);
	MPI_Comm_rank(MPI_COMM_WORLD,&k);
	MPI_Comm_size(MPI_COMM_WORLD,&n);
	//mpi communication obj
	MPI_Request isReq;
	MPI_Request reReq;
	MPI_Status s;
	//data
	com a;//input
	a.real=k+1;
	a.img=0;
	com c;//temp
	cp(a,&c);
	com w;//omega	
	com r;//recieve

	int l=log(n)/log(2);

	for(int h=l-1;h>=0;--h) {
		int p=pow(2,h);
		int q=n/p;
		if(k%p==k%(2*p)) {
			MPI_Issend(&c,2,MPI_DOUBLE,k+p,T,MPI_COMM_WORLD,&isReq);
			MPI_Irecv(&r,2,MPI_DOUBLE,k+p,T,MPI_COMM_WORLD,&reReq);
			int time=p*(br(k,l)%q);//compute while recieving and sending
			w.real=cos(2*PI*time/n);
			w.img=sin(2*PI*time/n);
			MPI_Wait(&reReq,&s);
			mul(r,w,&r);
			MPI_Wait(&isReq,&s);
			add(c,r,&c);
		} else {
			MPI_Issend(&c,2,MPI_DOUBLE,k-p,T,MPI_COMM_WORLD,&isReq);
			MPI_Irecv(&r,2,MPI_DOUBLE,k-p,T,MPI_COMM_WORLD,&reReq);
			int time=p*(br(k-p,l)%q);//compute while recieving and sending
			w.real=cos(2*PI*time/n);
			w.img=sin(2*PI*time/n);
			MPI_Wait(&reReq,&s);
			MPI_Wait(&isReq,&s);
			mul(c,w,&c);//can't modify until sending comes to an end
			sub(r,c,&c);
		}
		MPI_Barrier(MPI_COMM_WORLD);
	}

	MPI_Issend(&c,2,MPI_DOUBLE,br(k,l),T,MPI_COMM_WORLD,&isReq);
	MPI_Recv(&a,2,MPI_DOUBLE,MPI_ANY_SOURCE,T,MPI_COMM_WORLD,&s);
	printf("b%d:%.4f %.4f \n",k,a.real,a.img);
	MPI_Wait(&isReq,&s);
	MPI_Finalize();
	return 0;
}

void cp(com f,com *t) {
	t->real=f.real;
	t->img=f.img;
}
void add(com a,com b,com *c)   
{
	c->real=a.real+b.real;   
	c->img=a.img+b.img;   
}   
    
void mul(com a,com b,com *c)   
{   
	c->real=a.real*b.real-a.img*b.img;   
	c->img=a.real*b.img+a.img*b.real;   
}
void sub(com a,com b,com *c)   
{   
	c->real=a.real-b.real;   
	c->img=a.img-b.img;   
}   
int br(int src,int size) {
	int tmp=src;
	int des=0;
	for(int i=size-1;i>=0;--i) {
		des=((tmp&1)<<i)|des;
		tmp=tmp>>1;
	}
	return des;
}

计算过程和本书的串行算法有一定的区别,平衡了节点的计算量,很巧妙。实现过程中,尽量让通信和计算时间重叠,缩短计算时间。改写自本人另一个文章串行fft。

没有使用omega数组,而是直接在使用时计算。因为每个节点最多使用log(n)个omega,计算数组需要计算n个。

碟形中上下两个复数的计算和通信顺序有少许差别,因为可能产生读后写错误。