【PyTorch】Gradient Checkpointingで大規模モデル学習時のメモリ不足エラーを解決する方法

問題の概要:大規模モデル学習時の「CUDA out of memory」エラー

深層学習において、Transformerをはじめとする大規模なニューラルネットワークモデルを学習させようとすると、頻繁に以下のようなエラーに遭遇します。

RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 10.00 GiB total capacity; 5.50 GiB already allocated; 0 bytes free; 8.00 GiB reserved in total by PyTorch)

このエラーは、モデルの順伝播(フォワードパス)における中間活性化(activation)を保持するために必要なメモリが、GPUのメモリ容量を超えてしまったことを示しています。バッチサイズを小さくすれば一時的に解決することもありますが、学習の安定性や速度が犠牲になり、根本的な解決にはなりません。特に、層数が多くパラメータ数の大きなモデル(例:数十億パラメータのLLM)では、この問題が顕著になります。

原因の解説:計算グラフとメモリ使用量の関係

誤差逆伝播法(Backpropagation)を実行するためには、順伝播で計算された各層の中間活性化の値が必要です。デフォルトのPyTorchの動作では、これらすべての中間活性化がメモリ上に保持されます。そのため、モデルが深くなる(層数が増える)ほど、あるいはバッチサイズが大きくなるほど、必要なメモリ量は線形的に増加していきます。

例えば、100層のモデルがある場合、逆伝播のために100層分の中間活性化をすべてメモリ上に保存しておかなければなりません。これが「CUDA out of memory」エラーの主な原因です。Gradient Checkpointing(勾配チェックポイント)は、この「メモリ使用量と層数(計算量)のトレードオフ」を最適化する技術です。

Gradient Checkpointingの基本原理

Gradient Checkpointingの核心は、「すべての中間活性化を保存するのではなく、戦略的に選んだいくつかのチェックポイントでのみ活性化を保存し、それ以外の部分は必要になった時点で再計算する」という考え方です。

  • デフォルトの動作: 順伝播中にすべての活性化を保存 → メモリ使用量が多いが、再計算が不要で高速。
  • Checkpointingの動作: 少数の活性化のみを保存 → メモリ使用量が大幅に減少するが、保存されなかった部分の活性化は逆伝播中に順伝播を再実行して計算するため、計算時間が増加する。

これは、メモリ使用量を計算時間と交換(Time-Memory Trade-off)する古典的な手法です。適切に適用すれば、同じGPUで扱えるモデルのサイズやバッチサイズを数倍から数十倍に拡大することが可能になります。

解決方法:PyTorchにおけるGradient Checkpointingの実装手順

PyTorchでは、torch.utils.checkpoint モジュールを利用して、比較的簡単にGradient Checkpointingを導入できます。以下、ステップバイステップで説明します。

ステップ1: モジュールのインポート

まず、必要なモジュールをインポートします。

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential

ステップ2: カスタムモデルへの適用(基本形)

モデルの一部のサブモジュール(例えば、Transformerのレイヤー)に対してCheckpointingを適用する場合は、checkpoint関数を使用します。この関数は、第一引数に実行したい関数(通常はモジュールのフォワードメソッド)、第二引数以降にその関数に渡す引数を取ります。

重要な点: checkpointでラップされた関数は、非決定的な操作(e.g., Dropout)を含む場合、use_reentrant=Falseオプションを設定するか、ランダム性を制御する必要があります(PyTorch 1.11以降)。

