CLOVER🍀

That was when it all began.

RESTEasyでCORSの設定をFilterで行う

これは、なにをしたくて書いたもの?

そういえば、Jakarta EE(Java EE)でCORSの設定をしたことがないなと思いまして。

実装方法はいろいろあると思うのですが、すでに用意されているものとかないのかなと思って少し見ていたら、RESTEasyにあったので
軽く試してみることにしました。

RESTEasyのCorsFilter

CORS(オリジン間リソース共有)および使用するヘッダーについては、MDNを参照ということで。

オリジン間リソース共有 (CORS) - HTTP | MDN

RESTEasyが提供するCorsFilterについてのドキュメントはこちらです。

Chapter 32. CORS

ほぼ内容がありませんが。設定内容についてはJavadocを見てね、という感じです。

CorsFilter (RESTEasy 6.2.8.Final API)

設定をしたら、Applicationにシングルトンとして登録します。

You must allocate this and register it as a singleton provider from your Application class.

これ以上の説明はないので、実際に試してみましょう。

Java SE環境でSeBootstrapを使い、簡単なJakarta RESTful Web Services(JAX-RS)のリソースクラスを作成して確認することにします。

環境

今回の環境はこちら。

$ java --version
openjdk 21.0.2 2024-01-16
OpenJDK Runtime Environment (build 21.0.2+13-Ubuntu-122.04.1)
OpenJDK 64-Bit Server VM (build 21.0.2+13-Ubuntu-122.04.1, mixed mode, sharing)


$ mvn --version
Apache Maven 3.9.6 (bc0240f3c744dd6b6ec2920b3cd08dcc295161ae)
Maven home: $HOME/.sdkman/candidates/maven/current
Java version: 21.0.2, vendor: Private Build, runtime: /usr/lib/jvm/java-21-openjdk-amd64
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.15.0-105-generic", arch: "amd64", family: "unix"

サンプルコードを用意する

それでは、まずはサンプルとなるソースコードを用意します。

Maven依存関係など。

    <properties>
        <maven.compiler.release>21</maven.compiler.release>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.jboss.resteasy</groupId>
            <artifactId>resteasy-core</artifactId>
            <version>6.2.8.Final</version>
        </dependency>
        <dependency>
            <groupId>org.jboss.resteasy</groupId>
            <artifactId>resteasy-undertow-cdi</artifactId>
            <version>6.2.8.Final</version>
        </dependency>
    </dependencies>

今回使用するCorsFilterresteasy-coreに含まれています。Java SE環境でRESTEasyを起動するためのresteasy-undertow-cdiにも
含まれているのですが、明示的に指定することにしました。

なんとなくbeans.xmlも用意しておきます。中身は空で。

src/main/resources/META-INF/beans.xml




Applicationのサブクラス。

src/main/java/org/littlewings/resteasy/cors/RestApplication.java

package org.littlewings.resteasy.cors;

import jakarta.ws.rs.ApplicationPath;
import jakarta.ws.rs.core.Application;
import org.jboss.resteasy.plugins.interceptors.CorsFilter;

import java.util.Set;

@ApplicationPath("")
public class RestApplication extends Application {
    @Override
    public Set<Class<?>> getClasses() {
        return Set.of(HelloResource.class);
    }

    @Override
    public Set<Object> getSingletons() {
        CorsFilter corsFilter = new CorsFilter();
        // 後で

        return Set.of(corsFilter);
    }
}

JAX-RSリソースクラス。

src/main/java/org/littlewings/resteasy/cors/HelloResource.java

package org.littlewings.resteasy.cors;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.PreMatching;
import jakarta.ws.rs.core.MediaType;

@Path("hello")
public class HelloResource {
    @GET
    @Produces(MediaType.TEXT_PLAIN)
    public String message() {
        return "Hello World!!";
    }
}

mainクラス。

src/main/java/org/littlewings/resteasy/cors/App.java

package org.littlewings.resteasy.cors;

import jakarta.ws.rs.SeBootstrap;
import org.jboss.logging.Logger;

import java.util.concurrent.ExecutionException;

public class App {
    public static void main(String... args) throws ExecutionException, InterruptedException {
        Logger logger = Logger.getLogger(App.class);

        SeBootstrap.Configuration configuration =
                SeBootstrap
                        .Configuration
                        .builder()
                        .host("0.0.0.0")
                        .port(8080)
                        .build();

        SeBootstrap.Instance instance =
                SeBootstrap
                        .start(new RestApplication(), configuration)
                        .toCompletableFuture()
                        .get();

        logger.info("server startup.");
        System.console().readLine("> Enter stop.");

        instance
                .stop()
                .toCompletableFuture()
                .get();
    }
}

