To the moon 晓风の个人博客

题目链接

继续学习可持久化数据结构…

这次学习区间更新的可持久化线段树。具体来说,就是每次操作都复制一下根节点,down操作的时候创建子结点即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int NPOS = -1, N = 1e5 + 9;
ll a[N];
struct SegmentTree
{
	struct Seg
	{
		int l, r;
		ll sum;
		void upd(ll add) { sum += add * (r - l + 1); }
		friend Seg up(const Seg &lc, const Seg &rc) { return {lc.l, rc.r, lc.sum + rc.sum}; }
	};
	struct Node : Seg
	{
		int lc, rc;
		ll add;
	};
	vector<Node> v;
	SegmentTree(int l, int r) { build(l, r); }
	void build(int l, int r)
	{
		int rt = v.size();
		v.push_back({});
		v[rt].Seg::operator=({l, r, a[r] - a[l - 1]});
		v[rt].lc = v[rt].rc = NPOS;
		v[rt].add = 0;
	}
	void down(int rt)
	{
		int m = v[rt].l + v[rt].r >> 1;
		if (v[rt].lc == NPOS)
			v[rt].lc = v.size(), build(v[rt].l, m);
		else
			v.push_back(v[v[rt].lc]), v[rt].lc = v.size() - 1;
		if (v[rt].rc == NPOS)
			v[rt].rc = v.size(), build(m + 1, v[rt].r);
		else
			v.push_back(v[v[rt].rc]), v[rt].rc = v.size() - 1;
		upd(v[v[rt].lc].l, v[v[rt].lc].r, v[rt].add, v[rt].lc);
		upd(v[v[rt].rc].l, v[v[rt].rc].r, v[rt].add, v[rt].rc);
		v[rt].add = 0;
	}
	void upd(int l, int r, ll add, int rt)
	{
		if (l <= v[rt].l && v[rt].r <= r)
			return v[rt].add += add, v[rt].upd(add);
		down(rt);
		if (r <= v[v[rt].lc].r)
			upd(l, r, add, v[rt].lc);
		else if (l >= v[v[rt].rc].l)
			upd(l, r, add, v[rt].rc);
		else
			upd(l, v[v[rt].lc].r, add, v[rt].lc), upd(v[v[rt].rc].l, r, add, v[rt].rc);
		v[rt].Seg::operator=(up(v[v[rt].lc], v[v[rt].rc]));
	}
	Seg ask(int l, int r, int rt)
	{
		if (l <= v[rt].l && v[rt].r <= r)
			return v[rt];
		down(rt);
		if (r <= v[v[rt].lc].r)
			return ask(l, r, v[rt].lc);
		if (l >= v[v[rt].rc].l)
			return ask(l, r, v[rt].rc);
		return up(ask(l, v[v[rt].lc].r, v[rt].lc), ask(v[v[rt].rc].l, r, v[rt].rc));
	}
};
int main()
{
	for (int n, m, l, r, d; ~scanf("%d%d", &n, &m);)
	{
		for (int i = 1; i <= n; ++i)
			scanf("%lld", &a[i]), a[i] += a[i - 1];
		SegmentTree t(1, n);
		vector<int> root{0}, size{t.v.size()};
		for (char s[9]; m--;)
		{
			scanf("%s%d", s, &l);
			if (s[0] == 'C')
			{
				scanf("%d%d", &r, &d);
				t.v.push_back(t.v[root.back()]);
				root.push_back(t.v.size() - 1);
				t.upd(l, r, d, root.back());
				size.push_back(t.v.size());
			}
			else if (s[0] == 'B')
				root.resize(l + 1), size.resize(l + 1);
			else
			{
				scanf("%d", &r);
				if (s[0] == 'Q')
					d = root.size() - 1;
				else
					scanf("%d", &d);
				t.v.push_back(t.v[root[d]]);
				printf("%lld\n", t.ask(l, r, t.v.size() - 1).sum);
			}
			t.v.resize(size.back(), t.v[0]);
		}
	}
}

网上常见的解法是区间更新时不把标记下传而是直接带入返回结果,可以减少空间消耗(然而这里偷懒直接用vector,还直接把左右端点维护在线段树节点内…)。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll NPOS = -1, N = 1e5 + 9;
struct HJTree
{
	struct Node
	{
		int l, r, lc, rc;
		ll sum, add;
		void up(Node &lc, Node &rc)
		{
			l = lc.l;
			r = rc.r;
			sum = lc.sum + rc.sum + add * (r - l + 1);
		}
		void _add(ll a)
		{
			add += a, sum += a * (r - l + 1);
		}
	};
	vector<Node> v;
	vector<int> root;
	HJTree(ll a[], int l, int r) : root(1, build(a, l, r)) {}
	int build(ll a[], int l, int r)
	{
		if (l >= r)
			return v.push_back({l, r, NPOS, NPOS, a[l], 0}), v.size() - 1;
		int m = l + r >> 1;
		return v.push_back({l, r, build(a, l, m), build(a, m + 1, r), 0, 0}), v.back().up(v[v.back().lc], v[v.back().rc]), v.size() - 1;
	}
	int add(int l, int r, int val, int rt)
	{
		if (l <= v[rt].l && v[rt].r <= r)
			return v.push_back(v[rt]), v.back()._add(val), v.size() - 1;
		int m = v[rt].l + v[rt].r >> 1, lc = v[rt].lc, rc = v[rt].rc;
		if (r <= m)
			lc = add(l, r, val, lc);
		else if (l > m)
			rc = add(l, r, val, rc);
		else
		{
			lc = add(l, m, val, lc);
			rc = add(m + 1, r, val, rc);
		}
		return v.push_back(v[rt]), v.back().lc = lc, v.back().rc = rc, v.back().up(v[lc], v[rc]), v.size() - 1;
	}
	ll ask(int l, int r, int rt)
	{
		if (l <= v[rt].l && v[rt].r <= r)
			return v[rt].sum;
		int m = v[rt].l + v[rt].r >> 1;
		ll ret = v[rt].add * (r - l + 1);
		if (r <= m)
			return ret + ask(l, r, v[rt].lc);
		if (l > m)
			return ret + ask(l, r, v[rt].rc);
		return ret + ask(l, m, v[rt].lc) + ask(m + 1, r, v[rt].rc);
	}
};
ll a[N];
int main()
{
	for (int n, m, l, r, d; ~scanf("%d%d", &n, &m);)
	{
		for (int i = 1; i <= n; ++i)
			scanf("%lld", &a[i]);
		HJTree t(a, 1, n);
		for (char s[9]; m--;)
		{
			scanf("%s%d", s, &l);
			if (s[0] == 'C')
			{
				scanf("%d%d", &r, &d);
				t.root.push_back(t.add(l, r, d, t.root.back()));
			}
			else if (s[0] == 'B')
			{
				t.root.resize(l + 1);
				t.v.resize(t.root[l] + 1);
			}
			else
			{
				scanf("%d", &r);
				if (s[0] == 'Q')
					d = t.root.size() - 1;
				else
					scanf("%d", &d);
				printf("%lld\n", t.ask(l, r, t.root[d]));
			}
		}
	}
}