Luogu P2486 染色(树链剖分+线段树)

这里是简介

题解

不妨采取重链剖分的方式把路径剖成区间,然后用线段树维护,考虑如何合并一个区间

struct Node {
    int lf, rg, tot;
}seg[N << 2]; int col[N << 2];
inline Node merge(const Node &lc, const Node &rc) {
    if(!lc.tot) return rc;
    if(!rc.tot) return lc;
    Node ret = (Node){lc.lf, rc.rg, lc.tot + rc.tot};
    if(lc.rg == rc.lf) --ret.tot;
    return ret;
}

其中$Node$表示线段树中的一个节点,共有三个参数,左端点颜色,右端点颜色以及区间内颜色段数。$col$数组用于下方染色标记。

但是我们要考虑这个区间合并后是否存在相同的颜色其应该只有$1$的贡献却被记了$2$的贡献。这种情况存在当且仅当左区间的右端点颜色与右区间左端点颜色相同。

接着,有关于线段树的其他操作也没有什么好担心的了,接着考虑如何查询。

inline int doit(int x, int y) {
    int fx = top[x], fy = top[y];
    Node disx = (Node){0, 0, 0}, disy = (Node){0, 0, 0};
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) disx = merge(query(dfn[fx], dfn[x]), disx), x = fa[fx], fx = top[x];
        else disy = merge(query(dfn[fy], dfn[y]), disy), y = fa[fy], fy = top[y];
    } if(dfn[x] > dfn[y]) swap(x, y), swap(disx, disy);
    swap(disx.lf, disx.rg);
    Node ret = merge(merge(disx, query(dfn[x], dfn[y])), disy);
    return ret.tot;
}

由于重链剖分跳$top$时,两个端点的路径是独立的,所以不能像普通查询那样直接累加贡献,要分开处理,最后存在一个特殊情况,要将左区间的左右端点反置。(画图即可明白)

代码

#include <cstdio>
#include <algorithm>
using std::swap;
typedef long long ll;

const int N = 1e5 + 10;
int n, m, c[N], w[N];
int fa[N], son[N], siz[N], dep[N];
int time, dfn[N], top[N];
int cnt, from[N], to[N << 1], nxt[N << 1];
    struct Node {
        int lf, rg, tot;
    }seg[N << 2]; int col[N << 2];

void addEdge(int u, int v) {
    to[++cnt] = v, nxt[cnt] = from[u], from[u] = cnt;
}

void dfs(int u) {
    dep[u] = dep[fa[u]] + 1, siz[u] = 1;
    for(int i = from[u]; i; i = nxt[i]) {
        int v = to[i]; if(v == fa[u]) continue;
        fa[v] = u, dfs(v), siz[u] += siz[v];
        if(siz[v] > siz[son[u]]) son[u] = v;
    }
}
void dfs(int u, int t) {
    dfn[u] = ++time, top[u] = t, w[time] = c[u];
    if(!son[u]) return ; dfs(son[u], t);
    for(int i = from[u]; i; i = nxt[i]) {
        int v = to[i];
        if(v != fa[u] && v != son[u])
            dfs(v, v);
    }
}

inline Node merge(const Node &lc, const Node &rc) {
    if(!lc.tot) return rc;
    if(!rc.tot) return lc;
    Node ret = (Node){lc.lf, rc.rg, lc.tot + rc.tot};
    if(lc.rg == rc.lf) --ret.tot;
    return ret;
}
inline void pushdown(int o, int lc, int rc) {
    if(col[o]) {
        seg[lc] = (Node){col[o], col[o], 1};
        seg[rc] = (Node){col[o], col[o], 1};
        col[lc] = col[rc] = col[o], col[o] = 0;
    }
}
void build(int o = 1, int l = 1, int r = n) {
    if(l == r) { seg[o] = (Node){w[l], w[l], 1}; return ; }
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1;
    build(lc, l, mid), build(rc, mid + 1, r), seg[o] = merge(seg[lc], seg[rc]);
}
void color(int cl, int cr, int k, int o = 1, int l = 1, int r = n) {
    if(l >= cl && r <= cr) {
        seg[o] = (Node){k, k, 1}, col[o] = k;
        return ;
    }
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1;
    pushdown(o, lc, rc);
    if(cl <= mid) color(cl, cr, k, lc, l, mid);
    if(cr > mid) color(cl, cr, k, rc, mid + 1, r);
    seg[o] = merge(seg[lc], seg[rc]);
}
Node query(int ql, int qr, int o = 1, int l = 1, int r = n) {
    if(l >= ql && r <= qr) return seg[o];
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1;
    Node ret = (Node){0, 0, 0};
    pushdown(o, lc, rc);
    if(ql <= mid) ret = query(ql, qr, lc, l, mid);
    if(qr > mid) ret = merge(ret, query(ql, qr, rc, mid + 1, r));
    return ret;
}

inline void upt(int x, int y, int k) {
    int fx = top[x], fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) color(dfn[fx], dfn[x], k), x = fa[fx], fx = top[x];
        else color(dfn[fy], dfn[y], k), y = fa[fy], fy = top[y];
    } if(dfn[x] > dfn[y]) swap(x, y);
    color(dfn[x], dfn[y], k);
}
inline int doit(int x, int y) {
    int fx = top[x], fy = top[y];
    Node disx = (Node){0, 0, 0}, disy = (Node){0, 0, 0};
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) disx = merge(query(dfn[fx], dfn[x]), disx), x = fa[fx], fx = top[x];
        else disy = merge(query(dfn[fy], dfn[y]), disy), y = fa[fy], fy = top[y];
    } if(dfn[x] > dfn[y]) swap(x, y), swap(disx, disy);
    swap(disx.lf, disx.rg);
    Node ret = merge(merge(disx, query(dfn[x], dfn[y])), disy);
    return ret.tot;
}

int main () {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", c + i);
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        addEdge(u, v), addEdge(v, u);
    }
    dfs(1), dfs(1, 1), build();
    char opt; int a, b, c;
    while(m--) {
        scanf("\n%c%d%d", &opt, &a, &b);
        if(opt == 'C') {
            scanf("%d", &c);
            upt(a, b, c);
        } else printf("%d\n", doit(a, b));
    }
    return 0;
}