【PyTorch】torch.compileで推論速度を劇的に改善する実践ガイドとよくあるエラー解決法

問題の概要:torch.compileを使っても推論速度が改善しない、またはエラーが発生する

PyTorch 2.0で導入されたtorch.compileは、モデルの推論(および学習)速度を向上させる強力な機能です。しかし、実際に導入してみると、以下のような問題に直面することがあります。

  • コンパイルは成功したが、推論速度がほとんど変わらない、むしろ遅くなった
  • 「RuntimeError: expected int but got torch.SymInt」などの謎のエラーが発生する
  • 動的な入力形状(可変長シーケンスなど)を持つモデルでコンパイルに失敗する
  • GPUメモリ使用量が予想外に増加する

本記事では、これらの課題を解決し、torch.compileを実践で効果的に活用するためのステップバイステップガイドを提供します。

原因の解説:なぜtorch.compileがうまく働かないのか?

torch.compileは、PyTorchのEager Execution(逐次実行)グラフを最適化されたカーネルにコンパイルします。その過程で発生する主な問題の原因は以下の通りです。

1. グラフブレークによる最適化の阻害

モデル内に「グラフブレーク」を引き起こす操作(例:条件分岐がデータ依存、Pythonの制御フロー、特定のNumPy操作)があると、コンパイラは一つの大きなグラフではなく、複数の小さなグラフに分割します。これにより、最適化の効果が大幅に減少し、グラフ間のオーバーヘッドで逆に速度が低下することがあります。

2. 動的形状への非対応

デフォルトの設定では、コンパイラは最初に受け取ったテンソルの形状を固定とみなします。推論時にバッチサイズやシーケンス長が変動すると、形状が異なるたびに再コンパイル(「グラフの再キャプチャ」)が発生し、大きなオーバーヘッドとなります。

3. サポート外の演算子やカスタム操作

一部の演算子や、純粋なPyTorchテンソル操作で書かれていないカスタム関数(C++拡張など)は、コンパイルグラフに統合できず、グラフブレークの原因になります。

4. 初期コンパイルのコスト

torch.compileには最初の実行時(または形状が変化した時)にコンパイルを行う「ウォームアップ」コストがかかります。少数回の推論しか行わない場合、このコストがメリットを上回ってしまいます。

解決方法:torch.compileを効果的に活用するステップバイステップガイド

ステップ1: 基本コンパイルとベンチマーク

まずは最もシンプルな方法でコンパイルし、ベースラインを計測します。

import torch
import torch.nn as nn
import time

# サンプルモデル
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 10),
        )
    def forward(self, x):
        return self.layers(x)

model = SimpleModel().cuda()
input_tensor = torch.randn(32, 512).cuda()

# コンパイル前の推論速度計測
start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = model(input_tensor)
torch.cuda.synchronize()
print(f"Eager time: {time.time() - start:.4f} sec")

# 基本コンパイル
compiled_model = torch.compile(model)

# ウォームアップ実行(コンパイル発生)
with torch.no_grad():
    _ = compiled_model(input_tensor)

# コンパイル後の推論速度計測
start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = compiled_model(input_tensor)
torch.cuda.synchronize()
print(f"Compiled time: {time.time() - start:.4f} sec")

ステップ2: モードとバックエンドの選択による最適化

torch.compileには最適化の積極性を決めるmode引数があります。推論では通常"reduce-overhead"または"max-autotune"が有効です。

# 推論オーバーヘッド低減に特化したモード(推論推奨)
compiled_model_reduce = torch.compile(model, mode="reduce-overhead")

# 最大限の最適化を試みるモード(計算量の多いモデル向け)
compiled_model_max = torch.compile(model, mode="max-autotune")

# バックエンドを指定(通常はデフォルトでOK)
compiled_model_backend = torch.compile(model, backend="inductor")

エラー例と解決策: mode="max-autotune"でメモリ不足エラーが発生した場合、より軽量な"reduce-overhead"に切り替えてください。

ステップ3: 動的形状への対応 – 静的/動的コンパイルの切り分け

入力形状が動的な場合は、dynamic=Trueを設定することで、特定の次元を動的としてコンパイルできます。これにより、形状が変わっても再コンパイルが頻発するのを防ぎます。

# バッチサイズ次元を動的としてコンパイル(シーケンス長は静的)
compiled_model_dynamic = torch.compile(model, dynamic=True)

