これは、なにをしたくて書いたもの?
テキスト埋め込みを行うにはSentence Transformersを使うのがいいのかなと思っているのですが、できれば単体で動作するサーバーとして
使いたいなと。
これをやろうとするとLocalAIを使うのが1番近い気がするのですが、準備にかなり手間がかかります。
Embeddings / Huggingface embeddings
じゃあもういっそのこと簡単なAPIサーバーを自分で作ったらいいかなということで、作ることにしました。
Sentence Transformersのインストールには時間がかかるのですが、それさえできてしまえばテキスト埋め込みを動かすのにそれほど
大量のリソースは要らないので。
FastAPIで作る
お題としては「Sentence Transformersの機能を使ったテキスト埋め込みが行えるREST API」です。
FastAPIで作るのがよいかなと。
簡単にテストまで行うことにしました。
環境
今回の環境はこちら。
$ python3 --version Python 3.10.12 $ pip3 --version pip 22.0.2 from /usr/lib/python3/dist-packages/pip (python 3.10)
FastAPIでSentence Transformersを使ったテキスト埋め込みAPIを作る
まずはライブラリーのインストール。ASGIサーバーはUvicornを使うことにします。
$ pip3 install sentence-transformers fastapi uvicorn[standard]
テスト向けのライブラリーもインストール。
$ pip3 install pytest httpx
インストールしたライブラリーの一覧はこちら。
$ pip3 list Package Version ------------------------ ---------- annotated-types 0.6.0 anyio 4.3.0 certifi 2024.2.2 charset-normalizer 3.3.2 click 8.1.7 dnspython 2.6.1 email_validator 2.1.1 exceptiongroup 1.2.1 fastapi 0.111.0 fastapi-cli 0.0.3 filelock 3.14.0 fsspec 2024.5.0 h11 0.14.0 httpcore 1.0.5 httptools 0.6.1 httpx 0.27.0 huggingface-hub 0.23.0 idna 3.7 iniconfig 2.0.0 Jinja2 3.1.4 joblib 1.4.2 markdown-it-py 3.0.0 MarkupSafe 2.1.5 mdurl 0.1.2 mpmath 1.3.0 networkx 3.3 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.20.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.1.105 orjson 3.10.3 packaging 24.0 pillow 10.3.0 pip 22.0.2 pluggy 1.5.0 pydantic 2.7.1 pydantic_core 2.18.2 Pygments 2.18.0 pytest 8.2.0 python-dotenv 1.0.1 python-multipart 0.0.9 PyYAML 6.0.1 regex 2024.5.15 requests 2.31.0 rich 13.7.1 safetensors 0.4.3 scikit-learn 1.4.2 scipy 1.13.0 sentence-transformers 2.7.0 setuptools 59.6.0 shellingham 1.5.4 sniffio 1.3.1 starlette 0.37.2 sympy 1.12 threadpoolctl 3.5.0 tokenizers 0.19.1 tomli 2.0.1 torch 2.3.0 tqdm 4.66.4 transformers 4.40.2 triton 2.3.0 typer 0.12.3 typing_extensions 4.11.0 ujson 5.10.0 urllib3 2.2.1 uvicorn 0.29.0 uvloop 0.19.0 watchfiles 0.21.0 websockets 12.0
作成したソースコードはこちら。
api.py
from fastapi import FastAPI from pydantic import BaseModel import os from sentence_transformers import SentenceTransformer app = FastAPI() class EmbeddingRequest(BaseModel): model: str text: str normalize: bool = False class EmbeddingResponse(BaseModel): model: str embedding: list[float] dimention: int @app.post("/embeddings/encode") def encode(request: EmbeddingRequest) -> EmbeddingResponse: sentence_transformer_model = SentenceTransformer( request.model, device=os.getenv("EMBEDDING_API_DEVICE", "cpu") ) embeddings = sentence_transformer_model.encode(sentences=[request.text], normalize_embeddings=request.normalize) embedding = embeddings[0] # numpy array to float list embedding_as_float = embedding.tolist() return EmbeddingResponse( model=request.model, embedding=embedding_as_float, dimention=sentence_transformer_model.get_sentence_embedding_dimension() )
リクエストにはテキスト埋め込みに使うモデルと対象のテキスト、正規化の有無を
class EmbeddingRequest(BaseModel): model: str text: str normalize: bool = False
レスポンスにはリクエストで指定されたモデル、テキスト埋め込みの結果、ベクトルの次元数を返すことにしました。
class EmbeddingResponse(BaseModel): model: str embedding: list[float] dimention: int
APIの実装はこんな感じですね。
@app.post("/embeddings/encode") def encode(request: EmbeddingRequest) -> EmbeddingResponse: sentence_transformer_model = SentenceTransformer( request.model, device=os.getenv("EMBEDDING_API_DEVICE", "cpu") ) embeddings = sentence_transformer_model.encode(sentences=[request.text], normalize_embeddings=request.normalize) embedding = embeddings[0] # numpy array to float list embedding_as_float = embedding.tolist() return EmbeddingResponse( model=request.model, embedding=embedding_as_float, dimention=sentence_transformer_model.get_sentence_embedding_dimension() )
モデルは、実行時に自動的にHugging Face Hubからダウンロードしてきます。
numpyの配列をリストに変換する必要があったところが困ったくらいですね…。
起動。
$ uvicorn api:app # または $ uvicorn api:app --reload
確認。
$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World"}' | jq { "model": "all-MiniLM-L6-v2", "embedding": [ -0.03447727486491203, 0.03102317824959755, 0.006734995171427727, 0.026108944788575172, -0.039361994713544846, 〜省略〜 0.03323201462626457, 0.02379228174686432, -0.022889817133545876, 0.03893755003809929, 0.0302068330347538 ], "dimention": 384 }
もうひとつ、モデルを変更して確認してみましょう。
$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World"}' | jq { "model": "intfloat/multilingual-e5-base", "embedding": [ 0.03324141725897789, 0.04988044500350952, 0.00241446984000504, 0.011555945500731468, 0.03409387916326523, 〜省略〜 -0.018477996811270714, 0.04818818345665932, -0.04364151135087013, -0.04888230562210083, 0.03604992479085922 ], "dimention": 768 }
OKですね。
正規化を使うリクエストはこうなのですが、結果は変わりませんでした…。
$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World", "normalize": true}' | jq $ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World", "normalize": true}' | jq
あとはテストを書いておきます。
test_api.py
from fastapi.testclient import TestClient from api import app, EmbeddingRequest, EmbeddingResponse client = TestClient(app) def test_encode_basic(): request = EmbeddingRequest(model="all-MiniLM-L6-v2", text="Hello World") raw_response = client.post("/embeddings/encode", json=request.model_dump()) assert raw_response.status_code == 200 response = EmbeddingResponse.model_validate(raw_response.json()) assert response.model == "all-MiniLM-L6-v2" assert len(response.embedding) == 384 assert response.dimention == 384 def test_encode_e5(): request = EmbeddingRequest(model="intfloat/multilingual-e5-base", text="passage: Hello World") raw_response = client.post("/embeddings/encode", json=request.model_dump()) assert raw_response.status_code == 200 response = EmbeddingResponse.model_validate(raw_response.json()) assert response.model == "intfloat/multilingual-e5-base" assert len(response.embedding) == 768 assert response.dimention == 768
参考にしたのはこちらのページと
こちら。
Pydanticはあまり見ていなかったので、ちょっと手間取りました…。
確認。
$ pytest ===================================================================================== test session starts ====================================================================================== platform linux -- Python 3.10.12, pytest-8.2.0, pluggy-1.5.0 rootdir: /path/to plugins: anyio-4.3.0 collected 2 items test_api.py .. [100%] ====================================================================================== 2 passed in 9.16s =======================================================================================
OKですね。
おわりに
FastAPIとSentence Transformersを使って、簡単なテキスト埋め込みAPIを作成してみました。
特にPython以外でテキスト埋め込みをやりたいと思った時に、どうやってテキスト埋め込みを行うかにちょっと困っていたので、こうやって
自分で作ったものを使ってみてもいいかなと。
FastAPIのちょっとした勉強にもなりました。