给定一棵有 n 个顶点的加权树。有 q 个查询,对于每个查询,您都会获得整数 (u,k)。找到顶点 v 的数量,使得从 u 到 v 的路线上的最小边等于 k。 (n,q <= 1e5)
我尝试使用 dfs 但我认为最好的解决方案是 O(n*q)
我当前的代码:
#include <bits/stdc++.h>
using namespace std;
const int INF = 1e9;
struct Edge {
int to;
int weight;
};
vector<vector<Edge>> adj;
vector<int> mn;
void dfs(int u, int parent, int minWeight) {
mn[u] = minWeight;
for (auto edge : adj[u]) {
if (edge.to != parent) {
dfs(edge.to, u, min(minWeight, edge.weight));
}
}
}
int main() {
int n, q;
cin >> n >> q;
adj.resize(n + 1);
mn.resize(n + 1);
for (int i = 0; i < n - 1; ++i) {
int u, v, w;
cin >> u >> v >> w;
adj[u].push_back({v, w});
adj[v].push_back({u, w});
}
while (q--) {
int u, k;
cin >> u >> k;
fill(mn.begin(), mn.end(), INF);
dfs(u, -1, INF);
int cnt = 0;
for (int v = 1; v <= n; ++v) {
if (v != u && mn[v] == k) {
cnt++;
}
}
cout << cnt << endl;
}
return 0;
}
这可以通过首先读取所有查询,然后按边权重以非递增顺序对它们进行排序来离线解决。我们可以使用不相交集来维护仅使用权重大于某个值的边形成的森林。我们还按非递增顺序对树中的边进行排序,并按该顺序添加某些权重的边。每当我们重新添加边时,我们都会检查对该特定权重的查询。添加这些边后,任何节点的组件大小的增加是路径上该边权重最小的顶点数量。请注意,查询树中不存在的边权重始终会导致
0
。
我们可以使用不相交集的修改版本,以便每个组件的根存储组件的否定大小,以便更容易回答查询以及按大小实现并集。该解决方案的时间复杂度为
O(N log N + (N + Q) log Q + (N + Q)α(N))
(其中 α
是反阿克曼函数,并且在此实际上是常数)。
这个可以在线解决,但是代码会变得复杂很多。
#include <vector>
#include <iostream>
#include <map>
#include <functional>
#include <utility>
std::vector<int> ds; // the disjoint set
int find(int u) {
return ds[u] < 0 ? u : ds[u] = find(ds[u]);
}
int main() {
int n, q;
std::cin >> n >> q;
std::vector<int> answers(q);
ds.assign(n + 1, -1);
std::map<int, std::vector<std::pair<int, int>>, std::greater<>> edgesForWeight, queriesForWeight;
for (int i = 1, u, v, w; i < n; ++i) {
std::cin >> u >> v >> w;
edgesForWeight[w].push_back({u, v});
}
for (int i = 0, u, k; i < q; ++i) {
std::cin >> u >> k;
queriesForWeight[k].push_back({i, u});
}
for (const auto& [weight, edges] : edgesForWeight) {
auto queriesIt = queriesForWeight.find(weight);
if (queriesIt != queriesForWeight.end())
for (auto [qidx, node] : queriesIt->second)
answers[qidx] = ds[find(node)];
for (auto [u, v] : edges) {
u = find(u), v = find(v);
if (ds[u] > ds[v]) std::swap(u, v);
ds[u] += ds[v];
ds[v] = u;
}
if (queriesIt != queriesForWeight.end())
for (auto [qidx, node] : queriesIt->second)
answers[qidx] -= ds[find(node)];
}
for (int ans : answers) std::cout << ans << '\n';
}