跳转至

树链剥分

问题

题目链接

题目描述在链接中,在学习树链剥分前,需要先学习线段树的基础知识。

AC Code

#include <bits/stdc++.h>
#define debug(x) cout << #x << ' ' << x << '\n'
#define endl '\n'
#define int long long
using namespace std;

typedef pair<int, int> PII;
typedef long long ll;

const int N = 100010, mod = 1e9 + 7, inf = 0x3f3f3f3f;

int n, m, r, p;
int w[N];
/*
depth 树深度 
size 子树大小 
father 节点的父节点 
son 节点的重儿子 
id 节点的dfs序
nw dfs序的新权值
top 节点的链头
*/
int depth[N], siz[N], father[N], son[N], id[N], nw[N], top[N];
vector<int> G[N];

// 线段树
struct segment
{
    int l, r, sum;
    int tag;
} tr[4 * N];

void add_tag(int u, int tag) 
{
    tr[u].tag = (tr[u].tag + tag) % p;
    tr[u].sum = (tr[u].sum + tag * (tr[u].r - tr[u].l + 1)) % p;
}

void pushup(int u)
{
    tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}

void pushdown(int u)
{
    if (tr[u].tag)
    {
        add_tag(u << 1, tr[u].tag);
        add_tag(u << 1 | 1, tr[u].tag);
        tr[u].tag = 0;
    }
}

void build(int u, int l, int r)
{
    tr[u] = {l, r, 0, 0};
    if (l == r)
    {
        tr[u].sum = nw[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void update(int u, int l, int r, int w)
{
    if (tr[u].l >= l && tr[u].r <= r)
    {
        add_tag(u, w);
        return;
    }
    pushdown(u);
    int mid = (tr[u].l + tr[u].r) >> 1;
    if (mid >= l)   update(u << 1, l, r, w);
    if (mid < r)    update(u << 1 | 1, l, r, w);
    pushup(u);
}

int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)   
        return tr[u].sum;
    pushdown(u);
    ll res = 0, mid = (tr[u].l + tr[u].r) >> 1;
    if (mid >= l)   res = (res + query(u << 1, l, r)) % p;
    if (mid < r)    res = (res + query(u << 1 | 1, l, r)) % p;
    return res;
}

// 树链剥分
void dfs1(int u, int fa)
{
    father[u] = fa; siz[u] = 1;
    for (auto v : G[u])
    {
        if (v == fa)    continue;
        depth[v] = depth[u] + 1;
        dfs1(v, u);
        siz[u] += siz[v];
        if (!son[u] || siz[son[u]] < siz[v])    son[u] = v;
    }
}

int idx = 0; // dfs序
void dfs2(int u, int fa)
{
    top[u] = fa; id[u] = ++ idx; nw[idx] = w[u];
    if (!son[u]) return;
    dfs2(son[u], fa);
    for (auto v : G[u])
    {
        if (v == father[u] || v == son[u])  continue;
        dfs2(v, v);
    }
}

void update_path(int x, int y, int w)
{
    while (top[x] != top[y])
    {
        if (depth[top[x]] < depth[top[y]])  swap(x, y);
        update(1, id[top[x]], id[x], w);
        x = father[top[x]];
    }
    if (depth[x] < depth[y])    swap(x, y);
    update(1, id[y], id[x], w);
}

void update_tree(int u, int w) // 更新子树 
{
    // dfs序长度为 该节点加上子树大小减一
    update(1, id[u], id[u] + siz[u] - 1, w);
}

int query_path(int x, int y) // 查询两节点路径
{
    ll res = 0;
    while (top[x] != top[y])
    {
        if (depth[top[x]] < depth[top[y]])  swap(x, y);
        res = (res + query(1, id[top[x]], id[x])) % p;
        x = father[top[x]];
    }
    if (depth[x] < depth[y])    swap(x, y);
    res = (res + query(1, id[y], id[x])) % p;
    return res;
}

int query_tree(int u)
{
    // dfs序长度为 该节点加上子树大小减一
    return query(1, id[u], id[u] + siz[u] - 1);
}

void solve()
{
    cin >> n >> m >> r >> p;
    for (int i = 1; i <= n; i++)
        cin >> w[i];

    for (int i = 2; i <= n; i++)
    {
        int a, b;   cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    depth[r] = 1;
    dfs1(r, 0); dfs2(r, r);
    build(1, 1, n);

    while (m--)
    {
        int op, x, y, z;
        cin >> op;
        if (op == 1)
        {
            cin >> x >> y >> z;
            update_path(x, y, z);
        }
        else if (op == 2)
        {
            cin >> x >> y;
            cout << query_path(x, y) << endl;
        }
        else if (op == 3)
        {
            cin >> x >> z;
            update_tree(x, z);
        }
        else
        {
            cin >> x;
            cout << query_tree(x) << endl;
        }
    }
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    solve();

    return 0;
}