最坏情况时间为 O(n) 的二维峰值查找算法?

问题描述 投票:0回答:4

我正在学习麻省理工学院的this算法课程。在第一堂课中,教授提出了以下问题:-

二维数组中的峰值是一个值,它的所有 4 个邻居都小于或等于它,即。对于

a[i][j]
成为局部最大值,

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]

现在给定一个 NxN 2D 数组,在数组中找到一个峰值

通过迭代所有元素并返回峰值,可以在

O(N^2)
时间内轻松解决这个问题。

但是,可以通过使用分而治之的解决方案进行优化,以便在

O(NlogN)
时间内解决,如此处所述。

但是他们说存在一个

O(N)
时间算法可以解决这个问题。请建议我们如何在
O(N)
时间内解决这个问题。

PS(对于那些了解Python的人)课程人员在这里解释了一种方法(问题1-5.寻峰证明),并且还在他们的问题集中提供了一些Python代码。但所解释的方法完全不明显并且很难破译。 python 代码同样令人困惑。所以我复制了下面代码的主要部分,供那些了解 python 并且可以从代码中看出正在使用什么算法的人使用。

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
    # if it's empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0:
        return None

    subproblems = []
    divider = []

    if rowSplit:
        # the recursive subproblem will involve half the number of rows
        mid = problem.numRow // 2

        # information about the two subproblems
        (subStartR1, subNumR1) = (0, mid)
        (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
        (subStartC, subNumC) = (0, problem.numCol)

        subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
        subproblems.append((subStartR2, subStartC, subNumR2, subNumC))

        # get a list of all locations in the dividing column
        divider = crossProduct([mid], range(problem.numCol))
    else:
        # the recursive subproblem will involve half the number of columns
        mid = problem.numCol // 2

        # information about the two subproblems
        (subStartR, subNumR) = (0, problem.numRow)
        (subStartC1, subNumC1) = (0, mid)
        (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))

        subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
        subproblems.append((subStartR, subStartC2, subNumR, subNumC2))

        # get a list of all locations in the dividing column
        divider = crossProduct(range(problem.numRow), [mid])

    # find the maximum in the dividing row or column
    bestLoc = problem.getMaximum(divider, trace)
    neighbor = problem.getBetterNeighbor(bestLoc, trace)

    # update the best we've seen so far based on this new maximum
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
        bestSeen = neighbor
        if not trace is None: trace.setBestSeen(bestSeen)

    # return when we know we've found a peak
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
        if not trace is None: trace.foundPeak(bestLoc)
        return bestLoc

    # figure out which subproblem contains the largest number we've seen so
    # far, and recurse, alternating between splitting on rows and splitting
    # on columns
    sub = problem.getSubproblemContaining(subproblems, bestSeen)
    newBest = sub.getLocationInSelf(problem, bestSeen)
    if not trace is None: trace.setProblemDimensions(sub)
    result = algorithm4(sub, newBest, not rowSplit, trace)
    return problem.getLocationInSelf(sub, result)

#Helper Method
def crossProduct(list1, list2):
    """
    Returns all pairs with one item from the first list and one item from 
    the second list.  (Cartesian product of the two lists.)

    The code is equivalent to the following list comprehension:
        return [(a, b) for a in list1 for b in list2]
    but for easier reading and analysis, we have included more explicit code.
    """

    answer = []
    for a in list1:
        for b in list2:
            answer.append ((a, b))
    return answer
arrays algorithm data-structures multidimensional-array language-agnostic
4个回答
15
投票
  1. 假设数组的宽度大于高度,否则我们会向另一个方向分裂。
  2. 将数组分成三部分:中心列、左侧和右侧。
  3. 遍历中心列和两个相邻列并寻找最大值。
    • 如果它在中心柱中 - 这是我们的峰值
    • 如果它在左侧,则在子数组上运行此算法
      left_side + central_column
    • 如果它在右侧,则在子数组上运行此算法
      right_side + central_column

为什么这有效:

对于最大元素位于中心列的情况 - 显而易见。如果不是,我们可以从该最大值逐步增加元素,并且绝对不会穿过中心行,因此相应的一半中肯定会存在峰值。

为什么这是 O(n):

步骤 #3 需要小于或等于

max_dimension
次迭代,并且
max_dimension
每两个算法步骤至少减半。这给出了
n+n/2+n/4+...
,即
O(n)
。重要细节:我们按最大方向进行分割。对于方形阵列,这意味着分割方向将是交替的。这与您链接到的 PDF 中的上次尝试有所不同。

注意:我不确定它是否与您给出的代码中的算法完全匹配,它可能是也可能不是不同的方法。


3
投票

看到那个(n):

计算步骤如图

查看算法实现:

1) 从 1a) 或 1b) 开始

1a) 设置左半部分、分隔线、右半部分。

1b) 设置上半部分、分隔线、下半部分。

2)找到分频器上的全局最大值。 [θn]

3)找到其邻居的值。 并将曾经访问过的最大节点记录为bestSeen节点。 [θ1]

# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
    bestSeen = neighbor
    if not trace is None: trace.setBestSeen(bestSeen)

4)检查全局最大值是否大于 bestSeen 及其邻居。 [θ1]

//第4步是该算法起作用的关键

# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
    if not trace is None: trace.foundPeak(bestLoc)
    return bestLoc

5) 如果 4) 为 True,则返回全局最大值作为 2-D 峰值。

否则如果这次做了1a),选择BestSeen的一半,返回步骤1b)

否则,选择BestSeen的一半,返回步骤1a)


