#include <stdio.h>
#include <stdlib.h>
typedef struct Parent {
int node;
int sum;
} Parent;
typedef struct DSU {
Parent* parent;
int* rank;
} DSU;
void create_dsu(DSU* dsu, int n, int* wts)
{
dsu->parent = malloc(sizeof(Parent) * n);
dsu->rank = malloc(sizeof(int) * n);
for (int i = 0; i < n; i++) {
dsu->parent[i].sum = wts[i];
dsu->parent[i].node = i;
dsu->rank[i] = 0;
}
}
int find_parent(DSU* dsu, int n)
{
if (n == dsu->parent[n].node) {
return n;
}
return dsu->parent[n].node = find_parent(dsu, n);
}
void union_by_rank(DSU* dsu, int u, int v)
{
int up = find_parent(dsu, u);
int vp = find_parent(dsu, v);
if (up == vp) {
return;
}
else if (dsu->rank[up] > dsu->rank[vp]) {
dsu->parent[vp].node = up;
dsu->parent[up].sum += dsu->parent[vp].sum;
}
else if (dsu->rank[vp] > dsu->rank[up]) {
dsu->parent[up].node = vp;
dsu->parent[vp].sum += dsu->parent[up].sum;
}
else {
dsu->parent[up].node = vp;
dsu->rank[vp]++;
dsu->parent[vp].sum += dsu->parent[up].sum;
}
}
int find_sum(DSU* dsu, int u)
{
int up = find_parent(dsu, u); // causes a segfault
//
// printf("%d\n", dsu->parent[3].sum); -> 17
// printf("%d\n", dsu->parent[0].sum); -> 11
return (dsu->parent[up].sum);
}
int main()
{
int arr[] = { 11, 13, 1, 3, 5 };
DSU dsu;
create_dsu(&dsu, 5, arr);
union_by_rank(&dsu, 1, 3);
union_by_rank(&dsu, 2, 3);
union_by_rank(&dsu, 0, 4);
//
printf("%d\n", find_sum(&dsu, 2));
}
这里我尝试用 C 实现不相交集并集。 在程序中,我使用 union_by_rank 函数连接了 5 个节点。 函数 findSum 旨在求输入节点之一的集合之和。 例子: 节点的值为 11, 13, 1, 3, 5 如果节点 1-2-3 连接且节点 0-4 连接
那么节点 2 的总和将为 13 + 1 + 3 = 17(节点 1、2 和 3 权重的总和)
由于某种原因,当 find_parent int findSum 函数在 union 函数中正常工作时,它会出现段错误。因为我可以使用 printf 打印 findSum 中的值
find_parent 应该像在 union 函数中那样工作
好吧,所以 find_parent 的逻辑是错误的。 find_parent 应该是
int find_parent(DSU* dsu, int n)
{
if (n == dsu->parent[n].node) {
return n;
}
return dsu->parent[n].node = find_parent(dsu, dsu->parent[n].node);
}