Gemma3のファインチューニング - データ加工から学習まで
大規模言語モデルGemma3をファインチューニングした。
Google Colabで、データ加工から学習、保存、推論まで一通り実装した。最終的に高精度化するまでの過程をまとめる。
やったこと
Gemma3-1b-itをLoRAでファインチューニングし、日本語の質問応答モデルを作成した。
データセットをAIプロンプト形式に変換し、LoRAのパラメータ(r=8、alpha=16)を設定し、20エポック学習させた。学習率は3e-4で安定した。
ファインチューニング後のモデルは、特定のドメイン(例:菌類QA)で高い精度を発揮する。ベースモデルでは答えられなかった質問にも、正確に回答できるようになった。
Gemma3について
Gemma3はGoogleが公開している小規模な大規模言語モデルだ。
1bパラメータ版と2bパラメータ版がある。今回は1b版(gemma-3-1b-it)を使った。小さいモデルだが、ファインチューニングすれば十分実用的になる。
小さいモデルを選んだ理由は、Google Colabの無料枠でも学習できるから。GPUメモリが限られているので、1bモデルが適している。
データ加工の流れ
データをGemma3のプロンプト形式に変換する必要がある。
Gemma3のプロンプト形式は、<start_of_turn>user、<start_of_turn>modelというタグを使う。これを間違えると、学習そのものが失敗する。
最初、形式を間違えた。一度つまずいた。
タグの位置が違ったり、改行が抜けていたりすると、モデルが正しく学習できない。ドキュメントを何度も確認して、正しい形式を理解した。
正しい形式でデータを作成したら、学習がうまくいった。これだ。
LoRAファインチューニング
LoRA(Low-Rank Adaptation)を使った。
全パラメータを学習させると、GPUメモリが足りない。LoRAなら、一部のパラメータだけを学習させるので、メモリ消費が少ない。
# LoRAの設定
lora_config = LoraConfig(
r=8, # LoRAのランク
lora_alpha=16, # スケーリング係数
target_modules=["q_proj", "v_proj"], # 対象モジュール
lora_dropout=0.05, # ドロップアウト
bias="none",
task_type="CAUSAL_LM"
)
rは8に設定した。小さすぎると表現力が落ちる。大きすぎるとメモリを消費する。8がバランスが良い。
lora_alphaは16に設定した。rの2倍にするのが一般的だ。
学習パラメータの調整
学習率とエポック数の調整が重要だ。
最初は学習率1e-3で試した。学習が不安定だった。損失が発散する。
5e-4に下げた。まだ不安定。
3e-4まで下げたら、安定した。損失が順調に減少する。
エポック数は20に設定した。10エポックだと学習が不十分。30エポックだと過学習のリスクがある。20エポックでちょうど良かった。
学習の実行
Google Colabで学習を実行した。
# トレーニング設定
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=20,
per_device_train_batch_size=4,
learning_rate=3e-4,
warmup_steps=100,
logging_steps=10,
save_steps=500
)
# トレーナーの初期化
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
# 学習開始
trainer.train()
バッチサイズは4に設定した。大きくするとメモリ不足になる。小さくすると学習が遅い。4がちょうど良い。
学習には約2時間かかった。Google Colabの無料T4 GPUで実行した。
全体の流れ
データを準備し、Gemma3のプロンプト形式に変換し、LoRAでファインチューニングし、モデルを保存する。推論時は、LoRAのアダプタを読み込んで使用する。
使ってみて
Gemma3-1bをLoRAでファインチューニングした。
ポイントは以下の3つ:
- データをGemma3のプロンプト形式(
<start_of_turn>タグ)に正確に変換
- LoRAのパラメータ(r=8、alpha=16)を適切に設定
- 学習率を3e-4、20エポックで学習
小さいモデルでも、ファインチューニングすれば特定のドメインで高い精度を出せる。Google Colabの無料枠で十分学習できるのが嬉しい。
同じような小規模モデルのファインチューニングを考えている方の参考になれば嬉しいです。
まとめ
今回は、Gemma3-1bをLoRAでファインチューニングし、日本語QAモデルを作成した。
ポイントは以下の4つ:
- Gemma3のプロンプト形式を正確に理解(
<start_of_turn>タグが重要)
- LoRAパラメータ(r=8、alpha=16)でメモリ消費を抑制
- 学習率3e-4、エポック20で安定した学習
- Google Colab無料T4 GPUで約2時間で完了
データ加工が一番重要だった。プロンプト形式を間違えると、学習が失敗する。正しい形式を理解すれば、あとは学習させるだけだ。
さらに深く学ぶには
この記事で興味を持った方におすすめのリンク:
自分の関連記事:
最後まで読んでくださり、ありがとうございました。