cdq分治学习笔记

这里是简介

前言

感谢$__stdcall$的讲解,感谢伟大的导师$_tham$提供一系列练手题

cdq分治是什么?

国人(陈丹琦)引进的算法,不同于一般的分治,我们常说的分治是将问题分成互不影响的几个区间,递归进行处理,而所谓$cdq$分治,在处理一个区间时,还要计算它对其他区间的贡献。

二维偏序问题

给定$n$个二元组$[a,b]$,$m$次询问,每次给定其中的一个二元组$[c,d]$,求满足条件$c<a\&d<b$的二元组的个数

不知道怎么做?逆序对你总会求吧?逆序对就是一种经典的二维偏序问题,我们不妨这样转换逆序对问题:

给定$n$个数,定义一个二元组为$[$元素下标,元素值$]$,则共有$n$个这样的二元组

我们只需将约束条件改为:$cb$就行了。

那么,解决二维偏序的一般模式,也只需要改一下合并时的那一句话就好了。

PS:啊?你忘了怎么用归并排序求逆序对?戳我

相同的,我们也可以用树状数组来求解。复杂度同样为$O(nlogn)$


既然我们能用树状数组来解决用$cdq$分治的题,那我们能不能用$cdq$分治来解决树状数组的题目呢?当然可以,比如这道:Luogu3374 树状数组1

给定一个$n​$个元素的序列$a​$,初始值全部为$0​$,对这个序列进行以下两种操作

操作$1$:格式为$1\ x\ k$,把所有位置$x$的元素加上$k$

操作$2$:格式为$2 x y$,求出区间$[x,y]$内所有元素的和。

这显然是一道树状数组模板题,考虑如何用$cdq$分治来解决它。

我们不妨以修改的时间为第一关键字,修改元素的位置为第二关键字。由于时间已经有序,我们定义结构体包含$3$个元素:$opt,ind,val$,其中$ind$表示操作的位置,$opt$为$1$表示修改,$val$表示“加上的值”。而对于查询,我们用前缀和的思想把他分解成两个操作:$sum[1,y]-sum[1,x-1]$,即分解成两次前缀和的查询。在合并的过程中,$opt$为$2$表示遇到了一个查询的左端点$x-1$,对结果作负贡献,$opt$为$3$表示遇到了一个查询的右端点$y$,对结果作正贡献,$val$表示“是第几个查询”。这样,我们就把每个操作转换成了带有附加信息的有序对(时间,位置),然后对整个序列进行$cdq$分治。

#include <cstdio>
#include <cstring>
#include <algorithm>
using std::min;
using std::max;
using std::swap;
using std::sort;
typedef long long ll;

const int N = 5e5 + 10, M = 5e5 + 10;
int n, m, aid, qid;
ll ans[M];
struct Query {
    int ind, opt; ll val;
    inline bool operator < (const Query a) const {
        return ind == a.ind ? opt < a.opt : ind < a.ind;
    }
}q[(M << 1) + N], tmp[(M << 1) + N];

inline void cdq (int l, int r) {
    if (l == r) return ;
    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    int i = l, j = mid + 1, p = l; ll sum = 0;
    while (i <= mid && j <= r)
        if (q[i] < q[j]) {
            if (q[i].opt == 1) sum += q[i].val;
            tmp[p++] = q[i++];
        } else {
            if (q[j].opt == 2) ans[q[j].val] -= sum;
            if (q[j].opt == 3) ans[q[j].val] += sum;
            tmp[p++] = q[j++];
        }
    while (i <= mid) { if (q[i].opt == 1) sum += q[i].val; tmp[p++] = q[i++]; }
    while (j <= r) {
        if (q[j].opt == 2) ans[q[j].val] -= sum;
        if (q[j].opt == 3) ans[q[j].val] += sum;
        tmp[p++] = q[j++];
    }
    for (int k = l; k <= r; ++k) q[k] = tmp[k];    
}

int main () {
    scanf ("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) {
        q[++qid].ind = i, q[qid].opt = 1;
        scanf("%lld", &q[qid].val);
    }
    int opt, ind, l, r; ll val;
    for (int i = 1; i <= m; ++i) {
        scanf("%d", &opt);
        if (opt == 1) scanf("%d%lld", &ind, &val), q[++qid] = (Query){ind, 1, val};
        else {
            scanf ("%d%d", &l, &r);
            q[++qid] = (Query){l - 1, 2, ++aid}, q[++qid] = (Query){r, 3, aid};
        }
    }
    cdq(1, qid);
    for (int i = 1; i <= aid; ++i)
        printf("%lld\n", ans[i]);
    return 0;
}

三维偏序问题

给定$n$个三元组$[a,b,c]$,$m$次询问,每次给定其中的一个二元组$[d,e,f]$,求满足条件$d<a\&e<b\&f<c$的二元组的个数

相同的,我们也可以采取用其他方法来解决三位偏序问题,如$bitset$、$KD\ Tree$、树套树等…比如我们可以以$a$为关键字排序,同时用$BIT$套平衡树来维护剩下的两个元素。

接着考虑如何用$cdq$分治来解决这个问题,我们可以考虑先以$a$为关键字对数组排序,这样我们的问题就成了维护后两个元素了。接下来,我们以一个经典的三维偏序题:陌上花开来做具体说明(由于这道题较为经典,在各大$OJ$都能找到,不给出链接)


题面

