CLOVER🍀

That was when it all began.

Spring WebFluxでProxyサーバーを書いてみる

ちょっとしたお題がありまして、Spring WebFluxでProxyサーバーを書いてみました。

お題は、こんな感じで。

  • Reverse Proxy
  • 緩いProxyサーバー(X-Forwarded〜とかは気にしない)
  • ほぼなにも考えず、バックエンドへのリクエストとレスポンスの内容を転送する
  • GETとPOSTをとりあえず対応

Spring WebFlux+WebClientで、ノンブロッキングに書いてみようという話です。

Web on Reactive Stack

Proxyサーバーのスペックとして、それなりに頑張ろうと思ったらJettyあたりを参考にするとよいのではないでしょうか。
https://github.com/eclipse/jetty.project/blob/jetty-9.4.9.v20180320/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ProxyServlet.java
https://github.com/eclipse/jetty.project/blob/jetty-9.4.9.v20180320/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AbstractProxyServlet.java

環境

今回の動作環境は、こんな感じで。

$ java -version
openjdk version "1.8.0_162"
OpenJDK Runtime Environment (build 1.8.0_162-8u162-b12-0ubuntu0.16.04.2-b12)
OpenJDK 64-Bit Server VM (build 25.162-b12, mixed mode)

$ mvn -version
Apache Maven 3.5.3 (3383c37e1f9e9b3bc3df5050c29c8aff9f295297; 2018-02-25T04:49:05+09:00)
Maven home: /usr/local/maven3/current
Java version: 1.8.0_162, vendor: Oracle Corporation
Java home: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "4.4.0-104-generic", arch: "amd64", family: "unix"

バックエンドのサーバー

プロキシ先のバックエンドのサーバーは、簡単にServletで書くことにします。

Maven依存関係。

        <dependency>
            <groupId>javax.servlet</groupId>
            <artifactId>javax.servlet-api</artifactId>
            <version>3.1.0</version>
            <scope>provided</scope>
        </dependency>

受け取ったリクエストの内容を、Bodyに書き出すようにします。
src/main/java/org/littlewings/servlet/SimpleServlet.java

package org.littlewings.servlet;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collections;
import java.util.List;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

@WebServlet("/*")
public class SimpleServlet extends HttpServlet {
    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
        execute(request, response);
    }

    @Override
    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
        execute(request, response);
    }

    void execute(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
        request.setCharacterEncoding("UTF-8");

        response.setContentType("text/plain");
        response.setCharacterEncoding("UTF-8");

        response.addHeader("X-Custom-Header", "Header-Value");

        PrintWriter writer = response.getWriter();

        writer.println("========================================");
        writer.println("Request Method:");
        writer.println("  " + request.getMethod());

        writer.println("========================================");
        writer.println("Request URI:");
        String queryString = request.getQueryString();
        if (queryString != null) {
            writer.println("  " + request.getRequestURI() + "?" + queryString);
        } else {
            writer.println("  " + request.getRequestURI());
        }

        writer.println("========================================");
        writer.println("Request Headers:");
        Collections.list(request.getHeaderNames()).forEach(name -> {
            List<String> values = Collections.list(request.getHeaders(name));
            writer.println("  " + name + ":" + values);
        });

        writer.println("========================================");
        writer.println("Request Body:");
        BufferedReader reader = request.getReader();
        int c;
        while ((c = reader.read()) != -1) {
            writer.print((char)c);
        }
    }
}

あと、ひとつだけ、独自のヘッダーを入れていたり。

        response.addHeader("X-Custom-Header", "Header-Value");

まあ、いずれもクライアントからのリクエストの内容が到達しているか、クライアントへバックエンドのサーバーの
レスポンスが転送できているか、という確認の内容ですね。

これをパッケージングして、今回はApache Tomcat 9.0.7にデプロイして起動させておきます。

この後作成する、Proxyサーバーと同じホストで起動するので、利用するポートは18080としておきました。

準備

では、Spring WebFluxで作るProxyサーバーの準備を。

Maven依存関係は、こちら。

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-dependencies</artifactId>
                <version>2.0.1.RELEASE</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-webflux</artifactId>
        </dependency>
    </dependencies>

Proxy Controller

Proxyとして振る舞うプログラムは、RestControllerとして作成しました。

src/main/java/org/littlewings/spring/webflux/proxy/ProxyController.java

package org.littlewings.spring.webflux.proxy;

