6713 字
34 分钟
如何将微分方程写成python看得懂的函数?开始学习计算机求解微分方程的暴力美学 (初识数值实验·上)
2026-04-24
2026-05-15
统计加载中…… 阅读:    访客: 统计加载失败

前言#

本教程以 scipy.integratecupy 为例。以你会基本的 Python 编程和常微分方程基础为前提,以求解常微分方程(ODE)的初值问题为线索。我们将从科学计算的标准接口 solve_ivp 入手,教你如何把数学方程翻译成计算机看得懂的函数;然后亲手拆解黑盒,用 NumPy 实现经典的四阶龙格‑库塔法(RK4);最后以双摆系统的大规模模拟为例,用 CuPy 把计算搬上显卡,体验数值求解的暴力美学。学完本文,你将彻底掌握从单条轨迹求解理论到上万组轨迹并行加速的全过程。

ps: 本人并非相关专业,仅面向代码和交叉学科基础,如有不严谨还请见谅。

本期的源代码在以下仓库:

emumus
/
solve_ivp
Waiting for api.github.com...
00K
0K
0K
Waiting...

solve_ivp 标准的格式#

本节的目标:

  1. 学会 solve_ivp 函数的用法。
  2. 学会把数学微分方程改写成它的形式。

在深入研究如何用 GPU 加速或手动实现高级算法之前,我们必须先理解 Python 科学计算社区定义的第一协议: scipy.integrate.solve_ivp

solve_ivp 的接口定义是所有微分方程数值解的通用语言,也是描述一个微分方程最通用的标准格式。对大多数混作业的大学生来说,学到这里其实就够了。

在SciPy中,你只需要调用:

from scipy.integrate import solve_ivp
sol = solve_ivp(derivative_func, [t_start, t_end], y0, args=(alpha,))

就能在 sol.y[0] 中获得从 t_startt_end 期间,以f(tstart)=y0f(t_{start})=y_0的共几个点处的数值解。当然了,计算机不会像人一样生龙活虎地解微分方程,人家硬算的。

把数学微分方程转为python函数#

计算机不认识微分,只认识当点的变化率。

在数学纸面上,我们习惯写 dydt=f(y,t)\frac{d\mathbf{y}}{dt} = f(\mathbf{y}, t)。但在编程中,求解器本质上只会机械地(黑盒地)调用你写的函数。它只向你要一件东西:给定当前的时间 tt 和当前的状态 y\mathbf{y},请告诉我此刻的斜率(导数)是多少?

这个返回斜率的函数,就是我们需要编写的pyhton函数。你的函数必须遵循以下结构:

def derivative_func(
t, # 时间,更熟悉的说法是y=f(x)的x,自变量,求解范围指的就是这个数的范围。
state, # 状态,更熟悉的说法是y=f(x)的y,因变量。注意这里可以且必须是一个向量。
*args # 方程的其他外部参数会被这样传进来
):
# 1. 解构状态变量
y1, y2 = state # 如两个
# 2. 根据数学公式计算变化率
dy1_dt = ... # 可以自行用dy和t来计算
dy2_dt = ...
# 3. 返回变化率列表或数组
return [dy1_dt, dy2_dt, ...]

我们来看具体的例子:

示例 :单变量情况#

给定衰减系统微分方程

dydt=αy\frac{dy}{dt} = -\alpha y

其中 α\alpha 是衰减系数。

左侧的 dydt\frac{dy}{dt} 正是我们需要返回的变化率变化率的定义就是-的因变量状态 ×\times 系数 α\alpha 。同时注意因变量状态的参数 y 总是一个列表。故要写 y[0] 。返回的也必须是一个变化率列表 [dydt]

那么就可以写:

def decay_system(t, state, alpha):
dydt = -alpha * state[0] # 解包免了,我们就直接用state[0]
return [dydt]

注意: ty 的位置是确定的,alpha 以及后续可能出现的其他外部变量可以直接排列到函数参数中,不久之后通过 solve_ivp 传入,solve_ivp 调用你的函数的时候会自动把参数填进去的。

