これは、なにをしたくて書いたもの?
先日、GoogleからAIモデル「Gemma」が公開されました。
Gemma: Google introduces new state-of-the-art open models
グーグル、軽量でオープンな新AIモデル「Gemma」 - Impress Watch
今回は、こちらをHugging FaceのTransformersで試してみたいと思います。
Gemma
GemmaのWebサイトはこちら。
Gemma - Google が提供する最先端の軽量オープンモデル ファミリー。 | Google AI for Developers
同じくGoogleのAIモデルである「Gemini」を作った技術を元にした、軽量でオープンなモデルとされています。
Gemini モデルの作成に使用されたのと同じ研究とテクノロジーを基に構築された、軽量で最先端のオープンモデルのファミリーです。
ドキュメントはこちら。
Gemma モデルの概要 | Google AI for Developers
モデルにはパラメーターで2種類(2B、7B)、事前学習済み・instructモデルの2種類の組み合わせがあります。
Gemmaの入手元はいくつかありますが、今回はHugging Faceからダウンロードすることにします。
モデルは、以下の4種類があります。
- google/gemma-7b · Hugging Face
- google/gemma-7b-it · Hugging Face
- google/gemma-2b · Hugging Face
- google/gemma-2b-it · Hugging Face
サイズの関係もあって、今回はこちらを使うことにします。
google/gemma-2b-it · Hugging Face
Gemmaの利用には、ライセンスへの同意が必要です。
Hugging Faceのアカウントを持っていない場合はアカウントを作成し、同意する必要があります。
また、モデルをブラウザからダウンロードする場合はいいのですが、Transformersといったアプリケーションから利用する場合は
アプリケーション実行時に認証することになるので、アクセストークンが必要になります。
アクセストークンは、以下のページを参考に取得します。
取得したアクセストークンは、環境変数HF_TOKEN
に設定して使います。
Environment variables / Generic / HF_TOKEN
今回は、Gemmaのgoogle/gemma-2b-itモデルを使い、ドキュメントに記載されているチャットのサンプルをそのまま動かしてみます。
※それくらいしかできませんでした…
google/gemma-2b-it · Hugging Face
環境
今回の環境はこちら。
$ python3 --version Python 3.10.12 $ pip3 --version pip 22.0.2 from /usr/lib/python3/dist-packages/pip (python 3.10)
GoogleのLLM「Gemma」をTransformersで試す
まずは、必要なライブラリーをインストールします。ドキュメントではpip install -U transformers
とすればよいと書かれていますが、
CPU環境だとtransformers[torch]
としてインストールすることになります。
$ pip3 install transformers[torch]
インストールされたライブラリー一覧。
$ pip3 list Package Version ------------------------ ---------- accelerate 0.27.2 certifi 2024.2.2 charset-normalizer 3.3.2 filelock 3.13.1 fsspec 2024.2.0 huggingface-hub 0.20.3 idna 3.6 Jinja2 3.1.3 MarkupSafe 2.1.5 mpmath 1.3.0 networkx 3.2.1 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.19.3 nvidia-nvjitlink-cu12 12.3.101 nvidia-nvtx-cu12 12.1.105 packaging 23.2 pip 22.0.2 psutil 5.9.8 PyYAML 6.0.1 regex 2023.12.25 requests 2.31.0 safetensors 0.4.2 setuptools 59.6.0 sympy 1.12 tokenizers 0.15.2 torch 2.2.1 tqdm 4.66.2 transformers 4.38.1 triton 2.2.0 typing_extensions 4.9.0 urllib3 2.2.1
UsageのChat Templateをもとに、こんな感じのソースコードを作成。
app.py
app.py from transformers import AutoTokenizer, AutoModelForCausalLM import transformers import torch tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") chat = [ { "role": "user", "content": "Write a hello world program" }, ] prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=150) print(tokenizer.decode(outputs[0]))
google/gemma-2b-it / Usage / Chat Template
初回実行時はHugging Faceからモデルをダウンロードすることになりますが、この時Gemmaの場合だと認証が必要になるので
アクセストークンを環境変数HF_TOKEN
に設定しておきます。
$ export HF_TOKEN=[Hugging Faceのアクセストークン]
実行。
※初回のダウンロードログは省略
※"```"ははてなブログ上のMarkdownと混ざるので、前にスペースを入れています
$ python3 app.py Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [03:54<00:00, 117.16s/it] <bos><start_of_turn>user Write a hello world program<end_of_turn> <start_of_turn>model ```python print("Hello, world!") ``` **Explanation:** * `print()` is a built-in Python function that prints the given argument to the console. * `"Hello, world!"` is the string that we want to print. * `` is the string delimiter, which tells `print()` to print the string on a single line. **Output:** ``` Hello, world! ``` **Note:** * The `print()` function can take multiple arguments, which will be separated by commas. * You can also use `\n` as the string delimiter to print a new line character. * The `print()` function is a built-in module, so you
動きました。
動くには動いたのですが、とてもとても重くてちょっといろいろ試す気になりません…。
この時点で5分とか待つことになりました…。
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [03:54<00:00, 117.16s/it]
そういえば、Llama 2を使った時はいきなりllama-cpp-pythonやLocalAIを使ってTransformersは使いませんでしたね。
モデルも品質を絞ったものだったりしましたし。
今度はllama-cpp-pythonまたはLocalAIでどうなるか確認することにしましょう。
おわりに
GoogleのLLM「Gemma」をTransformersで試してみました、が、自分の環境では重すぎてちょっと厳しかったです。
Hugging Faceのアクセストークンの使い方を知れたりしたのは良かったですが、今度はllama-cpp-pythonやLocalAIで試してみることに
しましょう。