# より細かい制御(PyTorch 2.1以降)
compiled_model_dynamic_detail = torch.compile(model, dynamic={
    "tracing_mode": "symbolic", # または "real"
    "freeze_inputs": False, # 入力を固定しない
})

注意点: dynamic=Trueは万能ではなく、場合によっては最適化が制限されます。可能であれば、推論時に形状を固定(パディングなど)することが最も効果的です。

ステップ4: グラフブレークの診断と対処

コンパイルの詳細情報を出力して、どこでグラフが分割されているかを確認できます。

# グラフブレークの原因を分析
import torch._dynamo.config
import logging

torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.output_code = True # 生成されたコードを表示(詳細)

compiled_model_debug = torch.compile(model)
with torch.no_grad():
    output = compiled_model_debug(input_tensor)

ログにGRAPH BREAKと表示された場合、その原因を探ります。よくある原因と対処法:

  • Pythonのprintassert: 推論時には削除するか、torch.compilefullgraph=Trueモードを使用しない。
  • データ依存の条件分岐: 可能であれば分岐をモデル外に出すか、分岐条件を固定化する。
  • サポート外の演算子: カスタム演算子を純粋なPyTorch操作で書き直すことを検討する。

ステップ5: 推論専用の最適化テクニック

# 1. 推論モードと自動混合精度の併用
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
    compiled_model = torch.compile(model)
    # ... 推論実行

# 2. キャッシュの活用(同じ形状の推論が繰り返される場合)
compiled_model = torch.compile(model, mode="reduce-overhead")
# 最初の推論でコンパイルされ、キャッシュされる

# 3. モデルを事前にコンパイルして保存(実験的機能)
# 注意: 環境や入力形状が変わると無効になる可能性あり
torch._dynamo.reset()
compiled_model = torch.compile(model)
compiled_model(input_tensor) # ウォームアップ
# コンパイルされたグラフはメモリ内にキャッシュされる

コード例・コマンド例:ResNetでの実践例

import torchvision.models as models
import torch

# モデル準備
model = models.resnet50(weights='IMAGENET1K_V2').cuda().eval()
dummy_input = torch.randn(16, 3, 224, 224).cuda()

# 最適なコンパイル設定の適用(推論シナリオ)
compiled_model = torch.compile(
    model,
    mode="reduce-overhead",  # 推論オーバーヘッド低減
    dynamic=False,           # 入力形状固定(バッチサイズ16固定)
    fullgraph=False          # グラフブレークを許容(安定性優先)
)

# ウォームアップ
with torch.no_grad():
    for _ in range(2):
        _ = compiled_model(dummy_input)

# 本番推論
with torch.no_grad():
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(100):
        _ = compiled_model(dummy_input)
    end.record()
    torch.cuda.synchronize()
    print(f"Compiled ResNet50 inference time: {start.elapsed_time(end)/100:.2f} ms per batch")

まとめ・補足情報

torch.compileは適切に使用すれば、特に推論速度の向上に大きな効果を発揮します。成功のためのポイントをまとめます。

主要な推奨設定

  • 推論用途: mode="reduce-overhead"がバランス良好。
  • 動的形状: 可能な限り固定し、難しい場合はdynamic=Trueを試す。
  • ウォームアップ: 本番計測前に数回実行し、コンパイルコストを排除する。
  • メモリ: max-autotuneで問題が起きたら、より軽量なモードに戻る。

パフォーマンスが期待通りでない場合のチェックリスト

  1. モデルが十分に大きく、計算量が多いか?(小さなモデルではオーバーヘッドが目立つ)
  2. 推論のバッチサイズは適切か?(大きすぎるとメモリ不足、小さすぎると効果薄)
  3. グラフブレークが多発していないか?(ログで確認)
  4. 同じ形状での推論が繰り返されているか?(動的形状による再コンパイルが発生していないか)

最後に、torch.compileは急速に進化している機能です。PyTorchのバージョンアップに伴い、サポートされる操作や最適化が強化されています。本記事の内容はPyTorch 2.2以降を基準としていますが、最新の公式ドキュメントやリリースノートを常に参照することをお勧めします。効果的なコンパイルにより、AIモデルの推論速度と効率性を大幅に高めることができるでしょう。

この記事は役に立ちましたか?