【HNOI】 c tree-dp

  【题目描述】给定一个n个节点的树,每个节点有两个属性值a[i],b[i],我们可以在树中选取一个连通块G,这个连通块的值为(Σa[x])(Σb[x]) x∈G,求所有连通块的值的和,输出答案对1000000007取余。

  【数据范围】n<=10^5.

  首先我们任选一个点作为根,变成一颗有根树。观察答案为(Σa[x])(Σb[x]),那么我们可以将这个答案展开成为每一个b[x]乘上所有可能情况下的a[y],这个可能情况就是x点在连通块中时,b[x]乘上连通块内所有点的a值去和,再枚举所有的连通块,就可以求出来b[x]对答案的贡献,那么我们现在问题就转化为了求出来一个节点,所有包括这个节点的连通块的a值和,每个连通块的a值为连通块内所有点的a值和,设这个值为sum_[x]。

  我们的sum_[x]的值的求方法可以为x子树中每个a被累加的次数加上非x子树节点a值被累加的次数,那么我们可以依次求出来这两个,然后求出sum_[x]。

  我们设w[x]为以x为根的子树中,包含x节点的连通块的数量,sum[x]为以x为根的子树中,包含x的所有连通块的a值和,w_[x]为所有包含x节点的连通块的数量。

  有了这些量,我们就可以求出sum_[x],先考虑这些量的转移。

  w[x]=π(w[son of x]+1).

  sum[x]=Σ(w[x]/(w[son of x]+1)*sum[son of x]).

  这两个量的转移是由子节点到根的,比较容易考虑,现在我们有了这两个量之后,考虑用这两个量转移其余的两个量。

  w_[x]=(w_[father of x]/(w[x]+1)+1)*w[x].

  那么sum_[x]就等于之前说的两部分相加,则

  sum_[x]=w_[father of x]/(w[x]+1)+1)*sum[x]+(sum_[father of x]-w_[father of x]/(w[x]+1)*sum[x])/(w[x]+1)*w[x].

  反思:为了提高速度没开LL,用到的地方强转的LL,然后有的地方忘加了,纠结了好久= =。

//By BLADEVIL
#include <cstdio>
#define d39 1000000007
#define maxn 100010
#define LL long long

using namespace std;

int n,l;
int last[maxn],other[maxn<<1],pre[maxn<<1],a[maxn],b[maxn],que[maxn],dis[maxn];
int sum[maxn],w[maxn],sum_[maxn],w_[maxn];

void connect(int x,int y) {
    pre[++l]=last[x];
    last[x]=l;
    other[l]=y;
}

int pw(int x,int k) {
    int ans=1;
    while (k) {
        if (k&1) ans=((LL)ans*x)%d39; 
        x=((LL)x*x)%d39;
        k>>=1;
    }
    return ans;
}

int main() {
    freopen("c.in","r",stdin); freopen("c.out","w",stdout);
    scanf("%d",&n);
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        connect(x,y); connect(y,x);
    }
    for (int i=1;i<=n;i++) scanf("%d",&a[i]);
    for (int i=1;i<=n;i++) scanf("%d",&b[i]);
    int h=0,t=1; que[1]=1; dis[1]=1;
    while (h<t) {
        int cur=que[++h];
        for (int p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]) continue;
            que[++t]=other[p]; dis[other[p]]=dis[cur]+1;
        }
    }
    //for (int i=1;i<=n;i++) printf("%d ",que[i]); printf("
");
    for (int i=n;i;i--) {
        int cur=que[i];
        w[cur]=1;
        for (int p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]<dis[cur]) continue;
            w[cur]=((LL)w[cur]*(w[other[p]]+1))%d39;
        }
        sum[cur]=((LL)w[cur]*a[cur])%d39;
        for (int p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]<dis[cur]) continue;
            sum[cur]=(sum[cur]+((LL)((LL)w[cur]*pw(w[other[p]]+1,d39-2)%d39)*sum[other[p]])%d39)%d39;
        }
    }
    //for (int i=1;i<=n;i++) printf("%d %d %d
",i,sum[i],w[i]);
    for (int i=1;i<=n;i++) {
        int cur=que[i];
        if (cur==1) {
            w_[cur]=w[cur];
            sum_[cur]=sum[cur];
        }
        for (int p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]<dis[cur]) continue;        
            //printf("%d
",pw(w_[cur]*w[other[p]]+1,d39-2)%d39);
            int tot=(LL)w_[cur]*pw(w[other[p]]+1,d39-2)%d39;
            //printf("%d
",tot);
            w_[other[p]]=((LL)(tot+1)%d39*w[other[p]]%d39);
            sum_[other[p]]=((LL)(tot+1)*sum[other[p]]%d39+(LL)((LL)(sum_[cur]-(LL)tot*sum[other[p]]%d39+d39))%d39*w[other[p]]%d39*pw(w[other[p]]+1,d39-2)%d39)%d39;
        }
    }
    //for (int i=1;i<=n;i++) printf("%d %d %d %d
",i,w[i],sum[i],sum_[i]);
    int ans=0;
    for (int i=1;i<=n;i++) ans=(ans+(LL)sum_[i]*b[i])%d39;
    printf("%d
",ans);
    fclose(stdin); fclose(stdout);
    return 0;
}