《算法导论》笔记1-4

分治策略

最大子数组问题

假设让你买一支股票,并且你已经知道未来 17 天的走势,要求最大化收益。

股票走势

什么时候收益最大?当然是最低价买进,最高价卖出。如果这种策略有效的话,那么确定最大化收益的方法就是:

  1. 找到最高和最低价。
  2. 从最高价开始向左找最低价,从最低价开始向右找最高价。
  3. 取两对价格中差值最大的。

以上面走势为例就是:

  1. 最高价在第 2 天为 113 元,最低价在第 8 天为 63 元。
  2. [1, 2] 差值为 13 元,[8, 12] 差值为 43 元。

但是该策略有问题,比如:

股票走势

股票走势只给 5 天,按照上面的策略,应该买 [1, 2],差值是 1 元。但很明显最大收益是 [2, 3],差值为 3 元,所以不使用这种策略。

当然,可以使用暴力法,求出所有可能的买进卖出组合,只要卖出日期在买入日期即可。但是这种求法的运行时间太长,需要 Ω(n^2)。

问题转换:
为了设计出运行时间为 O(n^2) 的算法,我们就要从另一个角度看待输入的数据,就是从观察每日价格转为观察每日的价格变化,对于第 i 天的价格变化,定义为第 i 天和第 i - 1 天的价格差。

比如上面的每天股票走势,就转换成:

价格变化

如果把这行价格变化看做一个数组 A,那么求最大收益就转换成:在数组 A 中找出和最大的非空连续数组,这个连续子数组就称为最大子数组。
比如这里的最大子数组就是 A[8 … 11],它的和是 43。

只有当数组中包含负数时,最大子数组问题才有意义。(如果所有元素都是正数,那最大子数组显然就是整个数组)

使用分支策略的求解方法:
如果我们要找子数组 A[low … high] 的最大子数组,那么使用分治策略就意味着我们要将子数组划分成两个规模尽量相等的子数组 A[low … mid] 和 A[mid + 1 … high]。

对于 A[low … high] 来说,它的任意连续子数组 A[i … j] 所处的位置必然是以下三种情况:

  • 完全位于子数组 A[low … mid] 中,因此 low <= i <= j <= mid。
  • 完全位于子数组 A[mid + 1 … high] 中,因此 mid < i <= j <= high。
  • 跨越了中点,因此 low <= i <= mid < j <= high。

那么,最大子数组肯定也是这三种情况之一。即这三种情况中子数组和最大的就是最大子数组,所以可以递归求解 A[low … mid] 和 A[mid + 1 … high] 的最大子数组,然后再找跨越中点的最大子数组,最后在它们中选和最大的。

对于第三种情况,我们可以很容易地在线性时间(相对于子数组 A[low … high]的规模)内求出跨越中点的最大子数组,就是找出左区间最大和及其所到的下标,右区间同样操作。它们的解合并起来就是跨越中点时,最大子数组的起始和末尾位置,以及最大和。

跨越中点的最大子数组

找跨越中点的最大子数组的伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
find_max_crossing_subarray(A, low, mid, high)
left_sum = MIN
sum = 0
# 从 mid 到 low 找左区间的最大和,下标
for i in range(low, mid, -1)
sum += A[i]
if sum > left_sum
left_sum = sum
max_left = i
right_sum = MIN
sum = 0
# 找右区间的最大和,下标
for i in range(mid + 1, high)
sum += A[i]
if sum > right_sum
right_sum = sum
max_right = i
return (max_left, max_right, left_sum + right_sum)

有了线性时间的 find_max_crossing_subarray,就可以设计求解最大数组问题的分治算法的伪代码了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
find_max_subarray(A, low, high)
# 只有一个元素
if low == high
return (low, high, A[low])

mid = [(low + high) / 2]
left_low, left_high, left_sum = find_max_subarray(A, low, mid)
right_low, right_high, right_sum = find_max_subarray(A, mid + 1, high)
cross_low, cross_high, cross_sum = find_max_crossing_subarray(A, low, mid, high)