要直观地看到这个算法为什么有效,就像抓住最大价值的一边,不断缩小边界,最终得到BestSeen值。

#可视化模拟

第一轮

第二轮

第3轮

第四轮

第5轮

第6轮

终于

对于这个10*10的矩阵,我们只用了6步就找到了二维峰值,非常有说服力地证明它确实是theta n


猎鹰


1
投票

这是实现 @maxim1000 算法的工作 Java 代码。以下代码在线性时间内找到二维数组中的峰值。

import java.util.*;

class Ideone{
    public static void main (String[] args) throws java.lang.Exception{
        new Ideone().run();
    }
    int N , M ;

    void run(){
        N = 1000;
        M = 100;

        // arr is a random NxM array
        int[][] arr = randomArray();
        long start = System.currentTimeMillis();
//      for(int i=0; i<N; i++){   // TO print the array. 
//          System. out.println(Arrays.toString(arr[i]));
//      }
        System.out.println(findPeakLinearTime(arr));
        long end = System.currentTimeMillis();
        System.out.println("time taken : " + (end-start));
    }

    int findPeakLinearTime(int[][] arr){
        int rows = arr.length;
        int cols = arr[0].length;
        return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
    }

    // helper function that splits on the middle Column
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loCol==hiCol){
            int max = arr[loRow][loCol];
            int foundRow = loRow;
            for(int row = loRow; row<=hiRow; row++){
                if(max < arr[row][loCol]){
                    max = arr[row][loCol];
                    foundRow = row;
                }
            }
            if(!correctPeak(arr, foundRow, loCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return max;
        }
        int midCol = (loCol+hiCol)/2;
        int max = arr[loRow][loCol];
        for(int row=loRow; row<=hiRow; row++){
            max = Math.max(max, arr[row][midCol]);
        }
        boolean centralMax = true;
        boolean rightMax = false;
        boolean leftMax  = false;

        if(midCol-1 >= 0){
            for(int row = loRow; row<=hiRow; row++){
                if(arr[row][midCol-1] > max){
                    max = arr[row][midCol-1];
                    centralMax = false;
                    leftMax = true;
                }
            }
        }

        if(midCol+1 < M){
            for(int row=loRow; row<=hiRow; row++){
                if(arr[row][midCol+1] > max){
                    max = arr[row][midCol+1];
                    centralMax = false;
                    leftMax = false;
                    rightMax = true;
                }
            }
        }

        if(centralMax) return max;
        if(rightMax)  return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
        if(leftMax)   return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
        throw new RuntimeException("INCORRECT CODE");
    }

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loRow==hiRow){
            int ans = arr[loCol][loRow];
            int foundCol = loCol;
            for(int col=loCol; col<=hiCol; col++){
                if(arr[loRow][col] > ans){
                    ans = arr[loRow][col];
                    foundCol = col;
                }
            }
            if(!correctPeak(arr, loRow, foundCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return ans;
        }
        boolean centralMax = true;
        boolean upperMax = false;
        boolean lowerMax = false;

        int midRow = (loRow+hiRow)/2;
        int max = arr[midRow][loCol];

        for(int col=loCol; col<=hiCol; col++){
            max = Math.max(max, arr[midRow][col]);
        }

        if(midRow-1>=0){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow-1][col] > max){
                    max = arr[midRow-1][col];
                    upperMax = true;
                    centralMax = false;
                }
            }
        }

        if(midRow+1<N){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow+1][col] > max){
                    max = arr[midRow+1][col];
                    lowerMax = true;
                    centralMax = false;
                    upperMax   = false;
                }
            }
        }

        if(centralMax) return max;
        if(lowerMax)   return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
        if(upperMax)   return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
        throw new RuntimeException("Incorrect code");
    }

    int[][] randomArray(){
        int[][] arr = new int[N][M];
        for(int i=0; i<N; i++)
            for(int j=0; j<M; j++)
                arr[i][j] = (int)(Math.random()*1000000000);
        return arr;
    }

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
        if(row-1>=0 && arr[row-1][col]>arr[row][col])  return false;
        if(row+1<N && arr[row+1][col]>arr[row][col])   return false;
        if(col-1>=0 && arr[row][col-1]>arr[row][col])  return false;
        if(col+1<M && arr[row][col+1]>arr[row][col])   return false;
        return true;
    }
}

0
投票

抱歉,由于我还没有50声望,所以我无法评论最佳答案(@maxim1000的算法),所以我使用答案部分来提出我的问题。

在我想出的示例数组上尝试此算法后,该算法似乎无法正确识别这种情况下的峰值。

考虑以下 5x5 数组:

    0   1   2   3   4
    ------------------
0 |  9   19  20  31  64
1 |  49  72  71  51  44
2 |  5   141 95  6   7
3 |  35  36  27  28  50
4 |  10  88  29  189 21

第一次迭代: 该算法选择第 2 列作为中间列,并在第 1、2 和 3 列中查找最大值。第 3 列中的最大值为 189。因此,算法继续前进到该子数组的

right_side + central_column
,创建一个 5x3 子数组并递归地对其应用算法。

    0   1   2
    -----------
0 |  20  31  64
1 |  71  51  44
2 |  95   6   7
3 |  27  28  50
4 |  29  189 21

第二次迭代: 该算法现在选择第 2 行作为中间行,并在第 1、2 和 3 行中查找最大值。第 2 行中的最大值为 95。由于最大值位于中间行,因此该算法错误地将 95 视为峰值,即使尽管它不是数组中的实际峰值。

预先感谢您的任何澄清或建议!

© www.soinside.com 2019 - 2024. All rights reserved.