0.题目描述
在一棵树中,将任意一条边的边权改为0 (以下称“删边”) ,使得所有询问中两个点之间距离的最大值最小(消耗的时间最少)
1.心路历程
描述很简单,也很容易抽象成一棵树,但是如何实现是个大问题,而且这道题数据范围很大, 0(n^2) 的算法肯定会被卡掉,所以一下子想出正解就很难,可以先从暴力的做法出发,可以拿 50 分,然后再去想正解。
我当时是写完暴力后想不出正解,去看了别人的题解才茅塞顿开。。。
2.暴力做法
2.1 暴力思路
非常简单,依次把每条边的边权都改成 0 试一遍,找最小的答案即可。
2.2 暴力障碍
-
快速求树上两点的距离
这个可以通过在树上维护前缀和实现,维护每个点到根节点的距离d,
x和y的距离dis[x,y]=d[x]+d[y]-2*lca[x,y](1号点到x的距离加上1号点到y的距离减去多算两次的lca[x,y]到1号点的距离即可) -
建立点和边的关系
把边权下放到出点上,在一棵树上以同一个点为出点的边只有一个,可以通过穷举点来穷举边,“删除”一条边时,只需遍历到这条边的出点时不计算当前边的权值即可。
-
不会背
lca
那你做个紫题干什么?
2.4 暴力代码
时间复杂度为 O(n^2) ,只能过一半的数据
#include <cstdio>#include <cstring>#include <iostream>
using namespace std;
const int N = 300010;
struct Node{ // lca,起点,终点 int fa, x, y; // 预处理好lca不然每次都重算一次时间复杂度就上去了}query[N];
int nxt[N * 2], w[N * 2], ver[N * 2], head[N]; // 链式前向星int d[N]; // 所有点到1号点的距离int tot;
int f[N][20], dep[N]; // 树上倍增求lcaint len;
void add(int x, int y, int e) { ver[++tot] = y; w[tot] = e; nxt[tot] = head[x]; head[x] = tot;}
void init(int x) { for (int i = 1; i <= 18; ++i) { f[x][i] = f[f[x][i - 1]][i - 1]; } for (int i = head[x]; i; i = nxt[i]) { int y = ver[i], e = w[i]; if (y == f[x][0]) continue; // 防止回搜 dep[y] = dep[x] + 1; f[y][0] = x; init(y); }}
int lca(int x, int y) { if (dep[x] > dep[y]) swap(x, y); for (int i = 18; i >= 0; --i) { if (dep[f[y][i]] >= dep[x]) y = f[y][i]; } if (x == y) return x; for (int i = 18; i >= 0; --i) { if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; } return f[x][0];}
void dfs(int x, int pos) { for (int i = head[x]; i; i = nxt[i]) { int y = ver[i], e = w[i]; if (y == f[x][0]) continue; if (y == pos) d[y] = d[x]; // pos点的入边的边权视为0 else d[y] = d[x] + e; dfs(y, pos); }}
int main() { int n, m; int res = 0x3f3f3f3f; scanf("%d%d", &n, &m); for (int i = 1; i < n; ++i) { int x, y, e; scanf("%d%d%d", &x, &y, &e); add(x, y, e), add(y, x, e); } dep[1] = 1; init(1); for (int i = 1; i <= m; ++i) { int x, y; scanf("%d%d", &x, &y); query[++len] = {lca(x, y), x, y}; // 提前维护好lca,否则TLE } for (int pos = 2; pos <= n; ++pos) { // 穷举“删掉”的边,1号点(根节点)没有入边 memset(d, 0, sizeof(d)); // 每次重算距离 dfs(1, pos); int maxn = 0; for (int i = 1; i <= len; ++i) { int x = query[i].x, y = query[i].y, fa = query[i].fa; int dis = d[x] + d[y] - 2 * d[fa]; // 两点的距离 maxn = max(maxn, dis); // 耗时最大的一组的时间为消耗的时间 } res = min(res, maxn); } printf("%d\n", res); return 0;}3.正解
3.1 思路
仔细观察这道题,可以发现答案是单调的:当答案过大的时候不用“删边”,当答案过小的时候删一条边不能满足条件,因此可以用二分答案的思路做这道题。
3.2 难点
-
代码量挺大,做好心理准备
-
判断函数的设计
如何检验目前的
mid是否合法是这道题的最大的难点做法是在“删边”前将所有询问中长度大于
mid的路径筛出来,然后找出所有长度大于mid的边经过的路径中最大的一条边,检验最长的路径减去最大边后是否大于mid。-
如果大于
mid,表明还需要删更多的边才能满足要求或不能满足要求,mid取小了 -
如果小于或等于
mid,表明“删边”后状态合法或不需要删边,mid取大了(此时mid可能是答案)
-
-
判断函数的实现
-
遍历每组询问,统计大于
mid的路径条数cnt,并记录最长路径maxn。 -
判断一条边是否被所有长度大于
mid的路径经过。 (难点)统计每条边被长度大于
mid的路径经过的次数s[i],如果s[i]=cnt,表明这条边被所有长度大于mid的路径经过。s数组可以通过树上差分实现 ( 边权差分 ),时间复杂度为O(m+n)。 -
遍历所有的边
w[i],判断maxn-w[i]与mid的关系。
-
4.提醒
-
再次提醒,提前维护好
lca,否则时间复杂度会退化,导致 TLE。 -
还是,看懂了就自己先写,写不动了再来看代码
Ac Code
#include <cstdio>#include <cstring>#include <iostream>
using namespace std;
const int N = 300010;
struct Node{ int fa, x, y;}query[N];
int nxt[N * 2], w[N * 2], ver[N * 2], head[N];int d[N], s[N], w_ver[N];/** * d[N]为到1号点的距离 * s[N]为差分数组 * w_ver[N]为每个点入边的边权*/int tot, n, m;
int f[N][21], dep[N];
void add(int x, int y, int e) { ver[++tot] = y; w[tot] = e; nxt[tot] = head[x]; head[x] = tot;}
void init(int x) { for (int i = 1; i <= 20; ++i) { f[x][i] = f[f[x][i - 1]][i - 1]; } for (int i = head[x]; i; i = nxt[i]) { const int &y = ver[i], &e = w[i]; if (y == f[x][0]) continue; dep[y] = dep[x] + 1; f[y][0] = x; d[y] = d[x] + e; // 维护距离 w_ver[y] = e; // 维护入边的边权 init(y); }}
int lca(int x, int y) { if (dep[x] > dep[y]) swap(x, y); for (int i = 20; i >= 0; --i) { if (dep[f[y][i]] >= dep[x]) y = f[y][i]; } if (x == y) return x; for (int i = 20; i >= 0; --i) { if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; } return f[x][0];}
void dfs(int x) { // 对差分数组进行前缀和运算,求出每条边被经过的次数 for (int i = head[x]; i; i = nxt[i]) { const int &y = ver[i]; if (y == f[x][0]) continue; dfs(y); s[x] += s[y]; }}
bool judge(int mid) { memset(s, 0, sizeof(s)); // 每次初始化差分数组 int maxn = 0, cnt = 0; // maxn为最大路径长度,cnt为长度大于mid的路径条数 for (int i = 1; i <= m; ++i) { int x = query[i].x, y = query[i].y, fa = query[i].fa; int dis = d[x] + d[y] - 2 * d[fa]; if (dis > mid) { maxn = max(dis, maxn); ++s[x], ++s[y]; s[fa] -= 2; ++cnt; } } if (!maxn) return true; // 不用删边,说明mid取大了,直接往左找,如果不优化会TLE一个点 dfs(1); for (int i = 1; i <= n; ++i) { // 最长路径减去最长边小于mid,状态合法,此mid和更小的mid都可能是答案,向左找 if (s[i] == cnt && maxn - w_ver[i] <= mid) return true; } return false; // 删了一条边还不够,说明mid取小了}
int main() { scanf("%d%d", &n, &m); for (int i = 1; i < n; ++i) { int x, y, e; scanf("%d%d%d", &x, &y, &e); add(x, y, e), add(y, x, e); } dep[1] = 1; init(1); for (int i = 1; i <= m; ++i) { scanf("%d%d", &query[i].x, &query[i].y); query[i].fa = lca(query[i].x, query[i].y); } int l = 0, r = 300000000; // 3e5 * 1e3,答案最大是 3e8 while (l < r) { int mid = l + r >> 1; if (judge(mid)) r = mid; // mid自身也可能是答案,不要舍掉 else l = mid + 1; } printf("%d\n", l); return 0;}部分信息可能已经过时







