Merge "Two minor bugfies"
[nonrtric.git] / policy-agent / src / main / java / org / oransc / policyagent / clients / AsyncRestClient.java
1 /*-
2  * ========================LICENSE_START=================================
3  * O-RAN-SC
4  * %%
5  * Copyright (C) 2019 Nordix Foundation
6  * %%
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  * ========================LICENSE_END===================================
19  */
20
21 package org.oransc.policyagent.clients;
22
23 import io.netty.channel.ChannelOption;
24 import io.netty.handler.ssl.SslContext;
25 import io.netty.handler.ssl.SslContextBuilder;
26 import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
27 import io.netty.handler.timeout.ReadTimeoutHandler;
28 import io.netty.handler.timeout.WriteTimeoutHandler;
29
30 import java.io.FileInputStream;
31 import java.io.IOException;
32 import java.lang.invoke.MethodHandles;
33 import java.security.KeyStore;
34 import java.security.KeyStoreException;
35 import java.security.NoSuchAlgorithmException;
36 import java.security.cert.Certificate;
37 import java.security.cert.CertificateException;
38 import java.security.cert.X509Certificate;
39 import java.util.Collections;
40 import java.util.List;
41 import java.util.concurrent.atomic.AtomicInteger;
42 import java.util.stream.Collectors;
43
44 import org.oransc.policyagent.configuration.ImmutableWebClientConfig;
45 import org.oransc.policyagent.configuration.WebClientConfig;
46 import org.slf4j.Logger;
47 import org.slf4j.LoggerFactory;
48 import org.springframework.http.MediaType;
49 import org.springframework.http.ResponseEntity;
50 import org.springframework.http.client.reactive.ReactorClientHttpConnector;
51 import org.springframework.lang.Nullable;
52 import org.springframework.util.ResourceUtils;
53 import org.springframework.web.reactive.function.client.ExchangeStrategies;
54 import org.springframework.web.reactive.function.client.WebClient;
55 import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec;
56 import org.springframework.web.reactive.function.client.WebClientResponseException;
57
58 import reactor.core.publisher.Mono;
59 import reactor.netty.http.client.HttpClient;
60 import reactor.netty.tcp.TcpClient;
61
62 /**
63  * Generic reactive REST client.
64  */
65 public class AsyncRestClient {
66     private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
67     private WebClient webClient = null;
68     private final String baseUrl;
69     private static final AtomicInteger sequenceNumber = new AtomicInteger();
70     private final WebClientConfig clientConfig;
71     static KeyStore clientTrustStore = null;
72
73     public AsyncRestClient(String baseUrl) {
74         this(baseUrl,
75             ImmutableWebClientConfig.builder().isTrustStoreUsed(false).trustStore("").trustStorePassword("").build());
76     }
77
78     public AsyncRestClient(String baseUrl, WebClientConfig config) {
79         this.baseUrl = baseUrl;
80         this.clientConfig = config;
81     }
82
83     public Mono<ResponseEntity<String>> postForEntity(String uri, @Nullable String body) {
84         Object traceTag = createTraceTag();
85         logger.debug("{} POST uri = '{}{}''", traceTag, baseUrl, uri);
86         logger.trace("{} POST body: {}", traceTag, body);
87         Mono<String> bodyProducer = body != null ? Mono.just(body) : Mono.empty();
88         return getWebClient() //
89             .flatMap(client -> {
90                 RequestHeadersSpec<?> request = client.post() //
91                     .uri(uri) //
92                     .contentType(MediaType.APPLICATION_JSON) //
93                     .body(bodyProducer, String.class);
94                 return retrieve(traceTag, request);
95             });
96     }
97
98     public Mono<String> post(String uri, @Nullable String body) {
99         return postForEntity(uri, body) //
100             .flatMap(this::toBody);
101     }
102
103     public Mono<String> postWithAuthHeader(String uri, String body, String username, String password) {
104         Object traceTag = createTraceTag();
105         logger.debug("{} POST (auth) uri = '{}{}''", traceTag, baseUrl, uri);
106         logger.trace("{} POST body: {}", traceTag, body);
107         return getWebClient() //
108             .flatMap(client -> {
109                 RequestHeadersSpec<?> request = client.post() //
110                     .uri(uri) //
111                     .headers(headers -> headers.setBasicAuth(username, password)) //
112                     .contentType(MediaType.APPLICATION_JSON) //
113                     .bodyValue(body);
114                 return retrieve(traceTag, request) //
115                     .flatMap(this::toBody);
116             });
117     }
118
119     public Mono<ResponseEntity<String>> putForEntity(String uri, String body) {
120         Object traceTag = createTraceTag();
121         logger.debug("{} PUT uri = '{}{}''", traceTag, baseUrl, uri);
122         logger.trace("{} PUT body: {}", traceTag, body);
123         return getWebClient() //
124             .flatMap(client -> {
125                 RequestHeadersSpec<?> request = client.put() //
126                     .uri(uri) //
127                     .contentType(MediaType.APPLICATION_JSON) //
128                     .bodyValue(body);
129                 return retrieve(traceTag, request);
130             });
131     }
132
133     public Mono<ResponseEntity<String>> putForEntity(String uri) {
134         Object traceTag = createTraceTag();
135         logger.debug("{} PUT uri = '{}{}''", traceTag, baseUrl, uri);
136         logger.trace("{} PUT body: <empty>", traceTag);
137         return getWebClient() //
138             .flatMap(client -> {
139                 RequestHeadersSpec<?> request = client.put() //
140                     .uri(uri);
141                 return retrieve(traceTag, request);
142             });
143     }
144
145     public Mono<String> put(String uri, String body) {
146         return putForEntity(uri, body) //
147             .flatMap(this::toBody);
148     }
149
150     public Mono<ResponseEntity<String>> getForEntity(String uri) {
151         Object traceTag = createTraceTag();
152         logger.debug("{} GET uri = '{}{}''", traceTag, baseUrl, uri);
153         return getWebClient() //
154             .flatMap(client -> {
155                 RequestHeadersSpec<?> request = client.get().uri(uri);
156                 return retrieve(traceTag, request);
157             });
158     }
159
160     public Mono<String> get(String uri) {
161         return getForEntity(uri) //
162             .flatMap(this::toBody);
163     }
164
165     public Mono<ResponseEntity<String>> deleteForEntity(String uri) {
166         Object traceTag = createTraceTag();
167         logger.debug("{} DELETE uri = '{}{}''", traceTag, baseUrl, uri);
168         return getWebClient() //
169             .flatMap(client -> {
170                 RequestHeadersSpec<?> request = client.delete().uri(uri);
171                 return retrieve(traceTag, request);
172             });
173     }
174
175     public Mono<String> delete(String uri) {
176         return deleteForEntity(uri) //
177             .flatMap(this::toBody);
178     }
179
180     private Mono<ResponseEntity<String>> retrieve(Object traceTag, RequestHeadersSpec<?> request) {
181         return request.retrieve() //
182             .toEntity(String.class) //
183             .doOnNext(entity -> logger.trace("{} Received: {}", traceTag, entity.getBody())) //
184             .doOnError(throwable -> onHttpError(traceTag, throwable));
185     }
186
187     private static Object createTraceTag() {
188         return sequenceNumber.incrementAndGet();
189     }
190
191     private void onHttpError(Object traceTag, Throwable t) {
192         if (t instanceof WebClientResponseException) {
193             WebClientResponseException exception = (WebClientResponseException) t;
194             logger.debug("{} HTTP error status = '{}', body '{}'", traceTag, exception.getStatusCode(),
195                 exception.getResponseBodyAsString());
196         } else {
197             logger.debug("{} HTTP error: {}", traceTag, t.getMessage());
198         }
199     }
200
201     private Mono<String> toBody(ResponseEntity<String> entity) {
202         if (entity.getBody() == null) {
203             return Mono.just("");
204         } else {
205             return Mono.just(entity.getBody());
206         }
207     }
208
209     private boolean isCertificateEntry(KeyStore trustStore, String alias) {
210         try {
211             return trustStore.isCertificateEntry(alias);
212         } catch (KeyStoreException e) {
213             logger.error("Error reading truststore {}", e.getMessage());
214             return false;
215         }
216     }
217
218     private Certificate getCertificate(KeyStore trustStore, String alias) {
219         try {
220             return trustStore.getCertificate(alias);
221         } catch (KeyStoreException e) {
222             logger.error("Error reading truststore {}", e.getMessage());
223             return null;
224         }
225     }
226
227     private static synchronized KeyStore getTrustStore(String trustStorePath, String trustStorePass)
228         throws NoSuchAlgorithmException, CertificateException, IOException, KeyStoreException {
229         if (clientTrustStore == null) {
230             KeyStore store = KeyStore.getInstance(KeyStore.getDefaultType());
231             store.load(new FileInputStream(ResourceUtils.getFile(trustStorePath)), trustStorePass.toCharArray());
232             clientTrustStore = store;
233         }
234         return clientTrustStore;
235     }
236
237     private SslContext createSslContextRejectingUntrustedPeers(String trustStorePath, String trustStorePass)
238         throws NoSuchAlgorithmException, CertificateException, IOException, KeyStoreException {
239
240         final KeyStore trustStore = getTrustStore(trustStorePath, trustStorePass);
241         List<Certificate> certificateList = Collections.list(trustStore.aliases()).stream() //
242             .filter(alias -> isCertificateEntry(trustStore, alias)) //
243             .map(alias -> getCertificate(trustStore, alias)) //
244             .collect(Collectors.toList());
245         final X509Certificate[] certificates = certificateList.toArray(new X509Certificate[certificateList.size()]);
246
247         return SslContextBuilder.forClient() //
248             .trustManager(certificates) //
249             .build();
250     }
251
252     private SslContext createSslContext()
253         throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException {
254         if (this.clientConfig.isTrustStoreUsed()) {
255             return createSslContextRejectingUntrustedPeers(this.clientConfig.trustStore(),
256                 this.clientConfig.trustStorePassword());
257         } else {
258             // Trust anyone
259             return SslContextBuilder.forClient() //
260                 .trustManager(InsecureTrustManagerFactory.INSTANCE) //
261                 .build();
262         }
263     }
264
265     private WebClient createWebClient(String baseUrl, SslContext sslContext) {
266         TcpClient tcpClient = TcpClient.create() //
267             .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) //
268             .secure(c -> c.sslContext(sslContext)) //
269             .doOnConnected(connection -> {
270                 connection.addHandlerLast(new ReadTimeoutHandler(30));
271                 connection.addHandlerLast(new WriteTimeoutHandler(30));
272             });
273         HttpClient httpClient = HttpClient.from(tcpClient);
274         ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient);
275
276         ExchangeStrategies exchangeStrategies = ExchangeStrategies.builder() //
277             .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(-1)) //
278             .build();
279
280         return WebClient.builder() //
281             .clientConnector(connector) //
282             .baseUrl(baseUrl) //
283             .exchangeStrategies(exchangeStrategies) //
284             .build();
285     }
286
287     private Mono<WebClient> getWebClient() {
288         if (this.webClient == null) {
289             try {
290                 SslContext sslContext = createSslContext();
291                 this.webClient = createWebClient(this.baseUrl, sslContext);
292             } catch (Exception e) {
293                 logger.error("Could not create WebClient {}", e.getMessage());
294                 return Mono.error(e);
295             }
296         }
297         return Mono.just(this.webClient);
298     }
299
300 }