これで準備は完了です。

CorsFilterを設定する

では、Applicationのサブクラスに書いていたCorsFilterの設定をします。

今回はこんな感じにしました。

    @Override
    public Set<Object> getSingletons() {
        CorsFilter corsFilter = new CorsFilter();
        corsFilter.getAllowedOrigins().add("http://localhost:3000");
        corsFilter.setAllowedMethods("POST, GET, PUT, DELETE, OPTIONS");
        corsFilter.setAllowedHeaders("Content-Type, Origin, Authorization");
        corsFilter.setCorsMaxAge(86400);
        corsFilter.setAllowCredentials(true);

        return Set.of(corsFilter);
    }

Access-Control-Allow-Originに対する設定方法がないのでは…?と思ったのですが、Setを取得してこちらに直接追加するみたいです…。

そういえば、ドキュメントにもそう書いていました…。

CorsFilter filter = new CorsFilter();
filter.getAllowedOrigins().add("http://localhost");

Chapter 32. CORS

その他はsetterがあるんですけどね。

CorsFilter (RESTEasy 6.2.8.Final API)

確認する

準備はできたので、確認してみます。

起動。

$ mvn compile exec:java -Dexec.mainClass=org.littlewings.resteasy.cors.App

起動しました。

INFO: server startup.
> Enter stop.

設定したヘッダーがレスポンスに含まれるかどうか、OPTIONSで確認。

$ curl -i -XOPTIONS -H 'Origin: http://localhost:3000' localhost:8080/hello
HTTP/1.1 200 OK
Connection: keep-alive
Access-Control-Allow-Origin: http://localhost:3000
Vary: Origin
Access-Control-Allow-Credentials: true
Content-Length: 0
Access-Control-Max-Age: 86400
Date: Sat, 04 May 2024 12:02:17 GMT

指定のオリジンからアクセス。

$ curl -i -H 'Origin: http://localhost:3000' localhost:8080/hello
HTTP/1.1 200 OK
Connection: keep-alive
Access-Control-Allow-Origin: http://localhost:3000
Vary: Origin
Access-Control-Allow-Credentials: true
Content-Type: text/plain;charset=UTF-8
Content-Length: 13
Date: Sat, 04 May 2024 12:03:05 GMT

Hello World!!

許可していないオリジンからのアクセス。

$ curl -i -H 'Origin: http://example.com' localhost:8080/hello
HTTP/1.1 403 Forbidden
Connection: keep-alive
Content-Length: 0
Date: Sat, 04 May 2024 12:03:41 GMT

拒否されました。

OKですね。

少しポイントを

CorsFilterJAX-RSContainerRequestFilterContainerResponseFilterとして実装されています。

CorsFilter (RESTEasy 6.2.8.Final API)

このフィルターには@PreMatchingというアノテーションが付与されているのですが、これはリソースメソッドにマッチする前に
適用することを指示するもののようです。

As stated above, a ContainerRequestFilter that is annotated with @PreMatching is executed upon receiving a client request but before a resource method is matched. Thus, this type of filter has the ability to modify the input to the matching algorithm (see Request Matching) and, consequently, alter its outcome.

Jakarta RESTful Web Services / Filters and Interceptors / Filters

ところで、今回のようなCorsFilterApplication#getSingletonsを使って登録するしかなさそうなのですが、現在このメソッドは
非推奨になっているようです。

Deprecated. Automatic discovery of resources and providers or the getClasses method is preferred over getSingletons.

Application#getSingletons)

今後はどうしたらいいんでしょうね…。

WildFlyでの設定は?

今回はRESTEasyが実装しているクラスを使いましたが、WildFlyでやろうとするとundertowサブシステムのfilterを使って登録することに
なるようです。

WildFly Full Model Reference / subsystem=undertow / configuration=filter / response-header

おわりに

RESTEasyの提供するCorsFilterを使ってCORSの設定をしてみました。

本当、実装方法としてはいろいろと考えられると思うのですが、どうするのが一般的(?)なんでしょうね。

FastAPIとSentence Transformersを使って簡単なテキスト埋め込みAPIを作成する

これは、なにをしたくて書いたもの?