有n朵花,每朵花有三个属性:花形(s)、颜色(c)、气味(m),由三个整数表示。现要对每朵花评级,一朵花的级别是它拥有的美丽能超过的花的数量。定义一朵花A比花B要美丽,当且仅Sa>=Sb,Ca>=Cb,Ma>=Mb。显然,两朵花可能有同样的属性。需要统计出评出每个等级的花的数量。

题解

  1. 就如刚才所说的,以$a$为关键字进行排序
struct Node {
    int a, b, c, mult, ans;
    inline void Init() {
        read(a), read(b), read(c);
    }
} v[N], d[N];
inline bool cmpx (Node x, Node y) {
    return (x.a < y.a) || (x.a == y.a && x.b < y.b) || (x.a == y.a && x.b == y.b && x.c < y.c);
}
int main () {
    read(n), read(k);
    for (int i = 1; i <= n; ++i) v[i].Init();
    sort(&v[1], &v[n + 1], cmpx);
}
  1. 然后,我们会发现,普通的三位偏序只用处理小于,而不是小于等于,根据题意,完全相同属性的花是不计算在内的,所以我们得考虑将其去重。
for (int i = 1; i <= n; ++i) {
    ++mul;//相同元素的个数
    //这里的异或你可以理解为不等于,由于之前已经排过序(见函数cmpx),可以线性比较,mult表示重复元素的个数
    if ((v[i].a ^ v[i + 1].a) || (v[i].b ^ v[i + 1].b) || (v[i].c ^ v[i + 1].c))
        d[++m] = v[i], d[m].mult = mul, mul = 0;
}
  1. 接着,我们考虑如何进行$cdq$分治,同样是在计算左区间时,处理右区间的询问,不妨采用$two-pointers$,两个指针$i,j$分别指向左右两个区间,这时候我们以$b$为关键字进行比较,如果$d[i].b<=d[j].b$,则将$d[i].c$插入权值$BIT$中,反之则在$BIT$中查询比$d[j].c$小的数的个数,作正贡献。在两个区间都扫完后,我们要考虑清空$BIT$,防止在接下来的递归回溯中被添加多次。
inline bool cmpy (Node x, Node y) {
    return (x.b < y.b) || (x.b == y.b && x.c < y.c);
}
inline void cdq (int l, int r) {
    if (l == r) return ;
    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    int i = l;
    for (int j = mid + 1; j <= r; ++j) {
        while (d[i].b <= d[j].b && i <= mid) update(d[i].c, d[i].mult), ++i;
        d[j].ans += query(d[j].c);
        //ans表示小于等于它的个数
    }
    //清空BIT
    for (int k = l; k < i; ++k)
        update(d[k].c, -d[k].mult);
    inplace_merge(&d[l], &d[mid + 1], &d[r + 1], cmpy);
    //这个函数表示将区间[l,mid+1)和[mid+1,r+1)按照cmpy方法合并
}
  1. 计算答案。
for (int i = 1; i <= m; ++i) ans[d[i].ans + d[i].mult - 1] += d[i].mult;
for (int i = 0; i < n; ++i) printf("%d\n", ans[i]);

代码

#include <cstdio>
#include <algorithm>
using std::sort;
using std::inplace_merge;
typedef long long ll;

template<typename T>
inline void read (T &x) {
    char ch = getchar(); int flag = 1;
    while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') flag = -flag, ch = getchar();
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
    x *= flag;
}

const int N = 1e5 + 10, K = 2e5 + 10;
int n, m, k, mul, ans[N], bit[K];
struct Node {
    int a, b, c, mult, ans;
    inline void Init() {
    read(a), read(b), read(c);
    }
} v[N], d[N];
inline bool cmpx (Node x, Node y) {
    return (x.a < y.a) || (x.a == y.a && x.b < y.b) || (x.a == y.a && x.b == y.b && x.c < y.c);
}
inline bool cmpy (Node x, Node y) {
    return (x.b < y.b) || (x.b == y.b && x.c < y.c);
}

inline int lowbit (int x) { return x & (-x); }
inline void update (int pos, int val) {
    while (pos <= k) bit[pos] += val, pos += lowbit(pos);
}
inline int query (int pos) {
    int val = 0;
    while (pos) val += bit[pos], pos -= lowbit(pos);
    return val;
}

inline void cdq (int l, int r) {
    if (l == r) return ;
    int mid = (l + r) >> 1;
    cdq(l, mid), cdq(mid + 1, r);
    int i = l;
    for (int j = mid + 1; j <= r; ++j) {
        while (d[i].b <= d[j].b && i <= mid) update(d[i].c, d[i].mult), ++i;
        d[j].ans += query(d[j].c);
    }
    for (int k = l; k < i; ++k)   
        update(d[k].c, -d[k].mult);
    inplace_merge(&d[l], &d[mid + 1], &d[r + 1], cmpy);
}

int main () {
    read(n), read(k);
    for (int i = 1; i <= n; ++i) v[i].Init();
    sort(&v[1], &v[n + 1], cmpx);
    for (int i = 1; i <= n; ++i) {
        ++mul;
        if ((v[i].a ^ v[i + 1].a) || (v[i].b ^ v[i + 1].b) || (v[i].c ^ v[i + 1].c))
            d[++m] = v[i], d[m].mult = mul, mul = 0;
    }
    cdq(1, m);
    for (int i = 1; i <= m; ++i) ans[d[i].ans + d[i].mult - 1] += d[i].mult;
    for (int i = 0; i < n; ++i) printf("%d\n", ans[i]);
    return 0;
}