import java.net.ConnectException;

import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@RestController
public class ProxyController {
    @RequestMapping(value = "/**", method = {RequestMethod.GET, RequestMethod.POST})
    public Flux<DataBuffer> proxy(ServerWebExchange exchange) {
        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();

        String remoteHost = "localhost";
        int remotePort = 18080;

        WebClient proxyClient =
                WebClient
                        .builder()
                        .baseUrl(String.format("http://%s:%d", remoteHost, remotePort))
                        .build();

        Mono<ClientResponse> monoRemoteResponse =
                proxyClient
                        .method(request.getMethod())
                        .uri(uriBuilder ->
                                uriBuilder
                                        .path(request.getPath().value())
                                        .queryParams(request.getQueryParams())
                                        .build()
                        )
                        .headers(headers -> {
                            HttpHeaders requestHeaders = request.getHeaders();
                            requestHeaders.forEach(headers::addAll);
                        })
                        .body(request.getBody(), DataBuffer.class)
                        .exchange();

        Flux<ClientResponse> remoteResponse =
                Flux.from(monoRemoteResponse);

        return remoteResponse
                .flatMap(remoteClientResponse -> {
                    response.setStatusCode(remoteClientResponse.statusCode());
                    ClientResponse.Headers remoteHeaders = remoteClientResponse.headers();
                    remoteHeaders.asHttpHeaders().forEach((name, values) -> {
                        if ("Content-Type".equalsIgnoreCase(name)) {
                            // デフォルトの「text/event-stream」を潰す
                            response.getHeaders().set(name, values.get(0));
                        } else {
                            response.getHeaders().addAll(name, values);
                        }
                    });

                    return remoteClientResponse.bodyToFlux(DataBuffer.class);
                })
                .doOnError(throwable -> {
                    if (throwable.getCause() != null && throwable.getCause() instanceof ConnectException) {
                        response.setStatusCode(HttpStatus.BAD_GATEWAY);
                        response.getHeaders().set("Content-Type", "text/plain");
                    } else {
                        response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR);
                        response.getHeaders().set("Content-Type", "text/plain");
                    }
                })
                .onErrorReturn(response.bufferFactory().wrap(new byte[0]));
    }
}

