CLOVER🍀

That was when it all began.

FastAPIとSentence Transformersを使って簡単なテキスト埋め込みAPIを作成する

これは、なにをしたくて書いたもの?

テキスト埋め込みを行うにはSentence Transformersを使うのがいいのかなと思っているのですが、できれば単体で動作するサーバーとして
使いたいなと。

これをやろうとするとLocalAIを使うのが1番近い気がするのですが、準備にかなり手間がかかります。

Embeddings / Huggingface embeddings

じゃあもういっそのこと簡単なAPIサーバーを自分で作ったらいいかなということで、作ることにしました。

Sentence Transformersのインストールには時間がかかるのですが、それさえできてしまえばテキスト埋め込みを動かすのにそれほど
大量のリソースは要らないので。

FastAPIで作る

お題としては「Sentence Transformersの機能を使ったテキスト埋め込みが行えるREST API」です。

FastAPIで作るのがよいかなと。

FastAPI

簡単にテストまで行うことにしました。

Testing - FastAPI

環境

今回の環境はこちら。

$ python3 --version
Python 3.10.12


$ pip3 --version
pip 22.0.2 from /usr/lib/python3/dist-packages/pip (python 3.10)

FastAPIでSentence Transformersを使ったテキスト埋め込みAPIを作る

まずはライブラリーのインストール。ASGIサーバーはUvicornを使うことにします。

$ pip3 install sentence-transformers fastapi uvicorn[standard]

テスト向けのライブラリーもインストール。

$ pip3 install pytest httpx

インストールしたライブラリーの一覧はこちら。

$ pip3 list
Package                  Version
------------------------ ----------
annotated-types          0.6.0
anyio                    4.3.0
certifi                  2024.2.2
charset-normalizer       3.3.2
click                    8.1.7
dnspython                2.6.1
email_validator          2.1.1
exceptiongroup           1.2.1
fastapi                  0.111.0
fastapi-cli              0.0.3
filelock                 3.14.0
fsspec                   2024.5.0
h11                      0.14.0
httpcore                 1.0.5
httptools                0.6.1
httpx                    0.27.0
huggingface-hub          0.23.0
idna                     3.7
iniconfig                2.0.0
Jinja2                   3.1.4
joblib                   1.4.2
markdown-it-py           3.0.0
MarkupSafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
networkx                 3.3
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
orjson                   3.10.3
packaging                24.0
pillow                   10.3.0
pip                      22.0.2
pluggy                   1.5.0
pydantic                 2.7.1
pydantic_core            2.18.2
Pygments                 2.18.0
pytest                   8.2.0
python-dotenv            1.0.1
python-multipart         0.0.9
PyYAML                   6.0.1
regex                    2024.5.15
requests                 2.31.0
rich                     13.7.1
safetensors              0.4.3
scikit-learn             1.4.2
scipy                    1.13.0
sentence-transformers    2.7.0
setuptools               59.6.0
shellingham              1.5.4
sniffio                  1.3.1
starlette                0.37.2
sympy                    1.12
threadpoolctl            3.5.0
tokenizers               0.19.1
tomli                    2.0.1
torch                    2.3.0
tqdm                     4.66.4
transformers             4.40.2
triton                   2.3.0
typer                    0.12.3
typing_extensions        4.11.0
ujson                    5.10.0
urllib3                  2.2.1
uvicorn                  0.29.0
uvloop                   0.19.0
watchfiles               0.21.0
websockets               12.0

作成したソースコードはこちら。

api.py

from fastapi import FastAPI
from pydantic import BaseModel
import os
from sentence_transformers import SentenceTransformer

app = FastAPI()

class EmbeddingRequest(BaseModel):
    model: str
    text: str
    normalize: bool = False

class EmbeddingResponse(BaseModel):
    model: str
    embedding: list[float]
    dimension: int

@app.post("/embeddings/encode")
def encode(request: EmbeddingRequest) -> EmbeddingResponse:
    sentence_transformer_model = SentenceTransformer(
        request.model,
        device=os.getenv("EMBEDDING_API_DEVICE", "cpu")
    )

    embeddings = sentence_transformer_model.encode(sentences=[request.text], normalize_embeddings=request.normalize)
    embedding = embeddings[0]

    # numpy array to float list
    embedding_as_float = embedding.tolist()

    return EmbeddingResponse(
        model=request.model,
        embedding=embedding_as_float,
        dimension=sentence_transformer_model.get_sentence_embedding_dimension()
    )

