networkx 中的约束弹簧布局

问题描述 投票:0回答:1

我在networkx中有一个有向图。

节点有一个“高度”标签。这是高度为 0、1、2、3、4、5 和 6 的示例:

Bruhat graph that I'm trying to Spring layout

我想运行弹簧布局(二维),但将节点限制为固定高度。也就是说,我想“约束”弹簧布局,以便节点的 x 坐标移动,而 y 坐标不移动。

我对networkx比较陌生。实现这一目标的最佳方法是什么?预先感谢。

constraints networkx directed-acyclic-graphs
1个回答
1
投票

根据@Joe的要求,我在这里发布答案。

这只是将上面建议的代码修补在一起的问题。因此绝对没有原创性。

您的图表应该在每个节点上附加一个“高度”变量。因此,一旦添加了下面的代码,下面的代码就应该可以工作:

G = nx.Graph()
G.add_edges_from([[0,1],[1,2],[2,3]])
for g in G.nodes():
  G.nodes()[g]["height"] = g
draw_graph_with_height(G,figsize=(5,5))

enter image description here

#    Copyright (C) 2004-2015 by
#    Aric Hagberg <[email protected]>
#    Dan Schult <[email protected]>
#    Pieter Swart <[email protected]>
#    All rights reserved.
#    BSD license.
# import numpy as np

# taken from networkx.drawing.layout and added hold_dim
def _fruchterman_reingold(A, dim=2, k=None, pos=None, fixed=None,
                          iterations=50, hold_dim=None):
    # Position nodes in adjacency matrix A using Fruchterman-Reingold
    # Entry point for NetworkX graph is fruchterman_reingold_layout()
    try:
        nnodes, _ = A.shape
    except AttributeError:
        raise RuntimeError(
            "fruchterman_reingold() takes an adjacency matrix as input")

    A = np.asarray(A)  # make sure we have an array instead of a matrix

    if pos is None:
        # random initial positions
        pos = np.asarray(np.random.random((nnodes, dim)), dtype=A.dtype)
    else:
        # make sure positions are of same type as matrix
        pos = pos.astype(A.dtype)

    # optimal distance between nodes
    if k is None:
        k = np.sqrt(1.0 / nnodes)
    # the initial "temperature"  is about .1 of domain area (=1x1)
    # this is the largest step allowed in the dynamics.
    t = 0.1
    # simple cooling scheme.
    # linearly step down by dt on each iteration so last iteration is size dt.
    dt = t / float(iterations + 1)
    delta = np.zeros((pos.shape[0], pos.shape[0], pos.shape[1]), dtype=A.dtype)
    # the inscrutable (but fast) version
    # this is still O(V^2)
    # could use multilevel methods to speed this up significantly
    for _ in range(iterations):
        # matrix of difference between points
        for i in range(pos.shape[1]):
            delta[:, :, i] = pos[:, i, None] - pos[:, i]
        # distance between points
        distance = np.sqrt((delta**2).sum(axis=-1))
        # enforce minimum distance of 0.01
        distance = np.where(distance < 0.01, 0.01, distance)
        # displacement "force"
        displacement = np.transpose(np.transpose(delta)*(k * k / distance**2 - A * distance / k))\
            .sum(axis=1)
        # update positions
        length = np.sqrt((displacement**2).sum(axis=1))
        length = np.where(length < 0.01, 0.1, length)
        delta_pos = np.transpose(np.transpose(displacement) * t / length)
        if fixed is not None:
            # don't change positions of fixed nodes
            delta_pos[fixed] = 0.0
        # only update y component
        if hold_dim == 0:
            pos[:, 1] += delta_pos[:, 1]
        # only update x component
        elif hold_dim == 1:
            pos[:, 0] += delta_pos[:, 0]
        else:
            pos += delta_pos
        # cool temperature
        t -= dt
        pos = _rescale_layout(pos)
    return pos

def _rescale_layout(pos, scale=1):
    # rescale to (0,pscale) in all axes

    # shift origin to (0,0)
    lim = 0  # max coordinate for all axes
    for i in range(pos.shape[1]):
        pos[:, i] -= pos[:, i].min()
        lim = max(pos[:, i].max(), lim)
    # rescale to (0,scale) in all directions, preserves aspect
    for i in range(pos.shape[1]):
        pos[:, i] *= scale / lim
    return pos

def draw_graph_with_height(g,highlighted_nodes=set([]),figsize=(15,15),iterations=150,title=''):
  """ Try to draw a reasonable picture of a graph with a height feature on each node."""

  pos = { p : (5*np.random.random(),2*data["height"]) for (p,data) in g.nodes(data=True)} # random x, height fixed y.

  pos_indices = [i for i in pos.keys()]
  pos_flat = np.asarray([pos[i] for i in pos.keys()])
  A = nx.adjacency_matrix(g.to_undirected())
  Adense = A.todense()
  Adensefloat = Adense.astype(float)
  new_pos = _fruchterman_reingold(Adensefloat, dim=2, pos=pos_flat, fixed=[0,len(pos_flat)-1], iterations=iterations, hold_dim=1)

  pos_dict = { pos_indices[i] : tuple(new_pos[i]) for i in range(len(pos_indices))}

# for u,v,d in g.edges(data=True):
#   d['weight'] = float(d['t'][1]-d['t'][0])

# edges,weights = zip(*nx.get_edge_attributes(g,'weight').items())
# print(weights)

  fig, ax = plt.subplots(figsize=figsize)

  if title: fig.suptitle(title, fontsize=16)

  if highlighted_nodes:
    nx.draw(g, pos=pos_dict, alpha=.1, font_size=14,node_color='b')
    gsub = nx.subgraph(g,highlighted_nodes)
    nx.draw(gsub, pos=pos_dict, node_color='r')

  else:
    nx.draw(g,pos=pos_dict)

  plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.