[ZJOI2019]Minimax搜索

[ZJOI2019]Minimax搜索

  • 给定一棵 (n) 个点的以 (1) 为根的有根树,每个点有个权值 (w_i),权值的计算方式如下:

    [w_i = egin{cases} i & ext{if } x ext{ is leaf}\ max_{j in son_i}{w_{j}} & ext{if dep}_x ext{is odd}\ min_{j in son_i}{w_{j}} & ext{if dep}_x ext{is even} end{cases} ]

    其中 ( ext{dep}_x)(x) 的深度,根的深度为 1。

  • 记根节点的权值为 (W)

  • 你可以选择一个非空的叶子节点集合 (S),然后修改这个集合中叶子节点的权值 (w_i o nw_i),花费的代价为 (min_{i in S}|w_i - nw_i|)

  • 一个叶子节点集合的价值 (V_S) 定义为能使 (W) 发生改变所花费的最小代价。如果这个集合里的叶子节点无论怎么修改都不能使 (W) 改变,则 (V_S = n)

  • 对于 (k = L,L+1,dots, R),求 (V_S = k) 的非空叶子节点集合的数量。

  • (L leq R leq n leq 2 imes 10^5)

叶子节点 (W) 到根的路径上的节点权值都为 (W),我们称这条路径为主路径

改变主路径上任意一个点的权值就会改变根的权值。

所有包含叶子节点 (W) 的集合只需要花费 (1) 的代价修改 (W) 即可。

下面只考虑不包含叶子节点 (W) 的情况。

断开主路径上的边,单独考虑每一个联通子树,每一个联通子树的根的权值都是 (W)

对于叶子节点 (i)(w_i = i eq W) ,想要使 (W) 改变,必然要满足以下条件:

  • 如果 (w_i < W) 则它必须改成大于 (W) 的值,且它所在联通块的根的深度为奇数;
  • 如果 (w_i > W) 则它必须改成小于 (W) 的值,且它所在联通块的根的深度为偶数。

最小的花费显然是改成 (W+1)(W-1)

(f[x]) 表示在代价不超过 (k) 的前提下修改 (x) 子树中的叶子节点的权值,不能使 (W) 改变的概率

具体地,对于奇数深度的 (x)(f[x])(w_x < W) 的概率;对于偶数深度的 (x)(f[x])(w_x > W) 的概率。

然后就有 dp 式

[f[x] = prod_{y in son_x} (1 - f[y]) ]

能使 (W) 改变的总概率就是

[P = 1 - prod_{i in ext{MainChain}} f[i] ]

枚举每一个 (k) 分别 dp,差分计算答案,(O(n^2))

动态DP

发现当 (k) 变成 (k+1) 时,概率发生改变的叶子只有 (W - k)(W + k) 两个,可以用动态 DP。

动态DP写得太多了所以没写。

线段树合并

考虑直接用线段树维护DP值,即 (f[x]) 是一个 (n - 1) 维向量 ((c_1, c_2, c_3, dots, c_{n-1})) 分别表示 (k = 1, 2, 3, dots, n - 1) 时的DP值。

每个叶子节点最多只有两个区间(概率只会在 (k = |W - w_i| + 1) 时改变),用线段树合并对应项相乘。需要乘法和加法标记。

#include <bits/stdc++.h>
#define perr(a...) fprintf(stderr, a)
#define dbg(a...) perr(" 33[32;1m"), perr(a), perr(" 33[0m")
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;

constexpr int N(2e5 + 5), P(998244353);
inline void inc(int &x, int y) {
  x += y;
  if (x >= P) x -= P;
}

