X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=blobdiff_plain;f=policy-agent%2Fsrc%2Fmain%2Fjava%2Forg%2Foransc%2Fpolicyagent%2Fclients%2FAsyncRestClient.java;h=4bc29ee0452eb1fec7b75a62b583b97374797b60;hb=48ae0f39d9d03cc1ec976762e6f7400447ace0a4;hp=2435ef083c886747df531ae762bb421cfb77b187;hpb=6a8a0d5350a77b6d1e4a8f95c0fe8fbfeef77339;p=nonrtric.git diff --git a/policy-agent/src/main/java/org/oransc/policyagent/clients/AsyncRestClient.java b/policy-agent/src/main/java/org/oransc/policyagent/clients/AsyncRestClient.java index 2435ef08..4bc29ee0 100644 --- a/policy-agent/src/main/java/org/oransc/policyagent/clients/AsyncRestClient.java +++ b/policy-agent/src/main/java/org/oransc/policyagent/clients/AsyncRestClient.java @@ -20,35 +20,83 @@ package org.oransc.policyagent.clients; +import io.netty.channel.ChannelOption; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.timeout.ReadTimeoutHandler; +import io.netty.handler.timeout.WriteTimeoutHandler; + +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; import java.lang.invoke.MethodHandles; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import javax.net.ssl.KeyManagerFactory; +import org.oransc.policyagent.configuration.WebClientConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.lang.Nullable; +import org.springframework.util.ResourceUtils; +import org.springframework.web.reactive.function.client.ExchangeStrategies; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec; +import org.springframework.web.reactive.function.client.WebClientResponseException; + import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; +import reactor.netty.resources.ConnectionProvider; +import reactor.netty.tcp.TcpClient; +/** + * Generic reactive REST client. + */ public class AsyncRestClient { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); - private final WebClient client; + private WebClient webClient = null; private final String baseUrl; + private static final AtomicInteger sequenceNumber = new AtomicInteger(); + private final WebClientConfig clientConfig; + static KeyStore clientTrustStore = null; + private boolean sslEnabled = true; public AsyncRestClient(String baseUrl) { - this.client = WebClient.create(baseUrl); + this(baseUrl, null); + this.sslEnabled = false; + } + + public AsyncRestClient(String baseUrl, WebClientConfig config) { this.baseUrl = baseUrl; + this.clientConfig = config; } public Mono> postForEntity(String uri, @Nullable String body) { - logger.debug("POST uri = '{}{}''", baseUrl, uri); + Object traceTag = createTraceTag(); + logger.debug("{} POST uri = '{}{}''", traceTag, baseUrl, uri); + logger.trace("{} POST body: {}", traceTag, body); Mono bodyProducer = body != null ? Mono.just(body) : Mono.empty(); - RequestHeadersSpec request = client.post() // - .uri(uri) // - .contentType(MediaType.APPLICATION_JSON) // - .body(bodyProducer, String.class); - return retrieve(request); + return getWebClient() // + .flatMap(client -> { + RequestHeadersSpec request = client.post() // + .uri(uri) // + .contentType(MediaType.APPLICATION_JSON) // + .body(bodyProducer, String.class); + return retrieve(traceTag, request); + }); } public Mono post(String uri, @Nullable String body) { @@ -57,23 +105,45 @@ public class AsyncRestClient { } public Mono postWithAuthHeader(String uri, String body, String username, String password) { - logger.debug("POST (auth) uri = '{}{}''", baseUrl, uri); - RequestHeadersSpec request = client.post() // - .uri(uri) // - .headers(headers -> headers.setBasicAuth(username, password)) // - .contentType(MediaType.APPLICATION_JSON) // - .bodyValue(body); - return retrieve(request) // - .flatMap(this::toBody); + Object traceTag = createTraceTag(); + logger.debug("{} POST (auth) uri = '{}{}''", traceTag, baseUrl, uri); + logger.trace("{} POST body: {}", traceTag, body); + return getWebClient() // + .flatMap(client -> { + RequestHeadersSpec request = client.post() // + .uri(uri) // + .headers(headers -> headers.setBasicAuth(username, password)) // + .contentType(MediaType.APPLICATION_JSON) // + .bodyValue(body); + return retrieve(traceTag, request) // + .flatMap(this::toBody); + }); } public Mono> putForEntity(String uri, String body) { - logger.debug("PUT uri = '{}{}''", baseUrl, uri); - RequestHeadersSpec request = client.put() // - .uri(uri) // - .contentType(MediaType.APPLICATION_JSON) // - .bodyValue(body); - return retrieve(request); + Object traceTag = createTraceTag(); + logger.debug("{} PUT uri = '{}{}''", traceTag, baseUrl, uri); + logger.trace("{} PUT body: {}", traceTag, body); + return getWebClient() // + .flatMap(client -> { + RequestHeadersSpec request = client.put() // + .uri(uri) // + .contentType(MediaType.APPLICATION_JSON) // + .bodyValue(body); + return retrieve(traceTag, request); + }); + } + + public Mono> putForEntity(String uri) { + Object traceTag = createTraceTag(); + logger.debug("{} PUT uri = '{}{}''", traceTag, baseUrl, uri); + logger.trace("{} PUT body: ", traceTag); + return getWebClient() // + .flatMap(client -> { + RequestHeadersSpec request = client.put() // + .uri(uri); + return retrieve(traceTag, request); + }); } public Mono put(String uri, String body) { @@ -82,9 +152,13 @@ public class AsyncRestClient { } public Mono> getForEntity(String uri) { - logger.debug("GET uri = '{}{}''", baseUrl, uri); - RequestHeadersSpec request = client.get().uri(uri); - return retrieve(request); + Object traceTag = createTraceTag(); + logger.debug("{} GET uri = '{}{}''", traceTag, baseUrl, uri); + return getWebClient() // + .flatMap(client -> { + RequestHeadersSpec request = client.get().uri(uri); + return retrieve(traceTag, request); + }); } public Mono get(String uri) { @@ -93,9 +167,13 @@ public class AsyncRestClient { } public Mono> deleteForEntity(String uri) { - logger.debug("DELETE uri = '{}{}''", baseUrl, uri); - RequestHeadersSpec request = client.delete().uri(uri); - return retrieve(request); + Object traceTag = createTraceTag(); + logger.debug("{} DELETE uri = '{}{}''", traceTag, baseUrl, uri); + return getWebClient() // + .flatMap(client -> { + RequestHeadersSpec request = client.delete().uri(uri); + return retrieve(traceTag, request); + }); } public Mono delete(String uri) { @@ -103,12 +181,29 @@ public class AsyncRestClient { .flatMap(this::toBody); } - private Mono> retrieve(RequestHeadersSpec request) { + private Mono> retrieve(Object traceTag, RequestHeadersSpec request) { + final Class clazz = String.class; return request.retrieve() // - .toEntity(String.class); + .toEntity(clazz) // + .doOnNext(entity -> logger.trace("{} Received: {}", traceTag, entity.getBody())) // + .doOnError(throwable -> onHttpError(traceTag, throwable)); + } + + private static Object createTraceTag() { + return sequenceNumber.incrementAndGet(); } - Mono toBody(ResponseEntity entity) { + private void onHttpError(Object traceTag, Throwable t) { + if (t instanceof WebClientResponseException) { + WebClientResponseException exception = (WebClientResponseException) t; + logger.debug("{} HTTP error status = '{}', body '{}'", traceTag, exception.getStatusCode(), + exception.getResponseBodyAsString()); + } else { + logger.debug("{} HTTP error", traceTag, t); + } + } + + private Mono toBody(ResponseEntity entity) { if (entity.getBody() == null) { return Mono.just(""); } else { @@ -116,4 +211,124 @@ public class AsyncRestClient { } } + private boolean isCertificateEntry(KeyStore trustStore, String alias) { + try { + return trustStore.isCertificateEntry(alias); + } catch (KeyStoreException e) { + logger.error("Error reading truststore {}", e.getMessage()); + return false; + } + } + + private Certificate getCertificate(KeyStore trustStore, String alias) { + try { + return trustStore.getCertificate(alias); + } catch (KeyStoreException e) { + logger.error("Error reading truststore {}", e.getMessage()); + return null; + } + } + + private static synchronized KeyStore getTrustStore(String trustStorePath, String trustStorePass) + throws NoSuchAlgorithmException, CertificateException, IOException, KeyStoreException { + if (clientTrustStore == null) { + KeyStore store = KeyStore.getInstance(KeyStore.getDefaultType()); + store.load(new FileInputStream(ResourceUtils.getFile(trustStorePath)), trustStorePass.toCharArray()); + clientTrustStore = store; + } + return clientTrustStore; + } + + private SslContext createSslContextRejectingUntrustedPeers(String trustStorePath, String trustStorePass, + KeyManagerFactory keyManager) + throws NoSuchAlgorithmException, CertificateException, IOException, KeyStoreException { + + final KeyStore trustStore = getTrustStore(trustStorePath, trustStorePass); + List certificateList = Collections.list(trustStore.aliases()).stream() // + .filter(alias -> isCertificateEntry(trustStore, alias)) // + .map(alias -> getCertificate(trustStore, alias)) // + .collect(Collectors.toList()); + final X509Certificate[] certificates = certificateList.toArray(new X509Certificate[certificateList.size()]); + + return SslContextBuilder.forClient() // + .keyManager(keyManager) // + .trustManager(certificates) // + .build(); + } + + private SslContext createSslContext(KeyManagerFactory keyManager) + throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException { + if (this.clientConfig.isTrustStoreUsed()) { + return createSslContextRejectingUntrustedPeers(this.clientConfig.trustStore(), + this.clientConfig.trustStorePassword(), keyManager); + } else { + // Trust anyone + return SslContextBuilder.forClient() // + .keyManager(keyManager) // + .trustManager(InsecureTrustManagerFactory.INSTANCE) // + .build(); + } + } + + private TcpClient createTcpClientSecure(SslContext sslContext) { + return TcpClient.create(ConnectionProvider.newConnection()) // + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) // + .secure(c -> c.sslContext(sslContext)) // + .doOnConnected(connection -> { + connection.addHandlerLast(new ReadTimeoutHandler(30)); + connection.addHandlerLast(new WriteTimeoutHandler(30)); + }); + } + + private TcpClient createTcpClientInsecure() { + return TcpClient.create(ConnectionProvider.newConnection()) // + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) // + .doOnConnected(connection -> { + connection.addHandlerLast(new ReadTimeoutHandler(30)); + connection.addHandlerLast(new WriteTimeoutHandler(30)); + }); + } + + private WebClient createWebClient(String baseUrl, TcpClient tcpClient) { + HttpClient httpClient = HttpClient.from(tcpClient); + ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient); + ExchangeStrategies exchangeStrategies = ExchangeStrategies.builder() // + .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(-1)) // + .build(); + return WebClient.builder() // + .clientConnector(connector) // + .baseUrl(baseUrl) // + .exchangeStrategies(exchangeStrategies) // + .build(); + } + + private Mono getWebClient() { + if (this.webClient == null) { + try { + if (this.sslEnabled) { + final KeyManagerFactory keyManager = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + final KeyStore keyStore = KeyStore.getInstance(this.clientConfig.keyStoreType()); + final String keyStoreFile = this.clientConfig.keyStore(); + final String keyStorePassword = this.clientConfig.keyStorePassword(); + final String keyPassword = this.clientConfig.keyPassword(); + try (final InputStream inputStream = new FileInputStream(keyStoreFile)) { + keyStore.load(inputStream, keyStorePassword.toCharArray()); + } + keyManager.init(keyStore, keyPassword.toCharArray()); + SslContext sslContext = createSslContext(keyManager); + TcpClient tcpClient = createTcpClientSecure(sslContext); + this.webClient = createWebClient(this.baseUrl, tcpClient); + } else { + TcpClient tcpClient = createTcpClientInsecure(); + this.webClient = createWebClient(this.baseUrl, tcpClient); + } + } catch (Exception e) { + logger.error("Could not create WebClient {}", e.getMessage()); + return Mono.error(e); + } + } + return Mono.just(this.webClient); + } + }