在c中实现dsu时出现没有原因的段错误(可能是一个愚蠢的错误)

问题描述 投票:0回答:1
#include <stdio.h>
#include <stdlib.h>

typedef struct Parent {
    int node;
    int sum;
} Parent;

typedef struct DSU {
    Parent* parent;
    int* rank;
} DSU;

void create_dsu(DSU* dsu, int n, int* wts)
{
    dsu->parent = malloc(sizeof(Parent) * n);
    dsu->rank = malloc(sizeof(int) * n);

    for (int i = 0; i < n; i++) {
        dsu->parent[i].sum = wts[i];
        dsu->parent[i].node = i;
        dsu->rank[i] = 0;
    }
}

int find_parent(DSU* dsu, int n)
{
    if (n == dsu->parent[n].node) {
        return n;
    }

    return dsu->parent[n].node = find_parent(dsu, n);
}

void union_by_rank(DSU* dsu, int u, int v)
{
    int up = find_parent(dsu, u);
    int vp = find_parent(dsu, v);
    if (up == vp) {
        return;
    }

    else if (dsu->rank[up] > dsu->rank[vp]) {
        dsu->parent[vp].node = up;
        dsu->parent[up].sum += dsu->parent[vp].sum;
    }

    else if (dsu->rank[vp] > dsu->rank[up]) {
        dsu->parent[up].node = vp;
        dsu->parent[vp].sum += dsu->parent[up].sum;
    }

    else {
        dsu->parent[up].node = vp;
        dsu->rank[vp]++;
        dsu->parent[vp].sum += dsu->parent[up].sum;
    }
}

int find_sum(DSU* dsu, int u)
{
    int up = find_parent(dsu, u); // causes a segfault
                                  //
    // printf("%d\n", dsu->parent[3].sum); -> 17
    // printf("%d\n", dsu->parent[0].sum); -> 11
    return (dsu->parent[up].sum);
}

int main()
{
    int arr[] = { 11, 13, 1, 3, 5 };
    DSU dsu;
    create_dsu(&dsu, 5, arr);

    union_by_rank(&dsu, 1, 3);
    union_by_rank(&dsu, 2, 3);
    union_by_rank(&dsu, 0, 4);
    //
    printf("%d\n", find_sum(&dsu, 2));
}

这里我尝试用 C 实现不相交集并集。 在程序中,我使用 union_by_rank 函数连接了 5 个节点。 函数 findSum 旨在求输入节点之一的集合之和。 例子: 节点的值为 11, 13, 1, 3, 5 如果节点 1-2-3 连接且节点 0-4 连接

那么节点 2 的总和将为 13 + 1 + 3 = 17(节点 1、2 和 3 权重的总和)

由于某种原因,当 find_parent int findSum 函数在 union 函数中正常工作时,它会出现段错误。因为我可以使用 printf 打印 findSum 中的值

find_parent 应该像在 union 函数中那样工作

c pointers data-structures segmentation-fault disjoint-union
1个回答
0
投票

好吧,所以 find_parent 的逻辑是错误的。 find_parent 应该是

int find_parent(DSU* dsu, int n)
{
    if (n == dsu->parent[n].node) {
        return n;
    }

    return dsu->parent[n].node = find_parent(dsu, dsu->parent[n].node);
}
© www.soinside.com 2019 - 2024. All rights reserved.