テキスト埋め込みを行うにはSentence Transformersを使うのがいいのかなと思っているのですが、できれば単体で動作するサーバーとして
使いたいなと。

これをやろうとするとLocalAIを使うのが1番近い気がするのですが、準備にかなり手間がかかります。

Embeddings / Huggingface embeddings

じゃあもういっそのこと簡単なAPIサーバーを自分で作ったらいいかなということで、作ることにしました。

Sentence Transformersのインストールには時間がかかるのですが、それさえできてしまえばテキスト埋め込みを動かすのにそれほど
大量のリソースは要らないので。

FastAPIで作る

お題としては「Sentence Transformersの機能を使ったテキスト埋め込みが行えるREST API」です。

FastAPIで作るのがよいかなと。

FastAPI

簡単にテストまで行うことにしました。

Testing - FastAPI

環境

今回の環境はこちら。

$ python3 --version
Python 3.10.12


$ pip3 --version
pip 22.0.2 from /usr/lib/python3/dist-packages/pip (python 3.10)

FastAPIでSentence Transformersを使ったテキスト埋め込みAPIを作る

まずはライブラリーのインストール。ASGIサーバーはUvicornを使うことにします。

$ pip3 install sentence-transformers fastapi uvicorn[standard]

テスト向けのライブラリーもインストール。

$ pip3 install pytest httpx

インストールしたライブラリーの一覧はこちら。

$ pip3 list
Package                  Version
------------------------ ----------
annotated-types          0.6.0
anyio                    4.3.0
certifi                  2024.2.2
charset-normalizer       3.3.2
click                    8.1.7
dnspython                2.6.1
email_validator          2.1.1
exceptiongroup           1.2.1
fastapi                  0.111.0
fastapi-cli              0.0.3
filelock                 3.14.0
fsspec                   2024.5.0
h11                      0.14.0
httpcore                 1.0.5
httptools                0.6.1
httpx                    0.27.0
huggingface-hub          0.23.0
idna                     3.7
iniconfig                2.0.0
Jinja2                   3.1.4
joblib                   1.4.2
markdown-it-py           3.0.0
MarkupSafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
networkx                 3.3
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
orjson                   3.10.3
packaging                24.0
pillow                   10.3.0
pip                      22.0.2
pluggy                   1.5.0
pydantic                 2.7.1
pydantic_core            2.18.2
Pygments                 2.18.0
pytest                   8.2.0
python-dotenv            1.0.1
python-multipart         0.0.9
PyYAML                   6.0.1
regex                    2024.5.15
requests                 2.31.0
rich                     13.7.1
safetensors              0.4.3
scikit-learn             1.4.2
scipy                    1.13.0
sentence-transformers    2.7.0
setuptools               59.6.0
shellingham              1.5.4
sniffio                  1.3.1
starlette                0.37.2
sympy                    1.12
threadpoolctl            3.5.0
tokenizers               0.19.1
tomli                    2.0.1
torch                    2.3.0
tqdm                     4.66.4
transformers             4.40.2
triton                   2.3.0
typer                    0.12.3
typing_extensions        4.11.0
ujson                    5.10.0
urllib3                  2.2.1
uvicorn                  0.29.0
uvloop                   0.19.0
watchfiles               0.21.0
websockets               12.0

作成したソースコードはこちら。

api.py

from fastapi import FastAPI
from pydantic import BaseModel
import os
from sentence_transformers import SentenceTransformer

app = FastAPI()

class EmbeddingRequest(BaseModel):
    model: str
    text: str
    normalize: bool = False

class EmbeddingResponse(BaseModel):
    model: str
    embedding: list[float]
    dimention: int

@app.post("/embeddings/encode")
def encode(request: EmbeddingRequest) -> EmbeddingResponse:
    sentence_transformer_model = SentenceTransformer(
        request.model,
        device=os.getenv("EMBEDDING_API_DEVICE", "cpu")
    )

    embeddings = sentence_transformer_model.encode(sentences=[request.text], normalize_embeddings=request.normalize)
    embedding = embeddings[0]

    # numpy array to float list
    embedding_as_float = embedding.tolist()

    return EmbeddingResponse(
        model=request.model,
        embedding=embedding_as_float,
        dimention=sentence_transformer_model.get_sentence_embedding_dimension()
    )

リクエストにはテキスト埋め込みに使うモデルと対象のテキスト、正規化の有無を

class EmbeddingRequest(BaseModel):
    model: str
    text: str
    normalize: bool = False

