Fix Rest client of policy agent
[nonrtric.git] / policy-agent / src / main / java / org / oransc / policyagent / clients / AsyncRestClient.java
index cefc7ca..4bc29ee 100644 (file)
@@ -29,6 +29,7 @@ import io.netty.handler.timeout.WriteTimeoutHandler;
 
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.InputStream;
 import java.lang.invoke.MethodHandles;
 import java.security.KeyStore;
 import java.security.KeyStoreException;
@@ -41,7 +42,8 @@ import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 
-import org.oransc.policyagent.configuration.ImmutableWebClientConfig;
+import javax.net.ssl.KeyManagerFactory;
+
 import org.oransc.policyagent.configuration.WebClientConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -50,12 +52,14 @@ import org.springframework.http.ResponseEntity;
 import org.springframework.http.client.reactive.ReactorClientHttpConnector;
 import org.springframework.lang.Nullable;
 import org.springframework.util.ResourceUtils;
+import org.springframework.web.reactive.function.client.ExchangeStrategies;
 import org.springframework.web.reactive.function.client.WebClient;
 import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec;
 import org.springframework.web.reactive.function.client.WebClientResponseException;
 
 import reactor.core.publisher.Mono;
 import reactor.netty.http.client.HttpClient;
+import reactor.netty.resources.ConnectionProvider;
 import reactor.netty.tcp.TcpClient;
 
 /**
@@ -67,10 +71,12 @@ public class AsyncRestClient {
     private final String baseUrl;
     private static final AtomicInteger sequenceNumber = new AtomicInteger();
     private final WebClientConfig clientConfig;
+    static KeyStore clientTrustStore = null;
+    private boolean sslEnabled = true;
 
     public AsyncRestClient(String baseUrl) {
-        this(baseUrl,
-            ImmutableWebClientConfig.builder().isTrustStoreUsed(false).trustStore("").trustStorePassword("").build());
+        this(baseUrl, null);
+        this.sslEnabled = false;
     }
 
     public AsyncRestClient(String baseUrl, WebClientConfig config) {
@@ -176,9 +182,10 @@ public class AsyncRestClient {
     }
 
     private Mono<ResponseEntity<String>> retrieve(Object traceTag, RequestHeadersSpec<?> request) {
+        final Class<String> clazz = String.class;
         return request.retrieve() //
-            .toEntity(String.class) //
-            .doOnNext(entity -> logger.trace("{} Received: {}", traceTag, entity.getBody()))
+            .toEntity(clazz) //
+            .doOnNext(entity -> logger.trace("{} Received: {}", traceTag, entity.getBody())) //
             .doOnError(throwable -> onHttpError(traceTag, throwable));
     }
 
@@ -192,7 +199,7 @@ public class AsyncRestClient {
             logger.debug("{} HTTP error status = '{}', body '{}'", traceTag, exception.getStatusCode(),
                 exception.getResponseBodyAsString());
         } else {
-            logger.debug("{} HTTP error: {}", traceTag, t.getMessage());
+            logger.debug("{} HTTP error", traceTag, t);
         }
     }
 
@@ -222,12 +229,21 @@ public class AsyncRestClient {
         }
     }
 
-    SslContext createSslContextSecure(String trustStorePath, String trustStorePass)
+    private static synchronized KeyStore getTrustStore(String trustStorePath, String trustStorePass)
         throws NoSuchAlgorithmException, CertificateException, IOException, KeyStoreException {
+        if (clientTrustStore == null) {
+            KeyStore store = KeyStore.getInstance(KeyStore.getDefaultType());
+            store.load(new FileInputStream(ResourceUtils.getFile(trustStorePath)), trustStorePass.toCharArray());
+            clientTrustStore = store;
+        }
+        return clientTrustStore;
+    }
 
-        final KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
-        trustStore.load(new FileInputStream(ResourceUtils.getFile(trustStorePath)), trustStorePass.toCharArray());
+    private SslContext createSslContextRejectingUntrustedPeers(String trustStorePath, String trustStorePass,
+        KeyManagerFactory keyManager)
+        throws NoSuchAlgorithmException, CertificateException, IOException, KeyStoreException {
 
+        final KeyStore trustStore = getTrustStore(trustStorePath, trustStorePass);
         List<Certificate> certificateList = Collections.list(trustStore.aliases()).stream() //
             .filter(alias -> isCertificateEntry(trustStore, alias)) //
             .map(alias -> getCertificate(trustStore, alias)) //
@@ -235,43 +251,78 @@ public class AsyncRestClient {
         final X509Certificate[] certificates = certificateList.toArray(new X509Certificate[certificateList.size()]);
 
         return SslContextBuilder.forClient() //
+            .keyManager(keyManager) //
             .trustManager(certificates) //
             .build();
     }
 
-    private SslContext createSslContext()
+    private SslContext createSslContext(KeyManagerFactory keyManager)
         throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException {
         if (this.clientConfig.isTrustStoreUsed()) {
-            return createSslContextSecure(this.clientConfig.trustStore(), this.clientConfig.trustStorePassword());
+            return createSslContextRejectingUntrustedPeers(this.clientConfig.trustStore(),
+                this.clientConfig.trustStorePassword(), keyManager);
         } else {
+            // Trust anyone
             return SslContextBuilder.forClient() //
+                .keyManager(keyManager) //
                 .trustManager(InsecureTrustManagerFactory.INSTANCE) //
                 .build();
         }
     }
 
-    private WebClient createWebClient(String baseUrl, SslContext sslContext) {
-        TcpClient tcpClient = TcpClient.create() //
+    private TcpClient createTcpClientSecure(SslContext sslContext) {
+        return TcpClient.create(ConnectionProvider.newConnection()) //
             .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) //
             .secure(c -> c.sslContext(sslContext)) //
             .doOnConnected(connection -> {
                 connection.addHandlerLast(new ReadTimeoutHandler(30));
                 connection.addHandlerLast(new WriteTimeoutHandler(30));
             });
+    }
+
+    private TcpClient createTcpClientInsecure() {
+        return TcpClient.create(ConnectionProvider.newConnection()) //
+            .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) //
+            .doOnConnected(connection -> {
+                connection.addHandlerLast(new ReadTimeoutHandler(30));
+                connection.addHandlerLast(new WriteTimeoutHandler(30));
+            });
+    }
+
+    private WebClient createWebClient(String baseUrl, TcpClient tcpClient) {
         HttpClient httpClient = HttpClient.from(tcpClient);
         ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient);
-
+        ExchangeStrategies exchangeStrategies = ExchangeStrategies.builder() //
+            .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(-1)) //
+            .build();
         return WebClient.builder() //
             .clientConnector(connector) //
             .baseUrl(baseUrl) //
+            .exchangeStrategies(exchangeStrategies) //
             .build();
     }
 
     private Mono<WebClient> getWebClient() {
         if (this.webClient == null) {
             try {
-                SslContext sslContext = createSslContext();
-                this.webClient = createWebClient(this.baseUrl, sslContext);
+                if (this.sslEnabled) {
+                    final KeyManagerFactory keyManager =
+                        KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+                    final KeyStore keyStore = KeyStore.getInstance(this.clientConfig.keyStoreType());
+                    final String keyStoreFile = this.clientConfig.keyStore();
+                    final String keyStorePassword = this.clientConfig.keyStorePassword();
+                    final String keyPassword = this.clientConfig.keyPassword();
+                    try (final InputStream inputStream = new FileInputStream(keyStoreFile)) {
+                        keyStore.load(inputStream, keyStorePassword.toCharArray());
+                    }
+                    keyManager.init(keyStore, keyPassword.toCharArray());
+                    SslContext sslContext = createSslContext(keyManager);
+                    TcpClient tcpClient = createTcpClientSecure(sslContext);
+                    this.webClient = createWebClient(this.baseUrl, tcpClient);
+                } else {
+                    TcpClient tcpClient = createTcpClientInsecure();
+                    this.webClient = createWebClient(this.baseUrl, tcpClient);
+                }
             } catch (Exception e) {
                 logger.error("Could not create WebClient {}", e.getMessage());
                 return Mono.error(e);