int n, m = 1;
std::vector<int> g[N];
int dep[N], val[N];
void dfs0(int x, int fa) {
  dep[x] = dep[fa] ^ 1;
  if (x > 1 && g[x].size() == 1) {
    val[x] = x;
    inc(m, m);
    return;
  }
  val[x] = dep[x] ? 0 : 666666;
  for (int y : g[x]) {
    if (y == fa) continue;
    dfs0(y, x);
    dep[x] ? smax(val[x], val[y]) : smin(val[x], val[y]);
  }
}
struct Node {
  Node *ls, *rs;
  int mul, val;
  Node(int m = 1, int v = 0) : ls(nullptr), rs(nullptr), mul(m), val(v) {}
  void times(int x) { mul = 1LL * mul * x % P, val = 1LL * val * x % P; }
  void add(int x) { inc(val, x); }
  void pushdown() {
    if (!ls) {
      ls = new Node(mul, val);
    } else {
      ls->times(mul);
      ls->add(val);
    }
    if (!rs) {
      rs = new Node(mul, val);
    } else {
      rs->times(mul);
      rs->add(val);
    }
    mul = 1, val = 0;
  }
};
int ask(Node *o, int l, int r, int x) {
  if (!o->ls && !o->rs) return o->val;
  int m = l + r >> 1;
  o->pushdown();
  return x <= m ? ask(o->ls, l, m, x) : ask(o->rs, m + 1, r, x);
}
void update(Node *&o, int l, int r, int x, int y, int u, int v) {
  if (!o) o = new Node;
  if (x <= l && r <= y) {
    o->times(u), o->add(v);
    return;
  }
  o->pushdown();
  int m = l + r >> 1;
  if (x <= m) update(o->ls, l, m, x, y, u, v);
  if (y > m) update(o->rs, m + 1, r, x ,y, u, v);
}
Node *merge(Node *x, Node *y) {
  if (!y->ls && !y->rs) {
    x->times(y->val);
    delete y;
    return x;
  }
  if (!x->ls && !x->rs) {
    y->times(x->val);
    delete x;
    return y;
  }
  x->pushdown(), y->pushdown();
  x->ls = merge(x->ls, y->ls);
  x->rs = merge(x->rs, y->rs);
  delete y;
  return x;
}
// int k, f[N];
Node *root[N];
void dp(int x, int fa, int p) {
  if (val[x] == x) {
    if (p ? x < val[1] : x > val[1]) {
      int k = std::abs(x - val[1]);
      update(root[x], 1, n - 1, 1, k, 0, dep[x] ?  x < val[1] : x > val[1]);
      update(root[x], 1, n - 1, k + 1, n, 0, P + 1 >> 1);
      // f[x] = P + 1 >> 1;
    } else {
      update(root[x], 1, n - 1, 1, n - 1, 0, dep[x] ?  x < val[1] : x > val[1]);
      // f[x] = dep[x] ? x < val[1] : x > val[1];
    }
    return;
  }
  update(root[x], 1, n - 1, 1, n - 1, 0, 1);
  // f[x] = 1;
  for (int y : g[x]) {
    if (y == fa) continue;
    dp(y, x, p);
    update(root[y], 1, n - 1, 1, n - 1, P - 1, 1);
    root[x] = merge(root[x], root[y]);
    // f[x] = 1LL * f[x] * (1 + P - f[y]) % P;
  }
}
void dfs(int x, int fa) {
  update(root[x], 1, n - 1, 1, n - 1, 0, val[x] == x ? P + 1 >> 1 : 1);
  // f[x] = val[x] == x && k ? P + 1 >> 1 : 1;
  for (int y : g[x]) {  
    if (y == fa) continue;
    if (val[x] == val[y]) {
      dfs(y, x);
      root[1] = merge(root[1], root[y]);
      // f[1] = 1LL * f[1] * f[y] % P;
    } else {
      dp(y, x, dep[x]);
      update(root[y], 1, n - 1, 1, n - 1, P - 1, 1);
      root[x] = merge(root[x], root[y]);
      // f[x] = 1LL * f[x] * (1 + P - f[y]) % P;
    }
  }
}
int ans[N];
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int l, r;
  std::cin >> n >> l >> r;
  for (int i = 1, x, y; i < n; i++) {
    std::cin >> x >> y;
    g[x].push_back(y), g[y].push_back(x);
  }
  dfs0(1, 0);
  dfs(1, 0);
  ans[0] = 0, ans[n] = m - 1;
  for (int i = std::max(l - 1, 1); i <= r && i < n; i++) ans[i] = (1LL + P - ask(root[1], 1, n - 1, i)) * m % P;
  for (int i = l; i <= r; i++) std::cout << (ans[i] - ans[i - 1] + P) % P << " ";
  return 0;
}