リクエストにはテキスト埋め込みに使うモデルと対象のテキスト、正規化の有無を

class EmbeddingRequest(BaseModel):
    model: str
    text: str
    normalize: bool = False

レスポンスにはリクエストで指定されたモデル、テキスト埋め込みの結果、ベクトルの次元数を返すことにしました。

class EmbeddingResponse(BaseModel):
    model: str
    embedding: list[float]
    dimension: int

APIの実装はこんな感じですね。

@app.post("/embeddings/encode")
def encode(request: EmbeddingRequest) -> EmbeddingResponse:
    sentence_transformer_model = SentenceTransformer(
        request.model,
        device=os.getenv("EMBEDDING_API_DEVICE", "cpu")
    )

    embeddings = sentence_transformer_model.encode(sentences=[request.text], normalize_embeddings=request.normalize)
    embedding = embeddings[0]

    # numpy array to float list
    embedding_as_float = embedding.tolist()

    return EmbeddingResponse(
        model=request.model,
        embedding=embedding_as_float,
        dimension=sentence_transformer_model.get_sentence_embedding_dimension()
    )

モデルは、実行時に自動的にHugging Face Hubからダウンロードしてきます。

numpyの配列をリストに変換する必要があったところが困ったくらいですね…。

起動。

$ uvicorn api:app

# または
$ uvicorn api:app --reload

確認。

$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World"}' | jq
{
  "model": "all-MiniLM-L6-v2",
  "embedding": [
    -0.03447727486491203,
    0.03102317824959755,
    0.006734995171427727,
    0.026108944788575172,
    -0.039361994713544846,

    〜省略〜

    0.03323201462626457,
    0.02379228174686432,
    -0.022889817133545876,
    0.03893755003809929,
    0.0302068330347538
  ],
  "dimension": 384
}

もうひとつ、モデルを変更して確認してみましょう。

$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World"}' | jq
{
  "model": "intfloat/multilingual-e5-base",
  "embedding": [
    0.03324141725897789,
    0.04988044500350952,
    0.00241446984000504,
    0.011555945500731468,
    0.03409387916326523,

    〜省略〜

    -0.018477996811270714,
    0.04818818345665932,
    -0.04364151135087013,
    -0.04888230562210083,
    0.03604992479085922
  ],
  "dimension": 768
}

OKですね。

正規化を使うリクエストはこうなのですが、結果は変わりませんでした…。

$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World", "normalize": true}' | jq


$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World", "normalize": true}' | jq

あとはテストを書いておきます。

test_api.py

from fastapi.testclient import TestClient
from api import app, EmbeddingRequest, EmbeddingResponse

client = TestClient(app)

def test_encode_basic():
    request = EmbeddingRequest(model="all-MiniLM-L6-v2", text="Hello World")
    raw_response = client.post("/embeddings/encode", json=request.model_dump())

    assert raw_response.status_code == 200

    response = EmbeddingResponse.model_validate(raw_response.json())
    assert response.model == "all-MiniLM-L6-v2"
    assert len(response.embedding) == 384
    assert response.dimension == 384

def test_encode_e5():
    request = EmbeddingRequest(model="intfloat/multilingual-e5-base", text="passage: Hello World")
    raw_response = client.post("/embeddings/encode", json=request.model_dump())

    assert raw_response.status_code == 200

    response = EmbeddingResponse.model_validate(raw_response.json())
    assert response.model == "intfloat/multilingual-e5-base"
    assert len(response.embedding) == 768
    assert response.dimension == 768

参考にしたのはこちらのページと

Testing - FastAPI

こちら。

JSON - Pydantic

Pydanticはあまり見ていなかったので、ちょっと手間取りました…。

確認。

$ pytest
===================================================================================== test session starts ======================================================================================
platform linux -- Python 3.10.12, pytest-8.2.0, pluggy-1.5.0
rootdir: /path/to
plugins: anyio-4.3.0
collected 2 items

test_api.py ..                                                                                                                                                                           [100%]

====================================================================================== 2 passed in 9.16s =======================================================================================

OKですね。

おわりに

FastAPIとSentence Transformersを使って、簡単なテキスト埋め込みAPIを作成してみました。

特にPython以外でテキスト埋め込みをやりたいと思った時に、どうやってテキスト埋め込みを行うかにちょっと困っていたので、こうやって
自分で作ったものを使ってみてもいいかなと。

FastAPIのちょっとした勉強にもなりました。