レスポンスにはリクエストで指定されたモデル、テキスト埋め込みの結果、ベクトルの次元数を返すことにしました。

class EmbeddingResponse(BaseModel):
    model: str
    embedding: list[float]
    dimention: int

APIの実装はこんな感じですね。

@app.post("/embeddings/encode")
def encode(request: EmbeddingRequest) -> EmbeddingResponse:
    sentence_transformer_model = SentenceTransformer(
        request.model,
        device=os.getenv("EMBEDDING_API_DEVICE", "cpu")
    )

    embeddings = sentence_transformer_model.encode(sentences=[request.text], normalize_embeddings=request.normalize)
    embedding = embeddings[0]

    # numpy array to float list
    embedding_as_float = embedding.tolist()

    return EmbeddingResponse(
        model=request.model,
        embedding=embedding_as_float,
        dimention=sentence_transformer_model.get_sentence_embedding_dimension()
    )

モデルは、実行時に自動的にHugging Face Hubからダウンロードしてきます。

numpyの配列をリストに変換する必要があったところが困ったくらいですね…。

起動。

$ uvicorn api:app

# または
$ uvicorn api:app --reload

確認。

$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World"}' | jq
{
  "model": "all-MiniLM-L6-v2",
  "embedding": [
    -0.03447727486491203,
    0.03102317824959755,
    0.006734995171427727,
    0.026108944788575172,
    -0.039361994713544846,

    〜省略〜

    0.03323201462626457,
    0.02379228174686432,
    -0.022889817133545876,
    0.03893755003809929,
    0.0302068330347538
  ],
  "dimention": 384
}

もうひとつ、モデルを変更して確認してみましょう。

$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World"}' | jq
{
  "model": "intfloat/multilingual-e5-base",
  "embedding": [
    0.03324141725897789,
    0.04988044500350952,
    0.00241446984000504,
    0.011555945500731468,
    0.03409387916326523,

    〜省略〜

    -0.018477996811270714,
    0.04818818345665932,
    -0.04364151135087013,
    -0.04888230562210083,
    0.03604992479085922
  ],
  "dimention": 768
}

OKですね。

正規化を使うリクエストはこうなのですが、結果は変わりませんでした…。

$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "all-MiniLM-L6-v2", "text": "Hello World", "normalize": true}' | jq


$ curl -s -XPOST -H 'Content-Type: application/json' localhost:8000/embeddings/encode -d '{"model": "intfloat/multilingual-e5-base", "text": "query: Hello World", "normalize": true}' | jq

あとはテストを書いておきます。

test_api.py

from fastapi.testclient import TestClient
from api import app, EmbeddingRequest, EmbeddingResponse

client = TestClient(app)

def test_encode_basic():
    request = EmbeddingRequest(model="all-MiniLM-L6-v2", text="Hello World")
    raw_response = client.post("/embeddings/encode", json=request.model_dump())

    assert raw_response.status_code == 200

    response = EmbeddingResponse.model_validate(raw_response.json())
    assert response.model == "all-MiniLM-L6-v2"
    assert len(response.embedding) == 384
    assert response.dimention == 384

def test_encode_e5():
    request = EmbeddingRequest(model="intfloat/multilingual-e5-base", text="passage: Hello World")
    raw_response = client.post("/embeddings/encode", json=request.model_dump())

    assert raw_response.status_code == 200

    response = EmbeddingResponse.model_validate(raw_response.json())
    assert response.model == "intfloat/multilingual-e5-base"
    assert len(response.embedding) == 768
    assert response.dimention == 768

参考にしたのはこちらのページと

Testing - FastAPI

こちら。

JSON - Pydantic

Pydanticはあまり見ていなかったので、ちょっと手間取りました…。

確認。

$ pytest
===================================================================================== test session starts ======================================================================================
platform linux -- Python 3.10.12, pytest-8.2.0, pluggy-1.5.0
rootdir: /path/to
plugins: anyio-4.3.0
collected 2 items

test_api.py ..                                                                                                                                                                           [100%]

====================================================================================== 2 passed in 9.16s =======================================================================================

OKですね。

おわりに

FastAPIとSentence Transformersを使って、簡単なテキスト埋め込みAPIを作成してみました。

特にPython以外でテキスト埋め込みをやりたいと思った時に、どうやってテキスト埋め込みを行うかにちょっと困っていたので、こうやって
自分で作ったものを使ってみてもいいかなと。

FastAPIのちょっとした勉強にもなりました。