问题特点,对一颗树的子树内所有节点集体操作。可以将树上问题转为区间问题。
通过dfs,求出每个节点x
第一次到达的时间a
和最后一次到达的时间b
,显然它以及它的后代节点所经过的时间范围在[a,b)
,这样操作一个子树就是相当于操作一个连续区间。
vector<vector<int>> g(n+1); // 树g,g[i]代表i的子节点集合。
vector<int> l(n+1), r(n+1); // 节点x所代表的闭区间[ l[x], r[x] ]
int clk = 0;
function<void(int)> dfs = [&](int u) {
l[u] = ++clk;
for (int v:g[u]) {
dfs(v);
}
r[u] = clk;
};
dfs(1);
/* ******************************************************* */
vector<vector<int>> g(n+1); // 树g,g[i]代表i的子节点集合。
vector<int> l(n+1), r(n+1); // 节点x所代表的左闭右开区间[ l[x], r[x] )
int clk = 1;
function<void(int)> dfs = [&](int u) {
l[u] = clk++;
for (int v:g[u]) {
dfs(v);
}
r[u] = clk;
};
dfs(1);
**LCP 05. 发 LeetCoin 模板题**
#define MOD 1000000007
// #define MOD 998244353
#define ll long long
#define mint Modint<MOD>
template <const int _MOD>
struct Modint {
int v;
Modint() { v = 0; }
Modint(long long o) { v = o % _MOD; }
int val() { return v; }
int pow(long long o) {
int ret = 1, tmp = v;
while (o) {
if (o & 1)
ret = ((long long)ret * tmp) % _MOD;
o >>= 1;
tmp = ((long long)tmp * tmp) % _MOD;
}
return ret;
}
void operator=(long long o) { v = o % _MOD; }
bool operator==(long long o) const { return v == o; }
bool operator==(Modint o) const { return v == o.v; }
bool operator!=(long long o) const { return v != o; }
bool operator!=(Modint o) const { return v != o.v; }
bool operator<(long long o) const { return v < o; }
bool operator<(Modint o) const { return v < o.v; }
bool operator>(long long o) const { return v > o; }
bool operator>(Modint o) const { return v > o.v; }
bool operator<=(long long o) const { return v <= o; }
bool operator<=(Modint o) const { return v <= o.v; }
bool operator>=(long long o) const { return v >= o; }
bool operator>=(Modint o) const { return v >= o.v; }
Modint operator+(long long o) const { return *this + Modint(o); }
Modint operator+(Modint o) const { return ((long long)v + o.v) % _MOD; }
Modint operator*(long long o) const { return *this * Modint(o); }
Modint operator*(Modint o) const { return (long long)v * o.v % _MOD; }
Modint operator-(long long o) const { return *this - Modint(o); }
Modint operator-(Modint o) const {
return ((long long)v - o.v + _MOD) % _MOD;
}
Modint operator/(long long o) const { return *this / Modint(o); }
Modint operator/(Modint o) const {
return ((long long)v * o.pow(_MOD - 2)) % _MOD;
}
void operator+=(long long o) { *this = *this + o; }
void operator+=(Modint o) { *this = *this + o; }
void operator*=(long long o) { *this = *this * o; }
void operator*=(Modint o) { *this = *this * o; }
void operator-=(long long o) { *this = *this - o; }
void operator-=(Modint o) { *this = *this - o; }
void operator/=(long long o) { *this = *this / o; }
void operator/=(Modint o) { *this = *this / o; }
Modint operator^(long long o) { return Modint(pow(o)); }
Modint operator^(Modint o) { return Modint(pow(o.v)); }
template <class T>
friend bool operator==(T o, Modint u) {
return u == o;
}
template <class T>
friend Modint operator+(T o, Modint u) {
return u + o;
}
template <class T>
friend Modint operator*(T o, Modint u) {
return u * o;
}
template <class T>
friend Modint operator-(T o, Modint u) {
return Modint(o) - u;
}
template <class T>
friend Modint operator/(T o, Modint u) {
return Modint(o) / u;
}
void operator++() { *this = *this + 1; }
void operator--() { *this = *this - 1; }
void operator++(int k) { *this = *this + 1; }
void operator--(int k) { *this = *this - 1; }
template <const int T>
friend std::istream& operator>>(std::istream& in, Modint<T>& modint) {
ll x;
in >> x;
modint = Modint<T>(x);
return in;
}
template <const int T>
friend std::ostream& operator<<(std::ostream& os, const Modint<T>& modint) {
os << modint.v;
return os;
}
};
class Solution {
public:
#define N 50005
mint a[N];
struct Seg{
int l, r;
mint val, tag;
} seg[N<<2];
void push_up(Seg& u, const Seg& l, const Seg& r) {
u.val = l.val + r.val;
}
void push_down(Seg& u, Seg& l, Seg& r) {
l.val += u.tag*(l.r-l.l+1);
l.tag += u.tag;
r.val += u.tag*(r.r-r.l+1);
r.tag += u.tag;
u.tag = 0;
}
void seg_build(int id, int l, int r) {
seg[id].l = l; seg[id].r = r;
if (l == r) {
seg[id].val = a[l];
// cin >> seg[id].val;
return ;
}
int m = l+r>>1;
seg_build(id<<1, l, m);
seg_build(id<<1|1, m+1, r);
push_up(seg[id], seg[id<<1], seg[id<<1|1]);
}
void seg_update(int id, int l, int r, mint val) {
if (l <= seg[id].l && seg[id].r <= r) {
seg[id].val += val*(seg[id].r-seg[id].l+1);
seg[id].tag += val;
return ;
}
push_down(seg[id], seg[id<<1], seg[id<<1|1]);
int m = seg[id].l + seg[id].r >> 1;
if (l <= m) seg_update(id<<1, l, r, val);
if (m < r) seg_update(id<<1|1, l, r, val);
push_up(seg[id], seg[id<<1], seg[id<<1|1]);
}
mint seg_query(int id, int l, int r) {
if (l <= seg[id].l && seg[id].r <= r) {
return seg[id].val;
}
push_down(seg[id], seg[id<<1], seg[id<<1|1]);
mint rt = 0;
int m = seg[id].l + seg[id].r >> 1;
if (l <= m) rt += seg_query(id<<1, l, r);
if (m < r) rt += seg_query(id<<1|1, l, r);
return rt;
}
vector<int> bonus(int n, vector<vector<int>>& leadership, vector<vector<int>>& operations) {
vector<vector<int>> g(n+1);
for (auto& i:leadership) {
g[i[0]].push_back(i[1]);
}
vector<int> l(n+1), r(n+1);
int clk = 0;
function<void(int)> dfs = [&](int u) {
l[u] = ++clk;
for (int v:g[u]) {
dfs(v);
}
r[u] = clk;
};
dfs(1);
// for (int i=1; i<=n; i++) {
// cout << l[i] << " " << r[i] << endl;
// }
seg_build(1, 1, n);
vector<int> ans;
for (auto& i:operations) {
if (i[0] == 1) {
seg_update(1, l[i[1]], l[i[1]], i[2]);
}
if (i[0] == 2) {
seg_update(1, l[i[1]], r[i[1]], i[2]);
}
if (i[0] == 3) {
ans.push_back(seg_query(1, l[i[1]], r[i[1]]).v);
}
}
return ans;
}
};
通过求出树的欧拉序,可以将树上最近公共祖先问题转化为区间问题。
求lca的rmq方法,先求树的欧拉序,并预处理出每个节点的深度。然后记录每个点x
在欧拉序中的第一个位置p[x]
。求两个点x
和y
的lca实际上就是欧拉序中的子区间[p[x], p[y]]
中深度最小的点。
感觉单纯的求lca可以用,如果还有路径上的值需要维护,还得倍增。
vector<int> eulerTour, fc(n+1), dep(n+1); // 节点编号从1开始
function<void(int,int,int)> dfs = [&](int u, int fno) {
dep[u] = dep[fa[u][0]] + 1; // 每个节点的深度
// 每个节点欧拉序中第一次出现的位置
fc[u] = eulerTour.size(); // 每个节点第一次在欧拉序中出现的位置
eulerTour.push_back(u); // 欧拉序
for (auto [v, e] : g[u]) {
if (v == fno)
continue;
dfs(v, e, u);
eulerTour.push_back(u);
}
};
dfs(1,0);
// 用st表维护欧拉序区间最小值,按照深度比较
int st[eulerTour.size()][30]; //st[i][j] 代表区间[i, i+2^j)最小值
auto ST = [&](const vector<int>& a) {
int sz = a.size();
for (int i=0; i<sz; i++) st[i][0] = a[i];
for (int j=1; (1<<j)<=sz; j++) {//区间大小
for (int i=0; i+(1<<j)-1<sz; i++) {//区间下限
int x = st[i][j-1], y = st[i+(1<<(j-1))][j-1];
st[i][j] = dep[x] < dep[y] ? x : y;
}
}
};
ST(eulerTour);
auto ask = [&](int l, int r) {
int k = 0;
while ((1<<(k+1))<=r-l+1) k++;
int x = st[l][k], y = st[r-(1<<k)+1][k];
return dep[x] < dep[y] ? x : y;
};
// 获取x和y的lca
ask(min(fc[x], fc[y]), max(fc[x], fc[y]));