# 找到三种情况中最大的
if left_sum >= right_sum and left_sum >= cross_sum
return (left_low, left_right, left_sum)
elif right_sum >= left_sum and right_sum >= cross_sum
return (right_low, right_right, right_sum)
else
return (cross_low, cross_high, cross_sum)

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/**
* 找到跨越中点的最大子数组
*
* @param A 子数组A
* @param low 起始下标
* @param high 终止下标
* @return {起始下标, 终止下标, 最大和}
*/
int[] findMaxCrossingSubArray(int[] A, int low, int mid, int high) {
int leftSum = Integer.MIN_VALUE, maxLeft = 0;
int sum = 0;
for (int i = mid; i >= low; i--) {
sum += A[i];
if (sum > leftSum) {
leftSum = sum;
maxLeft = i;
}
}
int rightSum = Integer.MIN_VALUE, maxRright = 0;
sum = 0;
for (int i = mid + 1; i <= high; i++) {
sum += A[i];
if (sum > rightSum) {
rightSum = sum;
maxRright = i;
}
}
return new int[] { maxLeft, maxRright, leftSum + rightSum };
}

/**
* 找到最大子数组
*
* @param A 子数组A
* @param low 起始下标
* @param high 终止下标
* @return {起始下标, 终止下标, 最大和}
*/
int[] findMaxSubArray(int[] A, int low, int high) {
if (low == high)
return new int[] { low, high, A[low] };

int mid = (low + high) / 2;
int[] f1 = findMaxSubArray(A, low, mid);
int leftLow = f1[0], leftHigh = f1[1], leftSum = f1[2];

int[] f2 = findMaxSubArray(A, mid + 1, high);
int rightLow = f2[0], rightHigh = f2[1], rightSum = f2[2];

int[] f3 = findMaxCrossingSubArray(A, low, mid, high);
int crossLow = f3[0], crossHigh = f3[1], crossSum = f3[2];

if (leftSum >= rightSum && leftSum >= crossSum) {
return new int[] { leftLow, leftHigh, leftSum };
} else if (rightSum >= leftSum && rightSum >= crossSum) {
return new int[] { rightLow, rightHigh, rightSum };
} else {
return new int[] { crossLow, crossHigh, crossSum };
}
}

矩阵乘法的 Strassen 算法

矩阵乘法你应该了解过,下面是它的伪代码:

1
2
3
4
5
6
7
8
9
square_matrix_multiply(A, B)
n = A.rows
C = [n][n]
for i in range(1, n)
for j in range(1, n)
c[i][j] = 0
for k in range(1, n)
c[i][j] += a[i][k] * b[k][j]
return C

而使用 Strassen 算法求矩阵乘法只用 O(n^2.81) 的时间复杂度。

一个简单的分治算法:

为了简单说明,当使用分治算法计算 C = A x B 时,假定三个矩阵均为 n x n 矩阵,其中 n 为 2 的幂。
做出这个假设是因为每个分解步骤中,n x n 矩阵都被划分为 4 个 n/2 x n/2 的子矩阵,假定 n 是 2 的幂,那么只要 n >= 2 即可保证子矩阵规模 n/2 为整数。

假定将 A、B 和 C 均分解为 4 个 n/2 x n/2 的子矩阵:

1
2
3
4
5
6
7
8
9
10
11
12
A = |A11, A12|   B = |B11, B12|  C = |C11, C12|
|A21, A22| |B21, B22| |C21, C22|

因此可以将 C = A x B 改写成:
|C11, C12| = |A11, A12| x |B11, B12|
|C21, C22| |A21, A22| |B21, B22|

等价于:
C11 = A11 x B11 + A12 x B21
C12 = A11 x B12 + A12 x B22
C21 = A21 x B11 + A22 x B21
C22 = A21 x B12 + A22 x B22

利用上面的等价公式,为我们可以直接设计一个递归的分治算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
square_matrix_multiply_recursive(A, B)
n = A.rows
C = [n][n]
if n == 1
c[1][1] = A[1][1] x B[1][1]
else
# 将 A、B 和 C 划分成子矩阵,然后递归计算
C[1][1] = square_matrix_multiply_recursive(A[1][1], B[1][1])
+ square_matrix_multiply_recursive(A[1][2], B[2][1])
C[1][2] = square_matrix_multiply_recursive(A[1][1], B[1][2])
+ square_matrix_multiply_recursive(A[1][2], B[2][2])
C[2][1] = square_matrix_multiply_recursive(A[2][1], B[1][1])
+ square_matrix_multiply_recursive(A[2][2], B[2][1])
C[2][2] = square_matrix_multiply_recursive(A[2][1], B[1][2])
+ square_matrix_multiply_recursive(A[2][2], B[2][2])
Author

Zoctan

Posted on

2018-03-26

Updated on

2023-03-14

Licensed under