洛谷P4074 糖果公园 树上带修莫队

莫队的综合题。。处理起来真的麻烦。。

把树压成一维,然后在括号序上莫队,要注意端点不是lca的情况,以及起点和终点必须是第一次dfs的序号。。

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define full(a, b) memset(a, b, sizeof a)
using namespace std;
typedef long long ll;
inline int lowbit(int x){ return x & (-x); }
inline int read(){
    int X = 0, w = 0; char ch = 0;
    while(!isdigit(ch)) { w |= ch == '-'; ch = getchar(); }
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
    return w ? -X : X;
}
inline int gcd(int a, int b){ return b ? gcd(b, a % b) : a; }
inline int lcm(int a, int b){ return a / gcd(a, b) * b; }
template<typename T>
inline T max(T x, T y, T z){ return max(max(x, y), z); }
template<typename T>
inline T min(T x, T y, T z){ return min(min(x, y), z); }
template<typename A, typename B, typename C>
inline A fpow(A x, B p, C lyd){
    A ans = 1;
    for(; p; p >>= 1, x = 1LL * x * x % lyd)if(p & 1)ans = 1LL * x * ans % lyd;
    return ans;
}
const int N = 300010;
int n, m, q, cnt, t, dfn, head[N], v[N], w[N], f[N], g[N], id[N], p[N][20], depth[N];
int pos[N], c[N], mt, qt, k, num[N];
ll ans, res[N];
bool vis[N];
struct Edge { int v, next; } edge[N<<1];
struct Modify { int p, pre, suc; } modify[N];
struct Query{
    int l, r, t, id;
    bool operator < (const Query &rhs) const {
        if(pos[l] != pos[rhs.l]) return l < rhs.l;
        if(pos[r] != pos[rhs.r]) return r < rhs.r;
        return t < rhs.t;
    }
}query[N];

void addEdge(int a, int b){
    edge[cnt].v = b, edge[cnt].next = head[a], head[a] = cnt ++;
}

void dfs(int s, int fa){
    f[s] = ++dfn, p[s][0] = fa, depth[s] = depth[fa] + 1;
    id[f[s]] = s;
    for(int i = 1; i <= t; i ++)
        p[s][i] = p[p[s][i - 1]][i - 1];
    for(int i = head[s]; i != -1; i = edge[i].next){
        int u = edge[i].v;
        if(u == fa) continue;
        dfs(u, s);
    }
    g[s] = ++dfn, id[g[s]] = s;
}

int lca(int x, int y){
    if(depth[x] < depth[y]) swap(x, y);
    for(int i = t; i >= 0; i --){
        if(depth[p[x][i]] >= depth[y]) x = p[x][i];
    }
    if(x == y) return y;
    for(int i = t; i >= 0; i --){
        if(p[x][i] == p[y][i]) continue;
        x = p[x][i], y = p[y][i];
    }
    return p[y][0];
}

void add(int k){
    if(vis[k]) ans -= 1LL * v[c[k]] * w[num[c[k]]--];
    else ans += 1LL * v[c[k]] * w[++num[c[k]]];
    vis[k] ^= 1;
}

void rev(int k, int p){
    if(vis[k]){
        add(k), c[k] = p, add(k);
    }
    else c[k] = p;
}

int main(){

    full(head, -1);
    n = read(), m = read(), q = read();
    t = (int)(log(n) / log(2)) + 1;
    for(int i = 1; i <= m; i ++) v[i] = read();
    for(int i = 1; i <= n; i ++) w[i] = read();
    for(int i = 1; i <= n - 1; i ++){
        int u = read(), v = read();
        addEdge(u, v), addEdge(v, u);
    }
    dfs(1, 0);
    k = (int)pow(dfn, 2.0 / 3);
    for(int i = 1; i <= n; i ++) c[i] = read();
    for(int i = 1; i <= q; i ++){
        int opt = read();
        if(opt == 0){
            ++ mt;
            modify[mt].p = read(), modify[mt].suc = read();
            modify[mt].pre = c[modify[mt].p];
            c[modify[mt].p] = modify[mt].suc;
        }
        else{
            ++ qt;
            int l = read(), r = read();
            if(f[l] > f[r]) swap(l, r);
            query[qt].l = lca(l, r) == l ? f[l] : g[l];
            query[qt].r = f[r];
            query[qt].id = qt, query[qt].t = mt;
            pos[query[qt].l] = (query[qt].l - 1) / k + 1;
            pos[query[qt].r] = (query[qt].r - 1) / k + 1;
        }
    }
    for(int i = mt; i >= 1; i --) c[modify[i].p] = modify[i].pre;
    sort(query + 1, query + qt + 1);
    int l = 1, r = 0, ti = 0;
    for(int i = 1; i <= qt; i ++){
        int curL = query[i].l, curR = query[i].r, curT = query[i].t;
        while(l < curL) add(id[l++]);
        while(r < curR) add(id[++r]);
        while(l > curL) add(id[--l]);
        while(r > curR) add(id[r--]);
        while(ti < curT) ti ++, rev(modify[ti].p, modify[ti].suc);
        while(ti > curT) rev(modify[ti].p, modify[ti].pre), ti --;
        int f = lca(id[l], id[r]);
        if(f != id[l] && f != id[r]){
            add(f), res[query[i].id] = ans, add(f);
        }
        else res[query[i].id] = ans;
    }
    for(int i = 1; i <= qt; i ++){
        printf("%lld
", res[i]);
    }
    return 0;
}