树套树, 这个东西我想学了好久了,之前虽说一只在听说线段树里套个set就是树套树,(一只也是这样认为的)但是我写树状数组套主席树的时候总感觉有那里很迷惑,去找dalao想要问问具体的实现方法
然后被dis回来,接下来就咕咕咕了很久,直到最近, 发现树套树真就是好简单(依然没有想明白树状数组套主席树的我, QAQ) 不够线段树套一些其他的树感觉还可以,
线段树只会维护一堆树根root,因此我就可以得到2n个root,每个根都生成一颗平衡树,每次加点就直接往加logn棵平衡树上加点,
因为线段树有很好的分割性质,可以把任意一段区间均匀的分割到logn个结点上
询问一 查询k在区间排名
我依次查询分割好的线段树结点上的平衡树,并将比k小的个数依次加起来,就是整个区间的排名, 最后答案注意+1或不加,
线段树查询复杂度 logn, 平衡树维护 logn, 查询复杂度 log^2n
询问二 查询区间排名为k的元素
因为无论是线段树还是平衡树都不是很好的能够维护好线段树合并后的区间信息,因此此项操作就很很难维护好,所以就只能用其他方法,根据说明我们可以选用二分答案,然后利用询问一查看是否合法
询问一复杂度 log^2n 二分复杂度 logn 查询复杂度 log^3n
询问三 单点修改
将在线段树上所有包含的点都删除原来的元素和插入新的元素
单点修改的线段树复杂度 logn 修改平衡树复杂度 logn 总题复杂度 log^2n
询问四 区间查询x的前驱和后继
注意取min或者max即可,原理同上
复杂度 log^n
因此线段树套平衡树复杂度 nlog^3n
对于FHQ来说, 核心代码Split和Merge,我们可以完善整棵树的很多操作,详情看上篇文章(上篇文章也不是很详细)
对于线段树,此时的线段树是需要维护区间的信息,和这个区间的平衡树的根
显然线段树只需要这样建树即可
struct node{int l, r, root};
平平无奇的定义
因此我们只需要注意的是建树操作 build(1, 1, n)
void build(int k, int l, int r)
{
a[k].l = l; a[k].r = r;
for(int i = l ; i <= r ; i ++) ins(v[i], a[k].root);
ins(inf, a[k].root);
ins(-inf, a[k].root); // 因为插入了正负inf,因此前驱就比正常的多一
if(l == r) return;
int mid = l + r >> 1;
build(k << 1, l, mid);
build(k << 1|1, mid + 1, r );// 线段树常规操作
}
对于询问三的操作显然可以当成单点修改
对于大部分FHQ平衡树的操作, 其中传递根的基本上都要去修改根
void change(int k, int pos, int val)
{
del(v[pos], a[k].root);
ins(val, a[k].root); // 因为信息都存到平衡树上了,因此必须要删除原来的点和插入新的点
// 对于信息若是已经处理完了,就覆盖掉原来的旧值
if(a[k].l == a[k].r && a[k].l == pos) {v[pos] = val; return;}
int mid = (a[k].l + a[k].r) >> 1;
if(mid >= pos) change(k << 1, pos, val);
else change(k << 1|1, pos, val);
}
// 是不是和线段树的信息维护一模一样,出来没有pushdown之类的需要向上维护的操作, 其实已经向上维护了
对于询问一,四五的操作显然在线段树上的依然是类似的
因此这里就跳出来询问一的来注释说明
void srank(int k, int l, int r, int val)
{
if(l <= a[k].l && a[k].r <= r) return ranks(val, a[k].root); // 区间被完全覆盖就直接返回这个区间的排名
int mid = (a[k].l + a[k].r) >> 1;
int ans = 0; // 记录排名值
if(l <= mid) ans += srank(k << 1, l, r, val);
if(r > mid) ans += srank(k << 1|1, l, r, val);
return ans;
}
// 是不是这个函数除了第一行的返回值不一样之外就完全和普通的线段树区间查询操作没区别
void skth(int l, int r, int val)
{
int L = 0, R = 1e8, mid, ans; //R可以只用取到max{v[i]}
while(L < R)
{
mid = L + R >> 1;
if(srank(1, l, r, mid) + 1 <= val){L = mid + 1; ans = mid;} // srank查询显然没有把val本身的大小算进来
else R = mid;
}
return ans;
}
// 简单的二分答案
下附AC代码(没有卡常, 不开O2会T三个点)
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#pragmaGCC optimize(2) // 不开过不了
#define eps 1e-9
#define endl '\n'
#define gcd __gcd
#define pi acos(-1)
#define ll long long
#define LL long long
#define IDX(x) x - 'A'
#define idx(x) x - 'a'
#define idz(x) x - '0'
#define ld long double
#define lowebit(x) x&(-x)
#define rint register int
#define Len(x) (int)(x).size()
#define all(s) (s).begin(), (s).end()
using namespace std;
inline int read()
{
register int x = 0, f = 1, ch = getchar();
while( !isdigit(ch) ){if(ch == '-') f = -1; ch = getchar();}
while( isdigit(ch) ){x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
const int inf = 2147483647;
const int maxn = 5e4 + 5;
struct fhq
{
int l, r, sie, rnd, val;
}tr[maxn * 50];
struct node
{
int l, r, root;
}a[maxn << 2];
int v[maxn];
mt19937 rnd(116551);
int tot;
int newnode(int val)
{
int now = ++ tot;
tr[now].sie = 1;
tr[now].val = val;
tr[now].rnd = rnd();
return now;
}
void update(int k) {tr[k].sie = tr[tr[k].l].sie + tr[tr[k].r].sie + 1;}
void split(int u, int val, int& x, int& y)
{
if(!u){x = y = 0; return ;}
if(tr[u].val <= val){x = u; split(tr[u].r, val, tr[u].r, y);}
else {y = u; split(tr[u].l, val, x, tr[u].l);}
update(u);
}
int Merge(int x, int y)
{
if(!x || !y) return x + y;
if(tr[x].rnd > tr[y].rnd){tr[x].r = Merge(tr[x].r, y); update(x); return x;}
tr[y].l = Merge(x, tr[y].l); update(y); return y;
}
int x, y, z;
void ins(int val, int& root) // 这里也有引用
{
split(root, val, x, y);
root = Merge(Merge(x, newnode(val)), y);
}
void del(int val, int& root)// 这里也有引用
{
split(root, val, x, z);
split(x, val - 1, x, y);
y = Merge(tr[y].l, tr[y].r);
root = Merge(Merge(x, y), z);
}
int pre(int val, int& root)// 这里也有引用
{
split(root, val - 1, x, y);
int now = x;
while(tr[now].r) now = tr[now].r;
now = tr[now].val;
root = Merge(x, y);
return now;
}
int nex(int val, int& root)// 这里也有引用
{
split(root, val, x, y);
int now = y;
while(tr[now].l) now = tr[now].l;
now = tr[now].val;
root = Merge(x, y);
return now;
}
int ranks(int val, int& root)// 这里也有引用
{
split(root, val - 1, x, y);
int tmp = tr[x].sie - 1;//消去 -inf的影响
root = Merge(x, y);
return tmp;
}
void build(int k, int l, int r)
{
a[k].l = l; a[k].r = r;
ins(inf, a[k].root);
ins(-inf, a[k].root);
for(int i = l ; i <= r ; i ++) ins(v[i], a[k].root);
if(l == r) return;
int mid = l + r >> 1;
build(k << 1, l, mid);
build(k << 1|1, mid + 1, r );
}
void sinsert(int k, int pos, int val)
{
ins(val, a[k].root);
if(pos == a[k].l && a[k].l == a[k].r) return ;
int mid = (a[k].l + a[k].r) >> 1;
if(mid >= pos) sinsert(k << 1, pos, val);
else sinsert(k << 1|1, pos, val);
}
int srank(int k, int l, int r, int val)
{
if(l <= a[k].l && a[k].r <= r) return ranks(val, a[k].root);
int ans = 0;
int mid = (a[k].l + a[k].r) >> 1;
if(l <= mid) ans += srank(k << 1, l, r, val);
if(r > mid) ans += srank(k << 1|1, l,r, val);
//if(ans) cout << a[k].l << " " << a[k].r <<" "<< ans << endl;
return ans;
}
void change(int k, int pos, int val)
{
del(v[pos], a[k].root);
ins(val, a[k].root);
if(a[k].l == a[k].r && a[k].r == pos) {return (void)(v[pos] = val);}
int mid = a[k].l + a[k].r >> 1;
if(mid >= pos) change(k << 1, pos, val);
else change(k << 1|1, pos, val);
}
int spre(int k, int l, int r, int val)
{
if(l <= a[k].l && a[k].r <= r) return pre(val, a[k].root);
int mid = (a[k].l + a[k].r) >> 1;
int ans = -inf;
if(l <= mid) ans = max(ans, spre(k << 1, l, r, val));
if(r > mid) ans = max(ans, spre(k << 1|1,l,r, val));
return ans;
}
int snxt(int k, int l, int r, int val)
{
if(l <= a[k].l && a[k].r <= r) return nex(val, a[k].root);
int mid = (a[k].l + a[k].r) >> 1;
int ans = inf;
if(l <= mid) ans = min(ans, snxt(k << 1, l, r, val));
if(r > mid) ans = min(ans, snxt(k << 1|1,l,r, val));
return ans;
}
int skth(int l, int r, int k)
{
int L = 0, R = 1e8+1, mid, ans;
while(L < R)
{
mid = (L + R) >> 1;
if(srank(1, l, r, mid) + 1<= k) L = mid + 1, ans = mid;
else R = mid;
}
return ans;
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
clock_t c1ockck = clock();
// .....................................
int n, m; n = read(); m = read();
for(int i = 1 ; i <= n ; i ++) v[i] = read();
build(1, 1, n);
for(int i = 1 ; i <= m ; i ++)
{
int op = read();
if(op == 1)
{
int l = read(), r = read(), k = read();
cout << srank(1, l, r, k) + 1 << endl;
}
else if(op == 2)
{
int l = read(), r = read(), k = read();
cout << skth(l, r, k) << endl;
}
else if(op == 3)
{
int pos = read(), k = read();
change(1, pos, k);
}
else if(op == 4)
{
int l = read(), r = read(), k = read();
cout << spre(1, l, r, k) << endl;
}
else if(op == 5)
{
int l = read(), r = read(), k = read();
cout << snxt(1, l, r, k) << endl;
}
}
// .....................................
cerr << endl << "Time:" << clock() - c1ockck << "ms" <<endl;
return 0;
}