从点创建线串,每个线串具有最大数量的点,无交叉,每个线串具有唯一点

问题描述 投票:0回答:1
import pandas as pd
import geopandas as gpd
import itertools
from shapely.geometry import Point, LineString, MultiLineString

points = pd.DataFrame({'X':[1, 1, 1, 1, 4, 4, 4, 4],
                     'Y': [1, 2, 3, 4, 1, 2, 3, 4]})

gdf_points = gpd.GeoDataFrame(points, geometry=gpd.points_from_xy(x=points.X, y=points.Y))
gdf_points = gdf_points['geometry']

max_string = 3
start_point = Point(2, 0.5)

enter image description here

目标:

  • 从红点开始创建线串,start_point
  • 字符串应最大连接max_string
  • 如果 LineString 连接 max_string,则在 start_point
  • 处开始一个新字符串
  • 线不能交叉
  • 一个点不应该连接多条线

80 个解决方案之一可能是:

enter image description here

我的做法:

def points_on_multiple_lines(points, lines):
    for point in points:
        is_on_line = 0
        for line in lines:
            if line.dwithin(point, 0.0):
                is_on_line += 1

                start = Point(line.coords[0])
                end = Point(line.coords[1])
                if (not point.equals(start)) and (not point.equals(end)):
                    return False

        if is_on_line >= 3:
            return False
    return True

def lines_not_touch(lines):
    for line in lines:
        for i in lines:
            if line.crosses(i):
                return False

    return True


valid_lines = []
for shuffled_points in itertools.permutations(gdf_points):
    string_line = []
    complete_line = []
    for i, point in enumerate(shuffled_points):
        if i == 0:
            start_line = LineString([start_point, point])
            complete_line.append(start_line)
        elif i % max_string == 0:
            start_line = LineString([start_point, point])
            complete_line.append(start_line)
        else:
            current_line = LineString([shuffled_points[i - 1], point])
            complete_line.append(current_line)

    if points_on_multiple_lines(shuffled_points, complete_line) and lines_not_touch(complete_line):
        valid_lines.append(MultiLineString(complete_line))

点较多时,执行时间将迅速增加。您知道如何改进代码或者是否有一些软件包可以提供帮助?

python-3.x geopandas shapely
1个回答
0
投票

我真的深入研究了它,并能够让它用你的代码在大约 7 秒而不是 28 秒内计算所有有效行。

此外,它应该能够很好地扩展大量点,因为并非所有组合都是根据需求创建和测试的。

这是我的方法:

1。根据每行点数查找工作组合列表

在您的情况下,您想要使用所有 8 个点以及从 1 点到 3 点范围内的线,因此线长度的所有组合为:

  • (2,3,3)
  • (1,1,3,3)
  • (1,2,2,3)
  • (2,2,2,2)
  • (1, 1, 1, 2, 3)
  • (1, 1, 2, 2, 2)
  • (1, 1, 1, 1, 1, 3)
  • (1, 1, 1, 1, 2, 2)
  • (1, 1, 1, 1, 1, 1, 2)
  • (1, 1, 1, 1, 1, 1, 1, 1)

2。创建从

start_point

开始的所有可能的行

3.根据长度组合创建所有线条组合

4。仅选择线之间没有重复点的组合

5。选择除

start_point

以外的线条不相交或接触的组合

瞧:

import os
import itertools
from collections import Counter
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point, LineString
import matplotlib as mpl
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor, as_completed


def find_combinations(numbers, target):
    valid_combinations = []
    for r in range(1, target + 1):
        # The use of itertools.combinations_with_replacement allows for duplicates in combinations.
        # For example, summing to 4 with 1 and 2 can be (1, 1, 1, 1), (1, 1, 2), (2, 2).
        combinations_list = itertools.combinations_with_replacement(numbers, r)
        for comb in combinations_list:
            if sum(comb) == target:
                valid_combinations.append(comb)
    return valid_combinations


def find_all_lines(points, max_length):
    # We create a dictionnary to group available lines by length.
    lines = {length: [] for length in range(1, max_length + 1)}
    for r in range(1, max_length + 1):
        for combination in itertools.combinations(points, r):
            line = [start_point] + list(combination)
            lines[r].append(LineString(line))
    return lines
        

def filter_combinations(lines, target_combs):
    
    all_combinations = []
    
    for comb in target_combs:
        
        combs = []
        count = Counter(comb)
        
        # We get the combinations for each length of line.
        # For example if a combination of length is (1, 2, 2, 3),
        # it will return all :
        #   - 1-element long lines
        #   - pairs of 2-elements long lines
        #   - 3-element long lines
        for element, cnt in count.items():
            combs.append(itertools.combinations(lines[element], cnt))
        
        # Use itertools.product to find all combinations when multiple lengths are mixed
        combinations_list = list(itertools.product(*combs))
        
        # Flatten the combinations (no tuples)
        for comb in combinations_list:
            flat_comb = []
            for sub_comb in comb:
                flat_comb.extend(sub_comb)
            all_combinations.append(flat_comb)
               
    return all_combinations


def remove_duplicate_points_chunk(chunk):
    # This is a worker for parallel mode
    unique_combinations = []
    for comb in chunk:
        points_used = set()
        valid = True
        for line in comb:
            points_in_line = list(line.coords)[1:]
            for point in points_in_line:
                if point in points_used:
                    valid = False
                    break
                points_used.add(point)
            if not valid:
                break
        if valid:
            unique_combinations.append(comb)
    return unique_combinations

