O(n) 复杂度实现最大回撤的计算

看到大多数计算最大回撤的代码都是 $O(n^{2})$ 的算法复杂度,其实最大回撤的计算用 $O(n)$ 的算法复杂度就能实现,只需对 $O(n^{2})$ 复杂度的代码稍作修改即可。

我们先来回顾下最大回撤的定义:

$最大回撤 = max(1-策略当日价值/当日之前账户最高价值)$

我们来生成一个数组,包含连续1000个时间点的账户价值

import matplotlib.pyplot as plt
import numpy as np
np.random.seed(1)
a = np.random.randn(1000)
values = np.cumsum(a)
plt.plot(values)
plt.show()

上述1000个时间点的账户价值画出来的资金曲线是这样的

$O(n^{2})$ 的实现方式

对于这样一个账户价值的序列计算最大回撤,通常的实现方式是这样的

def get_max_drawdown_slow(array):
    drawdowns = []
    for i in range(len(array)):
        max_array = max(array[:i+1])
        drawdown = max_array - array[i]
        drawdowns.append(drawdown)
    return max(drawdowns)

最外层的循环遍历数组中所有的元素。对于array中的每个元素array[i],找到这个元素之前所有元素的最大值:max_array = max(array[:i+1]),然后max_array - array[i]就代表了以i为截止点的所有回撤的最大值,把max_array加入数据drawdowns数组中。最大回撤就是drawdowns的最大值。

求最大值的复杂度是 $O(n)$,所以max(drawdowns)的复杂度 $O(n)$,每一次循环中计算max(array[:i+1])的复杂度是 $O(i+1)$, 所以总的算法复杂度是 $O(n^{2})$ 。

$O(n)$ 的实现方式

上述代码中,每次循环里面我们都计算max(array[:i+1]),相当于每次我们都遍历前i+1个数来求最大值。但其实

$前i + 1个数的最大值 = max(前i个数的最大值,第i + 1个数)$

用动态规划改进我们的代码。每次循环我们把迄今为止的最大值记下来,下次循环只需将当前元素值和之前记录的最大值比较一下,就能求得新的最大值。

改进后的代码如下:

def get_max_drawdown_fast(array):
    drawdowns = []
    max_so_far = array[0]
    for i in range(len(array)):
        if array[i] > max_so_far:
            drawdown = 0
            drawdowns.append(drawdown)
            max_so_far = array[i]
        else:
            drawdown = max_so_far - array[i]
            drawdowns.append(drawdown)
    return max(drawdowns)

max_so_far记录当前遍历过的元素的最大值,并被不断更新。每次计算最大值不再需要遍历之前的所有元素,只要比较max_so_far和array[i]即可。比较两个数值的复杂度是 $O(1)$,所以总的复杂度是 $O(n)$。

性能测试

我们来测试一下上面两个算法的性能

import timeit
print timeit.timeit('get_max_drawdown_slow(values)', setup="from __main__ import get_max_drawdown_slow, values", number=100)
print timeit.timeit('get_max_drawdown_fast(values)', setup="from __main__ import get_max_drawdown_fast, values", number=100)

每个函数运行100次,计算总时间,结果如下

2.69495010376
0.0435910224915

可以看到 $O(n)$ 算法有60倍的性能提升。

附:完整的源代码

import timeit
import matplotlib.pyplot as plt
import numpy as np

def get_max_drawdown_slow(array):
    drawdowns = []
    for i in range(len(array)):
        max_array = max(array[:i+1])
        drawdown = max_array - array[i]
        drawdowns.append(drawdown)
    return max(drawdowns)

def get_max_drawdown_fast(array):
    drawdowns = []
    max_so_far = array[0]
    for i in range(len(array)):
        if array[i] > max_so_far:
            drawdown = 0
            drawdowns.append(drawdown)
            max_so_far = array[i]
        else:
            drawdown = max_so_far - array[i]
            drawdowns.append(drawdown)
    return max(drawdowns)

np.random.seed(1)
a = np.random.randn(1000)
values = np.cumsum(a)
print timeit.timeit('get_max_drawdown_slow(values)', setup="from __main__ import get_max_drawdown_slow, values", number=100)
print timeit.timeit('get_max_drawdown_fast(values)', setup="from __main__ import get_max_drawdown_fast, values", number=100)
print get_max_drawdown_slow(values)
print get_max_drawdown_fast(values)
plt.plot(values)
plt.show()