メソッドとしては、@RequestMappingのvalueを「/**」としてどのパスでも受け付けるようにして、引数はServerWebExchange、戻り値はFluxという
感じで。

    @RequestMapping(value = "/**", method = {RequestMethod.GET, RequestMethod.POST})
    public Flux<DataBuffer> proxy(ServerWebExchange exchange) {

サーバーとして使うリクエスト、レスポンスはServerWebExchangeからそれぞれServerHttpRequest、ServerHttpResponseとして取得することができます。

        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();

リクエスト転送先のホスト、ポートをベースにしてWebClientを作成します。

        String remoteHost = "localhost";
        int remotePort = 18080;

        WebClient proxyClient =
                WebClient
                        .builder()
                        .baseUrl(String.format("http://%s:%d", remoteHost, remotePort))
                        .build();

それから、ServerHttpRequestの内容をWebClientが送信するリクエストの内容にコピーしていきます。データは、特にここで変更したりしないので
DataBufferとして受け取るように作成。

        Mono<ClientResponse> monoRemoteResponse =
                proxyClient
                        .method(request.getMethod())
                        .uri(uriBuilder ->
                                uriBuilder
                                        .path(request.getPath().value())
                                        .queryParams(request.getQueryParams())
                                        .build()
                        )
                        .headers(headers -> {
                            HttpHeaders requestHeaders = request.getHeaders();
                            requestHeaders.forEach(headers::addAll);
                        })
                        .body(request.getBody(), DataBuffer.class)
                        .exchange();

WebClientの使い方は、ここを見つつ…。
WebClient

ここで返ってくるのはMonoなのですが、それなりにデータが大きいケースも考えるとFluxにした方がいいのではという気がしたので、
MonoをFluxにしておきます。

        Flux<ClientResponse> remoteResponse =
                Flux.from(monoRemoteResponse);

あとは、レスポンスの内容をクライアントに返すように作成。

        return remoteResponse
                .flatMap(remoteClientResponse -> {
                    response.setStatusCode(remoteClientResponse.statusCode());
                    ClientResponse.Headers remoteHeaders = remoteClientResponse.headers();
                    remoteHeaders.asHttpHeaders().forEach((name, values) -> {
                        if ("Content-Type".equalsIgnoreCase(name)) {
                            // デフォルトの「text/event-stream」を潰す
                            response.getHeaders().set(name, values.get(0));
                        } else {
                            response.getHeaders().addAll(name, values);
                        }
                    });

                    return remoteClientResponse.bodyToFlux(DataBuffer.class);
                })
                .doOnError(throwable -> {
                    if (throwable.getCause() != null && throwable.getCause() instanceof ConnectException) {
                        response.setStatusCode(HttpStatus.BAD_GATEWAY);
                        response.getHeaders().set("Content-Type", "text/plain");
                    } else {
                        response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR);
                        response.getHeaders().set("Content-Type", "text/plain");
                    }
                })
                .onErrorReturn(response.bufferFactory().wrap(new byte[0]));

ヘッダーについては、「Content-Type」のみsetするようにして、あとはひたすらaddAll。

                    response.setStatusCode(remoteClientResponse.statusCode());
                    ClientResponse.Headers remoteHeaders = remoteClientResponse.headers();
                    remoteHeaders.asHttpHeaders().forEach((name, values) -> {
                        if ("Content-Type".equalsIgnoreCase(name)) {
                            // デフォルトの「text/event-stream」を潰す
                            response.getHeaders().set(name, values.get(0));
                        } else {
                            response.getHeaders().addAll(name, values);
                        }
                    });

データは、DataBufferのFluxとして返すようにしました。

                    return remoteClientResponse.bodyToFlux(DataBuffer.class);

エラーケースについては、簡単に済ませています…。

起動クラス

Spring Bootアプリケーションとしての起動クラスは、こんな感じで簡単に。
src/main/java/org/littlewings/spring/webflux/proxy/App.java

package org.littlewings.spring.webflux.proxy;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class App {
    public static void main(String... args) {
        SpringApplication.run(App.class, args);
    }
}

確認

それでは、確認してみましょう。

Proxyとして作成したWebFluxを使ったアプリケーション、最初に作ったServletをデプロイしたApache Tomcatは起動済みとします。

GET。

$ curl -i -H 'X-Test-Header: sample1' http://localhost:8080/path1/path2?query-param=value
HTTP/1.1 200 OK
Content-Type: text/plain;charset=UTF-8
X-Custom-Header: Header-Value
Content-Length: 380
Date: Sun, 08 Apr 2018 12:13:59 GMT

========================================
Request Method:
  GET
========================================
Request URI:
  /path1/path2?query-param=value
========================================
Request Headers:
  accept-encoding:[gzip]
  host:[localhost:8080]
  user-agent:[curl/7.47.0]
  accept:[*/*]
  x-test-header:[sample1]
========================================
Request Body:

POST。

$ curl -i -H 'X-Test-Header: sample1' -H 'Content-Type: application/json' http://localhost:8080/path1/path2?query-param=value -d '{"param": "json-value"}'
HTTP/1.1 200 OK
Content-Type: text/plain;charset=UTF-8
X-Custom-Header: Header-Value
Content-Length: 460
Date: Sun, 08 Apr 2018 12:16:18 GMT

========================================
Request Method:
  POST
========================================
Request URI:
  /path1/path2?query-param=value
========================================
Request Headers:
  accept-encoding:[gzip]
  host:[localhost:8080]
  user-agent:[curl/7.47.0]
  accept:[*/*]
  x-test-header:[sample1]
  content-type:[application/json]
  content-length:[23]
========================================
Request Body:
{"param": "json-value"}

動いてそうですね。

まとめ

Spring WebFlux+WebClientを使って、簡単なProxyサーバーを書いてみました。

実はSpring WebFlux(Spring Boot 2/Spring Framework 5)を使うのは初めてだったので、けっこう勉強になりました。

オマケ

最初、このお題をReactor Netty単体でやろうとしたのですが、見事に挫折しました…。

Handlerをどう書いたらいいのかイマイチよく分からなくて、打開策としてSpring WebFluxに移ったらまああっさりと
うまくいきまして…。

追うなら、このあたりなのでしょうね。
https://github.com/spring-projects/spring-framework/blob/v5.0.5.RELEASE/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorHttpHandlerAdapter.java
https://github.com/spring-projects/spring-framework/blob/v5.0.5.RELEASE/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java

参考)
はじめてのSpring WebFlux (その1.5 - Spring Bootを使わずSpring WebFluxをマニュアルでBootstrap)