solve_ivp(decay_system, [0, 10], 100, args=(0.5,))
TIP

出于业界习惯标准,我们现在应该开始接受将自变量 xx 称为时间维度 tt,因变量 yy 称为状态变量。后续将涉及多个状态之间相互影响和高阶互相影响的情况,这套称呼的优点将逐渐显现。

示例 :多变量耦合#

当系统有多个变量时,yy 就变成了一个状态向量。(骗你的只有一个你也得写成数组)

举例来说,洛伦兹方程的数学表达:

{dxdt=σ(yx)dydt=x(ρz)ydzdt=xyβz\begin{cases} \dfrac{dx}{dt} = \sigma\,(y - x)\\[4pt] \dfrac{dy}{dt} = x\,(\rho - z) - y\\[4pt] \dfrac{dz}{dt} = xy - \beta\,z \end{cases}

不同的状态之间互相影响在混在一起的向量状态在这种格式中是很方便随意互相取用的。同时又因为python里有很方便的解包语法。

def lorenz_partial(t, state, sigma, rho, beta):
x, y, z = state
dxdt = sigma * (y - x)
dydt = x * (rho - z) - y
dzdt = x * y - beta * z
return [dxdt, dydt, dzdt]
NOTE

为什么 solve_ivp 就知道 state 是三元组呢?其实 solve_ivp 会先根据你传入的初值数组长度确定状态个数,之后直接把初值作为state塞进你的函数,不管你的函数实际上解包多少参数。哈基py感觉不对会自己报错的(not enough values to unpack (expected m, got n))。所以一定要自己保证初值数组匹配函数的状态数量哦。

示例 :高阶常微分方程#

solve_ivp 只能求解一阶常微分方程组。如果你遇到的是高阶方程(比如二阶、三阶),必须先把它拆解成一阶方程组。这个过程十分简单粗暴:

d2ydt2=f(t,y,dydt)v=dydt,{dydt=vdvdt=f(t,y,v)\frac{d^2 y}{dt^2} = f(t, y, \frac{dy}{dt}) \\ \text{设}\,v = \frac{dy}{dt} ,\quad \begin{cases} \frac{dy}{dt} = v \\[4pt] \frac{dv}{dt} = f(t, y, v) \end{cases}

举例来说,简谐振荡器的数学表达

d2xdt2=kmxv=dxdt,{dxdt=vdvdt=kmx\frac{d^2 x}{dt^2} = -\frac{k}{m} x \\ \text{设}\,v = \frac{dx}{dt},\quad \begin{cases} \frac{dx}{dt} = v \\[4pt] \frac{dv}{dt} = -\frac{k}{m} x \end{cases}

这实际上就是两个互相干扰的一阶微分方程。xx 的变化率由 vv 决定,vv 的变化率由 xx 决定,何尝又不是一对苦命鸳鸯。

def harmonic_oscillator(t, state, k, m):
x, v = state
dxdt = v
dvdt = -(k / m) * x
return [dxdt, dvdt]
TIP
  1. 即使你的方程中没有显式出现时间(自变量) tt(即自治系统),在定义函数时也必须保留 t 作为第一个参数,这是为了兼容那些随时间变化的非自治系统。

  2. solve_ivp 是为通用CPU计算设计的。当你需要同时模拟一百万组不同的轨迹(比如蒙特卡洛模拟或大规模粒子系统)时,solve_ivp 会慢得绝望。但只要你有一张普通的消费级显卡,再结合后续章节的一些知识,某些任务上加速百倍起步完全不是问题

求解器 solve_ivp 的完全版用法#

solve_ivp 的完整签名如下:

solve_ivp(
fun, # 上面所说格式的函数
t_span, # t的左右界,如[0, 100]
y0, # 左界初值 *向量*
method='RK45', # 更多优化器待你解锁……不清楚的话就先用RK45吧!
t_eval=None,
dense_output=False,
events=None,
vectorized=False,
args=None,
**options
)
  • fun: 如上定义的函数。
  • t_span: 二元组 (t0, tf)。允许 tf < t0 反向积分。
  • y0: 初始条件。形状为 (n,) 的一维数组。
  • method: 选择求解器。刚性问题建议使用 'Radau''BDF'
  • t_eval: 指定需要输出结果的时间点。
    • 这与 t_span 并不冲突,t_span 确定它实际求解计算的范围,t_eval 指定我所需要的那些值。
    • 举例来说,我需要[1,2][1,2]的步长 0.1 的值,但我只有在0处的初值,则我会传入求解的范围 t_span=[0,2] 而所需的值 t_eval=np.arange(1, 2, 0.1)
    • 求解器仍会从0正常地算到2,只是最后只返回1到2之间的点。
  • dense_output: 启用后会在内存里保存一些求解信息,允许随时通过 sol.sol(t) 获取任意 ttspant\in t_{\mathrm{span}} 的解。
  • events: 检测事件函数 event(t, y, *args) 的零点。可设置 terminal=True 停止积分。
  • vectorized: 若 fun 能同时接受多个 y 列(形状 (n, k) ),设为 True 可加速(仅限部分方法)。
  • args: 额外参数的元组,按顺序传递给 funevents。即微分方程函数中 tstate 之后的参数。
  • 部分方法有自己独特的额外options,如 atolrtol 等等。

求解器 solve_ivp 返回的的对象#

solve_ivp 返回 Bunch 对象,以下字段较为常用:

  • t: ndarray: 返回各个求得值的时间点,形状 (n_points,)
  • y: ndarray: 在各个t处的解的值。形状 (n, n_points)
  • sol: OdeSolution|None: 找到的解作为 OdeSolution 实例。如果 dense_output 设置为 False,则为 None。
  • status: 算法终止的原因:
    • -1:积分步长失败。
    • 0:求解器成功到达 tspan 的终点。
    • 1:发生了终止事件。
  • message: str: 可读的终止原因描述。
  • success: bool: 如果求解器达到区间终点或发生终止事件(status >= 0),则为 True。

复习微分方程以及RK4原理#

现在,我们将打破黑盒,用 NumPy 手动实现 RK4 算法,手动复现一个青春版 solve_ivp

计算机是如何解微分方程的?#

在高数课上,我们解微分方程是靠各种模式背方法找原函数模式。但计算机可没那么聪明(至少在数值计算领域不是)。数值求解的核心思想极其简单粗暴:切线迭代(或者叫步进)

想象一下,你站在一张空白的画布上。你想画出这个 y(t)y(t)。你知道这个函数过初值点(0,0)(0, 0)

你手里有一个魔法罗盘(也就是那个返回导数的 Python 函数)。只要你告诉罗盘你当前的时间 tt 和位置 yy,罗盘就会立刻告诉你这一个距离下的斜率 dydt\frac{dy}{dt}

欧拉法(Euler Method):最天真的走法

你问一次罗盘,得到方向,然后沿着这个方向直勾勾地往前跨一大步(hh)。

yn+1=yn+hf(tn,yn)y_{n+1} = y_n + h \cdot f(t_n, y_n)

这就是最基础的欧拉法。但问题很明显:真实的曲线大概率是二阶可导的,你沿着切线直走,只要步子迈得稍微大一点,下一步落地时你就会偏离真实的曲线。随着时间推移,误差会疯狂累积,直到你的结果离题万里。

龙格-库塔法(RK4):货比四家#

为了解决直走偏航的问题,数学家们发明了经典的四阶龙格-库塔法(Runge-Kutta 4th Order Method,简称 RK4)。

RK4 的核心哲学是:不要急着迈步,先在这个步长 hh 范围内多探几个点的斜率,综合起来再走。

具体来说,它会在每一次跨步前,试探性地计算 4 个斜率(通常记为 kk):

  1. k1k_1(起点斜率): 在当前位置问罗盘得到的真实斜率。

    k1=f(tn,yn)k_1 = f(t_n, y_n)

  2. k2k_2(中点试探斜率1): 假装用 k1k_1 往前走半步,到达中点,问罗盘这个中点的斜率。

    k2=f(tn+h2,yn+hk12)k_2 = f(t_n + \frac{h}{2}, y_n + h\frac{k_1}{2})

  3. k3k_3(中点试探斜率2): 刚才用 k1k_1 走半步可能不太准,这次我们用刚求出的 k2k_2 重新走到半步的位置,再问罗盘获取斜率。

    k3=f(tn+h2,yn+hk22)k_3 = f(t_n + \frac{h}{2}, y_n + h\frac{k_2}{2})

  4. k4k_4(终点试探斜率): 最后,用 k3k_3 往前走完整整一步到达终点,在这个预测的终点处,再问最后一次罗盘。

    k4=f(tn+h,yn+hk3)k_4 = f(t_n + h, y_n + h k_3)

最终合并:

拿到这四个斜率后,RK4 给它们赋予不同的权重(中点的斜率最重要,权重是两头的两倍),求出一个平均斜率,然后用这个平均斜率迈出最终的一步:

yn+1=yn+h6(k1+2k2+2k3+k4)y_{n+1} = y_n + \frac{h}{6}(k_1 + 2k_2 + 2k_3 + k_4)

就是这么简单。这四个斜率线性组合,巧妙地抵消了泰勒展开四阶的误差。

动手实现青春版RK4求解器#

这里我们借助 NumPy 的向量化运算,让它可以同时处理上一节提到的多变量状态向量。

import numpy as np

适配 solve_ivp 的解对象的格式。我们先只兼容直接取 ty

class OdeSolution:
def __init__(self, t, y):
self.t = t
self.y = y

先实现一个单步的函数,我们以此通过 tty(x0)\mathbf{y}(x_0) 计算 y(x0+h)\mathbf{y}(x_0+h)。逻辑就是刚刚提到的斜率线性组合。

def rk4_step(fun, t, y, h, args):
"""
单步RK4
:param fun: 你的微分方程函数,签名为 fun(t, y, *args)
:param t: 当前时间
:param y: 当前状态 (numpy数组)
:param h: 步长
"""
k1 = np.array(fun(t, y, *args))
k2 = np.array(fun(t + h / 2, y + h * k1 / 2, *args))
k3 = np.array(fun(t + h / 2, y + h * k2 / 2, *args))
k4 = np.array(fun(t + h, y + h * k3, *args))
y_next = y + (h / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
return y_next

然后简单地从y0开始,以h为步长循环就可以了。唯一要说的是我们提前构建了一个二维数组(其实是np矩阵) y_history,以列为单位,每次将上次用过的状态向量拿出来,算好后放到下一个列。

def my_solve_ivp(fun, t_span, y0, args, h=1e-3):
"""
青春版rk4求解器
"""
t_start, t_end = t_span
# 计算需要的步数
n_steps = int(np.ceil((t_end - t_start) / h))
# 生成要计算的所有时间点
t_eval = np.linspace(t_start, t_end, n_steps + 1)
# y_history 形状为 (状态数, 时间点数),与 scipy 标准保持一致
y_history = np.zeros((len(y0), n_steps + 1))
# 初值是这么用的
y_history[:, 0] = y0
current_y = np.array(y0)
for i in range(n_steps):
current_t = t_eval[i] # 每次拿取下一个t值
# 更新状态
current_y = rk4_step(fun, current_t, current_y, h, args)
y_history[:, i + 1] = current_y
return OdeSolution(t_eval, y_history)

有了这段代码,你就彻底摆脱了对 scipy.integrate 的依赖。虽然它功能单一(只支持固定步长 hh),但它的逻辑完全透明,并且是后续我们向 GPU 加速进军的基石。

WARNING

这个手搓的版本与 solve_ivp 相比多一个重要的默认参数 h=1e-3

步长 hh 需要你自己设定。hh 如果太大,误差就会爆炸;hh 如果太小(比如 10810^{-8}),不仅计算极其缓慢,还会因为计算机浮点数精度限制产生严重的舍入误差。

各种方法步长在一个黑洞方程上的表现
各种方法步长在一个黑洞方程上的表现

TIP

此处用于测试的方程 y=y2x/yy' = y - 2x/y 在初值 y(0)=1y(0)=1 在数学上有一个看起来很完美的特解 2x+1\sqrt{2x+1},但在计算机眼中,这个解就像是走钢丝。

但凡迭代过程稍微出现一点误差,它的通解是 y=1+2x+Ce2xy=\sqrt{1+2x+Ce^{2x}}。但凡开始有一点误差,指数就会爆出根号冲向天空。可以看到即使是 scipy.integrate.solve_ivp 的RK45也倒在了 x=3x=3 左右。

简单扩展:有没有RKn?#

有的有的。并且还有一个叫布彻尔表(Butcher Tableau)的东西,是数学家们总结出的一套系数,就是它指明了满足各阶精度的阶条件。这里不涉及太多数学概念。当然越高者系数计算量越大,并且一般方程不会有像 y=y2x/yy' = y - 2x/y 这样的走钢丝函数一样的性质。所以工程权衡下就有了RK45这个优化方法。

推荐拓展阅读:龙格库塔方法与它的布彻表.py - 知乎

简单扩展:RK45 又是什么?#

在使用 solve_ivp 时,默认的 method='RK45'(也被称为 Dormand-Prince 法)经常让人迷惑:它到底是 4 阶还是 5 阶?

它同时是 4 阶和 5 阶。

在现实的物理模拟中,系统有时变化剧烈(比如弹簧突然崩开),有时又风平浪静。如果一直用固定的步长 hh,在剧烈变化时误差极大,在平静时又浪费算力。

自适应步长是RK45的灵魂:

  • 它在每一步计算时,不仅算出一个 4 阶精度的预估值 y4y_4,还会额外算出一个 5 阶精度的预估值 y5y_5
  • 计算机把这两个结果拿来相减,得到一个误差估计值 error=y5y4error = |y_5 - y_4|
  • 如果 errorerror 超出了你容忍的阈值(比如上一节提到的 rtolatol),说明当前步长太大,求解器会拒绝这一步,把步长 hh 减小,重新算一次。
  • 如果 errorerror 极小,说明现在系统很平稳,求解器会偷偷把下一步的 hh 放大,从而飞速跨过平稳区,为你节省大量时间。

这就是为什么 solve_ivp 非常通用且鲁棒我喜欢鲁棒这个词,它抽象得令人记忆深刻。但在我们需要极致并行性能时,带有 if-else 判断的自适应步长算法在 GPU 上反而容易成为性能瓶颈,这也是我们为什么要在后续章节坚持回归纯粹固定步长 RK4 的原因。

显卡计算(听起来就专业)#

在前一章中,我们已经拥有了一个逻辑透明的青春版 RK4 求解器。对于求解单一物体的运动轨迹,现代 CPU 的单核性能绰绰有余,跑完几万步迭代也就是眨眼间的功夫。

但如果我们面对的是大规模的并行模拟呢?

WARNING

本章主要内容是代码,说俗了就是教你调库,晕代码的同学可以跳,AI写的也不错。

复杂耦合方程的噩梦:双摆系统#

为了直观地展示算力瓶颈,我们请出物理学界著名混沌系统双摆。想必很多人已经见识到了最近一些科普视频展示的双摆美丽的动态相空间示意图。

双摆系统由两个单摆首尾相连构成。尽管它的结构看似简单,但根据拉格朗日力学推导出的微分方程却极其复杂,且各状态变量(两个角度 θ1,θ2\theta_1, \theta_2 和两个角速度 ω1,ω2\omega_1, \omega_2)之间深度耦合。稍微改变一点点初始角度,其后续的运动轨迹就会天差地别。

{Δ=θ1θ2D=169cos2Δθ˙1=6ml2D(2p13cosΔp2)θ˙2=6ml2D(8p23cosΔp1)p˙1=12ml(3gsinθ1+lθ˙1θ˙2sinΔ)p˙2=12ml(gsinθ2lθ˙1θ˙2sinΔ)\begin{cases} \Delta &= \theta_1 - \theta_2 \\ D &= 16 - 9\cos^2\Delta \\ \dot{\theta}_1 &= \frac{6}{ml^2 D}\left(2p_1 - 3\cos\Delta \, p_2\right) \\ \dot{\theta}_2 &= \frac{6}{ml^2 D}\left(8p_2 - 3\cos\Delta \, p_1\right) \\ \dot{p}_1 &= -\frac{1}{2}ml\left(3g\sin\theta_1 + l\dot{\theta}_1\dot{\theta}_2\sin\Delta\right) \\ \dot{p}_2 &= -\frac{1}{2}ml\left(g\sin\theta_2 - l\dot{\theta}_1\dot{\theta}_2\sin\Delta\right) \end{cases}

用 Python 将其微分方程写出来,大致是这样的:

def double_pendulum_cp(t, group_state, m, l, g):
# 沿着列切片,拿到 N 个双摆的对应状态。截取出的变量形状皆为 (N,)
theta1, theta2, p1, p2 = group_state
# 重用量
delta = theta1 - theta2
cos_delta = np.cos(delta)
inv_D = 1.0 / (16.0 - 9.0 * cos_delta * cos_delta)
# 角速度
factor = 6.0 / (m * l * l) * inv_D
theta1_dot = factor * (2.0 * p1 - 3.0 * cos_delta * p2)
theta2_dot = factor * (8.0 * p2 - 3.0 * cos_delta * p1)
# 动量导数
pre = -0.5 * m * l # -1/2 m l
c1 = l * theta1_dot * theta2_dot * np.sin(delta)
p1_dot = pre * (3.0 * g * np.sin(theta1) + c1)
p2_dot = pre * (g * np.sin(theta2) - c1)
# 构建为 (4, N)
return np.array([theta1_dot, theta2_dot, p1_dot, p2_dot])

CPU 的叹息#

如同视频里那样,假设我们想研究:在不同的初始角度 (θ1,θ2)(\theta_1, \theta_2) 下,双摆在 10 秒后会处于什么状态?

为此,我们需要在 [π,π][-\pi, \pi] 的范围内,均匀划分出一个 n×nn \times n 的网格。哪怕只有 n=50n=50,这也意味着我们需要同时求解 2500 个独立但计算密集的微分方程

如果使用传统的 for 循环:

results = []
for th1 in np.linspace(-np.pi, np.pi, 50):
for th2 in np.linspace(-np.pi, np.pi, 50):
sol = solve_ivp(
double_pendulum, (0.0, 10.0),
[th1, th2, 0.0, 0.0],
h=1e-2, args=args,
)
results[(th1, th2)] = sol.y
progress.update(1)

你可以切换窗口去稍微摸一会儿鱼了。在 Python 中,解释器执行这个双重循环,并且在最内层频繁调用数学函数,效率是极其低下的。

跑完这 2500 条轨迹,可能需要十几分钟。如果你的网格扩大到 1000×10001000 \times 1000,那基本就可以下班了。

这时候,就该让 GPU 出场了。

怎么用我的显卡?#

在 Python 科学计算界,想用 GPU 加速有很多流派:

  • JAX:Google 出品,支持自动微分和 XLA 编译,非常强大。但它的生态深度绑定 Linux,在 Windows 上配置 CUDA 环境时常让人抓狂。
  • PyTorch / TensorFlow:虽然是深度学习框架,但本质上也是张量计算库,完全可以用来做数值积分。只是为了解个方程引入这么庞大的库,略显臃肿。
  • Taichi / Numba:需要你手动编写 Kernel 代码,学习成本较高。

我们以 CuPy 为例教学。它的API对标 Numpy,说不定是最平易近人的,但肯定是最方便我们理解和使用的。

它的核心理念可以用一句话概括:它是运行在 NVIDIA 显卡上的 NumPy。 只要你的电脑有一张 N 卡,pip install cupy 之后,把代码里的 import numpy as np 换成 import cupy as cp,绝大多数数组操作就能直接获得成百上千倍的加速。本质上,CuPy 在底层自动帮你把 NumPy 的广播操作编译成了 CUDA Kernel 并在几千个流处理器上并发执行。

TIP

示例代码里已经写好能用的画图函数了,你可以直接用哦

配置cuda和cupy本文略。这又是一个分硬件分人的事。

NVIDIA提供的cuda进行数值计算,其他卡也提供了各自家显卡用于数值计算的库,且几乎都有python的包装,逻辑大同小异。

新增维度来代替循环#

我们的循环基于 CPU 对代码逐个进行计算。GPU 不知道什么是循环,但它知道什么是矩阵。它为矩阵生,为矩阵死。

import cupy as cp # 导入cupy,它的API与numpy类似。

解 1 个双摆和解 2500 个双摆的逻辑是一模一样的,只是数组的维度变长了而已。我们将 2500 个 [th1, th2, w1, w2] 拼成一个形状为 (2500, 4) 的巨大二维矩阵,一次性喂给方程。

double_pendulum_cprk4_step 等函数的改动不提,简单把 npcp 即可。只要注意量级:

def double_pendulum_cp(t, group_state, m, l, g):
"""
注意:这里的 group_state 形状是 (4, N),N 是并行的双摆数量。
CuPy 会自动对 N 这个维度进行并行广播运算。
"""
# 将所有np换为cp

我们的 solve_ivp 也要进行改动,主要是殺了循环

def my_solve_ivp_cupy(
fun, t_span, y0, h, args=(), step_func=rk4_step_cupy, show_progress=None
):
t_start, t_end = t_span
y0 = cp.asarray(y0)
n_steps = int(np.ceil((t_end - t_start) / h))
t_eval = np.linspace(t_start, t_end, n_steps + 1)
# y_history 形状为 (状态数, 组数, 时间点数)
y_history = cp.zeros((y0.shape[0], y0.shape[1], n_steps + 1), dtype=y0.dtype)
y_history[:, :, 0] = y0
current_y = y0
iterator = range(n_steps)
if show_progress:
iterator = tqdm(iterator)
for i in iterator:
current_t = t_eval[i] # 每次拿取下一个t值
# 更新状态
current_y = step_func(fun, current_t, current_y, h, args)
y_history[:, :, i + 1] = current_y
return GroupedOdeSolution(t_eval, y_history)

之后我们就可以组装大量的初值向量 [th1, th2, p1, p2] 了。

# 生成 50x50 网格初始角度
theta1_vals = np.linspace(-np.pi, np.pi, 50)
theta2_vals = np.linspace(-np.pi, np.pi, 50)
aTH1, aTH2 = np.meshgrid(theta1_vals, theta2_vals, indexing="ij")
theta1_init = aTH1.ravel() # (2500,)
theta2_init = aTH2.ravel()
p1_init = np.zeros_like(theta1_init)
p2_init = np.zeros_like(theta1_init)
# 组装批量初始状态 (4, N)
y0_np = np.array([theta1_init, theta2_init, p1_init, p2_init])
y0_cp = cp.asarray(y0_np)
sol = solve_ivp(
double_pendulum_cp,
t_span=(0.0, 10.0),
y0=y0_cp,
h=1e-3,
step_func=rk5_step_cupy,
args=args,
show_progress=True,
)
results = {}
for i, y in enumerate(sol.y):
results[(i // 50, i % 50)] = y
plot_dpmat(results, args, interval=1)

在主流的消费级显卡(例如 RTX 3060 或 4070)上,完成这 2500 个双摆长达 10000 步(总计两千五百万次状态更新)的计算,通常只需要几秒。—— Gemini

RK5,1e-3,0~10,100x100. GTX2060 6G,cupy. 42s.
RK5,1e-3,0~10,100x100. GTX2060 6G,cupy. 42s.

RK5,1e-2,0~10,500x250. GTX2060 6G,cupy. 5s.
RK5,1e-2,0~10,500x250. GTX2060 6G,cupy. 5s.

TIP

因为100x100和250x500在我这里都没有超过一个批次的计算量,到显卡里也只是一瞬间的功夫,所以模拟规模在这里影响甚微。起主要作用的是步长:

图一的参数下共迭代 1e4 次,图二 1e3 次,于是时间差了约十倍。

关于显卡计算的一些坑与废话#

上了显卡,虽然爽,但有几个重要的事必须交代清楚:

TIP

1. 内存和显存之间的数据搬运一事

不要在 for 循环内部频繁地将 CuPy 数组拉回 NumPy 数组(即把数据从显存搬回内存)。CPU 和 GPU 之间的数据传输通道(PCIe 总线)相对而言是非常慢的。

正确的做法是:让数据在显卡里一直待到所有迭代结束,最后只把需要的最终结果(或者稀疏采样的历史轨迹)传回内存。

2. 为什么我们坚持用固定步长的 RK4,而不是把 RK45 搬上 GPU?

GPU 最喜欢的是成千上万个线程做完全一样的事情。它是典型的单指令多线程架构。

RK45 的自适应步长里充满了 if (error > tol) 的条件分支。如果 2500 个双摆中,有 1 个发生了剧烈翻滚需要缩小步长重算,那么显卡为了保持同步,会让另外 2499 个乖乖等待。这种现象叫做分支发散,它会瞬间摧毁 GPU 的并行效率。

3. 内存/显存爆了怎么办!

当然,如果你要计算的边长更大,你可能还要进行分块和手动显卡渲染。

分块,顾名思义,组装 y0_cp 之前先检查一下有多少组方程被同时计算,如果你的显存承受不了这个量级,那就先进行一次内存搬运,搬出去一批再算另一批,你的内存总不能比显存低吧……是这样吧?如果内存真的还不够的话,那可能要苦一下你的硬盘了,NumPy提供了高效的 .npy 格式用于存储NumPy数组。记得把你不需要的数据洗出啦丢掉,比如公式中的广义动量。

手动显卡渲染,本章示例代码中提供了基础的CPU色彩渲染,每帧会根据角度动态计算每个色块的颜色,所以用FuncAnimation播放会很卡。如果把所有帧的颜色数据先用显卡算好,再导出为视频文件,边长上限会更高。

4. 精度陷阱:float64 vs float32

NumPy 默认使用 float64(双精度浮点数),而现代消费级游戏显卡(N 卡的 GeForce 系列)为了与专业计算卡(Tesla/Quadro)区分,在硬件层面上人为阉割了 float64 的计算单元,其 FP64 性能通常只有 FP32(单精度)的 1/321/32 甚至 1/641/64

CuPy 默认也会跟随你的输入类型。如果你发现 GPU 并没有想象中快,检查一下你是不是用了 float64。但要注意,在求解微分方程时,单精度 float32 会更快地积累截断误差,在长时间模拟中可能会导致系统崩溃。这是一个必须根据实际物理需求进行权衡的取舍。

下回预告#

至此,我们已经成功利用普通的游戏显卡,完成了一次算力飞跃。

但在长时间演化的物理系统(如天体力学或分子动力学)中,你会发现双摆的总能量并不守恒,这并不是代码写错了,而是 RK4 的数学基因决定的。

下一节,我们将重新审视微分方程求解过程,分析更精确的隐式方法和保结构的辛方法

如何将微分方程写成python看得懂的函数?开始学习计算机求解微分方程的暴力美学 (初识数值实验·上)
https://blog.emumu.xyz/posts/2026-04-24-00/
作者
月宮絵夢
发布于
2026-04-24
许可协议
CC BY-NC-SA 4.0