def remove_duplicate_points(combinations, parallel= True):
    unique_combinations = []
    
    if parallel:
        num_workers=os.cpu_count()
        chunk_size = len(combinations) // num_workers
        chunks = [combinations[i:i + chunk_size] for i in range(0, len(combinations), chunk_size)]
        with ProcessPoolExecutor(max_workers=num_workers) as executor:
            futures = [executor.submit(remove_duplicate_points_chunk, chunk) for chunk in chunks]
            for future in as_completed(futures):
                unique_combinations.extend(future.result())  
    else:
        for comb in combinations:
            points_used = set()
            valid = True
            for line in comb:
                points_in_line = list(line.coords)[1:]
                for point in points_in_line:
                    if point in points_used:
                        valid = False
                        break
                    points_used.add(point)
                if not valid:
                    break
            if valid:
                unique_combinations.append(comb)
    
    return unique_combinations

    
def remove_intersecting_lines(combinations, start_point):
    non_intersecting_combinations = []

    for comb in combinations:
        valid = True
        for i in range(len(comb)):
            for j in range(i + 1, len(comb)):
                line1 = comb[i]
                line2 = comb[j]
                
                # Exclude intersections at the start point
                if line1.intersects(line2):
                    intersection = line1.intersection(line2)
                    if not (isinstance(intersection, Point) and intersection.equals(start_point)):
                        valid = False
                        break
            if not valid:
                break
        if valid:
            non_intersecting_combinations.append(comb)

    return non_intersecting_combinations

    
if __name__ == "__main__":
    
    debug = True
    parallel = True
    
    # Data setup
    points = pd.DataFrame({'X': [1, 1, 1, 1, 4, 4, 4, 4],
                        'Y': [1, 2, 3, 4, 1, 2, 3, 4]})

    gdf_points = gpd.GeoDataFrame(points, geometry=gpd.points_from_xy(x=points.X, y=points.Y))
    gdf_points = gdf_points['geometry']
    max_string = 3
    start_point = Point(2, 0.5)
    
    # Find length combinations for using all points
    total_points = len(points)
    target_combinations = find_combinations([i for i in range(1, max_string + 1)], total_points)
    if debug:
        print(f"{len(target_combinations)} combinations that sum to {total_points}:")
        for comb in target_combinations:
            print(comb)
    
    # Find all possible lines
    all_lines = find_all_lines(gdf_points, max_string)
    if debug:
        print(f"{np.sum([len(all_lines[group]) for group in all_lines])} lines available")
    
    # Create lines combinations based on length compatibility
    filtered_combinations = filter_combinations(all_lines, target_combinations)
    if debug:
        print(f"{len(filtered_combinations)} combinations that use all points")
        
    # Remove combinations with duplicate points
    unique_combinations = remove_duplicate_points(filtered_combinations, parallel=parallel)
    if debug:
        print(f"{len(unique_combinations)} combinations with no duplicates")
        
    # Remove combinations with intersecting lines
    non_intersecting_combinations = remove_intersecting_lines(unique_combinations, start_point)
    if debug:
        print(f"{len(non_intersecting_combinations)} non-intersecting combinations")

我使用

concurrent.futures.ProcessPoolExecutor
来并行执行,但可以使用
parallel=False
标志禁用它。如果代码以串行模式运行,则在我的计算机上搜索重复点需要 12 秒,而不是并行模式下的 6 秒。

如果你想绘图,我编写了一段代码来绘制单个组合或批量子图:

# OPTIONAL: Plotting the results
    plot = False
    plot_all = False
    subplot = True 
    
    if plot:
        
        if plot_all:
            i_max = len(non_intersecting_combinations)
        else:
            i_max = 5
            
        for i, comb in enumerate(non_intersecting_combinations):
            
            if i<i_max:        
                fig, ax = plt.subplots()
                
                cmap = mpl.colormaps['viridis'].resampled(len(comb))
                
                # Plot each valid line with a different color
                for j, line in enumerate(comb):
                    x, y = line.xy
                    ax.plot(x, y, color=cmap.colors[j], linewidth=2)
                    ax.plot(x, y, 'o', color=cmap.colors[j])

                # Plot the start point in red
                ax.plot(start_point.x, start_point.y, 'ro')

                # Set labels and title
                ax.set_xlabel('X')
                ax.set_ylabel('Y')
                ax.set_title('Non-Crossing LineStrings from Start Point')

    if subplot:
        
        # Number of subplots per figure
        max_subplots_per_fig = 40
        ncols = 10
        nrows = 4

        # Calculate the number of figures needed
        num_figs = (len(non_intersecting_combinations) + max_subplots_per_fig - 1) // max_subplots_per_fig

        for fig_idx in range(num_figs):
            fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 8))
            fig.subplots_adjust(hspace=0.5, wspace=0.5)
            axs = axs.ravel()
            
            for i in range(max_subplots_per_fig):
                comb_idx = fig_idx * max_subplots_per_fig + i
                if comb_idx >= len(non_intersecting_combinations):
                    break
                
                comb = non_intersecting_combinations[comb_idx]
                ax = axs[i]
                
                cmap = mpl.colormaps['viridis'].resampled(len(comb))
                
                for j, line in enumerate(comb):
                    x, y = line.xy
                    ax.plot(x, y, color=cmap.colors[j], linewidth=2)
                    ax.plot(x, y, 'o', color=cmap.colors[j])
                
                ax.plot(start_point.x, start_point.y, 'ro')
                ax.set_title(f'Combination {comb_idx + 1}', fontsize=8)
                
                ax.tick_params(axis='both', which='major', labelsize=8)
            
            # Hide unused subplots
            for k in range(i, len(axs)):
                fig.delaxes(axs[k])
                
            fig.tight_layout()
            
    plt.show()

希望有帮助!

我不知道它是否可以再次改进,但这是优化你的方法的一步。

© www.soinside.com 2019 - 2024. All rights reserved.