class MyLargeModel(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([MyTransformerLayer() for _ in range(num_layers)])

    def forward(self, x):
        # 各レイヤーをチェックポイントで実行
        for layer in self.layers:
            # checkpoint関数は、layer自体と入力xを引数として取る。
            # 注意: layer.forwardは呼び出さず、layerオブジェクトを渡す。
            x = checkpoint(layer, x, use_reentrant=False)
        return x

ステップ3: Sequentialモデルへの適用(簡易形)

モデルがnn.Sequentialで構成されている場合は、checkpoint_sequential関数が便利です。この関数はシーケンシャルなブロックを指定された数のセグメント(チェックポイント)に分割します。

class MySequentialModel(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.block = nn.Sequential(
            *[MyTransformerLayer() for _ in range(num_layers)]
        )

    def forward(self, x):
        # ブロックを4つのセグメントに分割してチェックポイントを適用
        # セグメント数はハイパーパラメータ。メモリと計算効率のバランスを取る。
        num_segments = 4
        return checkpoint_sequential(self.block, num_segments, x)

ステップ4: 学習ループでの実行と注意点

モデルを定義したら、通常通り学習ループを実行します。Checkpointingはモデルの内部で透過的に動作します。

model = MyLargeModel(num_layers=100).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for inputs, labels in dataloader:
    inputs, labels = inputs.cuda(), labels.cuda()
    optimizer.zero_grad()

    # フォワードパス(内部でcheckpointが動作)
    outputs = model(inputs)

    loss = criterion(outputs, labels)
    loss.backward() # この中で、保存されていない活性化の再計算が行われる
    optimizer.step()

注意点:

  1. 非決定的性: use_reentrant=False (推奨) を指定しない場合、DropoutやBatchNormなどランダム性のある層を含むモデルで再現性の問題が生じる可能性があります。
  2. パフォーマンス: メモリ使用量は1/√n程度に削減できますが(nはチェックポイント数)、計算時間は約30%増加する可能性があります。プロファイリングを行い、ボトルネックがメモリなのか計算なのかを確認することが重要です。
  3. RNNへの適用: RNNの各タイムステップに適用するのは複雑であり、公式のチェックポイント関数はそのままでは不向きな場合があります。カスタム実装が必要になることがあります。

コード例・コマンド例:実践的なユースケース

実際に、メモリ不足エラーが発生しているスクリプトを修正する例を示します。

修正前(メモリ不足が発生):

# 巨大なTransformerブロック
transformer_block = nn.TransformerEncoderLayer(d_model=1024, nhead=16, batch_first=True)
encoder = nn.TransformerEncoder(transformer_block, num_layers=24) # 24層

# フォワードパス
output = encoder(input_tensor) # ここでCUDA OOMが発生!
loss = criterion(output, target)
loss.backward()

修正後(Gradient Checkpointingを適用):

from torch.utils.checkpoint import checkpoint

class CheckpointedTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])

    def forward(self, src):
        for layer in self.layers:
            # 各エンコーダーレイヤーをチェックポイントで実行
            src = checkpoint(layer, src, use_reentrant=False)
        return src

# モデル構築
transformer_block = nn.TransformerEncoderLayer(d_model=1024, nhead=16, batch_first=True)
encoder = CheckpointedTransformerEncoder(transformer_block, num_layers=24) # 24層でも学習可能に

# フォワードパス(メモリ使用量が大幅に削減される)
output = encoder(input_tensor)
loss = criterion(output, target)
loss.backward() # メモリ不足エラーが解消される

まとめ・補足情報

PyTorchのGradient Checkpointingは、限られたGPUメモリリソースで大規模なモデルを学習させるための必須の技術です。本記事で紹介したtorch.utils.checkpoint.checkpoint関数を用いることで、既存のモデルコードに比較的少ない修正で導入することができます。

適用の判断基準:

  • 「CUDA out of memory」エラーが発生している。
  • バッチサイズを極端に小さく(1や2)しなければ学習できない。
  • モデルの層数が非常に多い(例:50層以上)。

さらに進んだ情報:

  • チェックポイントの配置戦略: どのレイヤーをチェックポイントとするかは性能に影響します。均等に分割するのが一般的ですが、メモリ使用量の多い層(Attention層など)の直後に設定するなどの工夫も考えられます。
  • FSDP (Fully Sharded Data Parallel) との併用: 超大規模モデル学習では、Gradient Checkpointingと、モデルパラメータ、勾配、オプティマイザ状態を分散させるFSDPを併用することが一般的です。これにより、単一GPUのメモリ容量をはるかに超えるモデルの学習が可能になります。
  • プロファイリング: torch.cuda.memory_allocated()torch.cuda.max_memory_allocated()を使って、Checkpointing適用前後のメモリ使用量を計測し、効果を定量的に評価しましょう。

Gradient Checkpointingは、メモリと計算時間のトレードオフを管理する強力なツールです。大規模モデル開発に取り組む際は、この技術を活用して、ハードウェアの制約を超えたイノベーションを実現してください。

この記事は役に立ちましたか?