001/** 002 * Copyright (C) 2006-2023 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.component.server.configuration; 017 018import static java.util.Collections.emptyList; 019import static java.util.Collections.emptyMap; 020import static java.util.Collections.singletonList; 021import static java.util.Locale.ENGLISH; 022 023import java.io.ByteArrayInputStream; 024import java.io.IOException; 025import java.io.InputStream; 026import java.io.OutputStream; 027import java.nio.ByteBuffer; 028import java.nio.charset.StandardCharsets; 029import java.util.Collection; 030import java.util.Comparator; 031import java.util.HashMap; 032import java.util.LinkedList; 033import java.util.List; 034import java.util.Map; 035import java.util.Queue; 036import java.util.Set; 037import java.util.Spliterator; 038import java.util.Spliterators; 039import java.util.TreeMap; 040import java.util.logging.Logger; 041import java.util.stream.Stream; 042import java.util.stream.StreamSupport; 043 044import javax.enterprise.context.Dependent; 045import javax.enterprise.inject.Instance; 046import javax.inject.Inject; 047import javax.servlet.ServletConfig; 048import javax.servlet.ServletContext; 049import javax.servlet.ServletContextEvent; 050import javax.servlet.ServletContextListener; 051import javax.servlet.ServletException; 052import javax.servlet.annotation.WebListener; 053import javax.servlet.http.HttpServletRequest; 054import javax.servlet.http.HttpServletResponse; 055import javax.websocket.CloseReason; 056import javax.websocket.DeploymentException; 057import javax.websocket.Endpoint; 058import javax.websocket.EndpointConfig; 059import javax.websocket.MessageHandler.Partial; 060import javax.websocket.Session; 061import javax.websocket.server.ServerContainer; 062import javax.websocket.server.ServerEndpointConfig; 063import javax.ws.rs.ApplicationPath; 064import javax.ws.rs.core.Application; 065import javax.ws.rs.core.HttpHeaders; 066import javax.xml.namespace.QName; 067 068import org.apache.cxf.Bus; 069import org.apache.cxf.common.logging.LogUtils; 070import org.apache.cxf.continuations.Continuation; 071import org.apache.cxf.continuations.ContinuationCallback; 072import org.apache.cxf.continuations.ContinuationProvider; 073import org.apache.cxf.endpoint.ServerRegistry; 074import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean; 075import org.apache.cxf.message.ExchangeImpl; 076import org.apache.cxf.message.Message; 077import org.apache.cxf.message.MessageImpl; 078import org.apache.cxf.service.model.EndpointInfo; 079import org.apache.cxf.transport.AbstractDestination; 080import org.apache.cxf.transport.Conduit; 081import org.apache.cxf.transport.MessageObserver; 082import org.apache.cxf.transport.http.AbstractHTTPDestination; 083import org.apache.cxf.transport.http.ContinuationProviderFactory; 084import org.apache.cxf.transport.http.DestinationRegistry; 085import org.apache.cxf.transport.http.HTTPSession; 086import org.apache.cxf.transport.servlet.ServletController; 087import org.apache.cxf.transport.servlet.ServletDestination; 088import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet; 089import org.apache.cxf.transports.http.configuration.HTTPServerPolicy; 090import org.apache.cxf.ws.addressing.EndpointReferenceType; 091import org.talend.sdk.component.server.front.cxf.CxfExtractor; 092import org.talend.sdk.component.server.front.memory.InMemoryRequest; 093import org.talend.sdk.component.server.front.memory.InMemoryResponse; 094import org.talend.sdk.component.server.front.memory.MemoryInputStream; 095import org.talend.sdk.component.server.front.memory.SimpleServletConfig; 096 097import lombok.Data; 098import lombok.EqualsAndHashCode; 099import lombok.RequiredArgsConstructor; 100import lombok.extern.slf4j.Slf4j; 101 102// ensure any JAX-RS command can use websockets 103@Slf4j 104@Dependent 105@WebListener 106public class WebSocketBroadcastSetup implements ServletContextListener { 107 108 private static final String EOM = "^@"; 109 110 @Inject 111 private Bus bus; 112 113 @Inject 114 private CxfExtractor cxf; 115 116 @Inject 117 private Instance<Application> applications; 118 119 @Override 120 public void contextInitialized(final ServletContextEvent sce) { 121 final ServerContainer container = 122 ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName())); 123 124 final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class 125 .cast(bus 126 .getExtension(ServerRegistry.class) 127 .getServers() 128 .iterator() 129 .next() 130 .getEndpoint() 131 .get(JAXRSServiceFactoryBean.class.getName())); 132 133 final String appBase = StreamSupport 134 .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false) 135 .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class)) 136 .map(a -> a.getClass().getAnnotation(ApplicationPath.class)) 137 .map(ApplicationPath::value) 138 .findFirst() 139 .map(s -> !s.startsWith("/") ? "/" + s : s) 140 .orElse("/api/v1"); 141 final String version = appBase.replaceFirst("/api", ""); 142 143 final DestinationRegistry registry = cxf.getRegistry(); 144 final ServletContext servletContext = sce.getServletContext(); 145 146 final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry); 147 final ServletController controller = new ServletController(webSocketRegistry, 148 new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"), 149 new ServiceListGeneratorServlet(registry, bus)); 150 webSocketRegistry.controller = controller; 151 152 Stream 153 .concat(factory 154 .getClassResourceInfo() 155 .stream() 156 .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream()) 157 .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend.")) 158 .map(ori -> { 159 final String uri = ori.getClassResourceInfo().getURITemplate().getValue() 160 + ori.getURITemplate().getValue(); 161 return ServerEndpointConfig.Builder 162 .create(Endpoint.class, 163 "/websocket" + version + "/" 164 + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri) 165 .configurator(new ServerEndpointConfig.Configurator() { 166 167 @Override 168 public <T> T getEndpointInstance(final Class<T> clazz) 169 throws InstantiationException { 170 final Map<String, List<String>> headers = new HashMap<>(); 171 if (!ori.getProduceTypes().isEmpty()) { 172 headers 173 .put(HttpHeaders.CONTENT_TYPE, singletonList( 174 ori.getProduceTypes().iterator().next().toString())); 175 } 176 if (!ori.getConsumeTypes().isEmpty()) { 177 headers 178 .put(HttpHeaders.ACCEPT, singletonList( 179 ori.getConsumeTypes().iterator().next().toString())); 180 } 181 return (T) new JAXRSEndpoint(appBase, controller, servletContext, 182 ori.getHttpMethod(), uri, headers); 183 } 184 }) 185 .build(); 186 }), 187 Stream 188 .of(ServerEndpointConfig.Builder 189 .create(Endpoint.class, "/websocket" + version + "/bus") 190 .configurator(new ServerEndpointConfig.Configurator() { 191 192 @Override 193 public <T> T getEndpointInstance(final Class<T> clazz) 194 throws InstantiationException { 195 196 return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET", 197 "/", emptyMap()); 198 } 199 }) 200 .build())) 201 .sorted(Comparator.comparing(ServerEndpointConfig::getPath)) 202 .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath())) 203 .forEach(config -> { 204 try { 205 container.addEndpoint(config); 206 } catch (final DeploymentException e) { 207 throw new IllegalStateException(e); 208 } 209 }); 210 } 211 212 @Data 213 @EqualsAndHashCode(callSuper = false) 214 private static class JAXRSEndpoint extends Endpoint { 215 216 private final String appBase; 217 218 private final ServletController controller; 219 220 private final ServletContext context; 221 222 private final String defaultMethod; 223 224 private final String defaultUri; 225 226 private final Map<String, List<String>> baseHeaders; 227 228 @RequiredArgsConstructor 229 private class PartialMessageHandler implements Partial<byte[]> { 230 231 private final Session session; 232 233 private InMemoryRequest request; 234 235 private InMemoryResponse response; 236 237 private void handleStart(final StringBuilder buffer, final InputStream message) { 238 final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); 239 headers.putAll(baseHeaders); 240 241 try { // read headers from the message 242 String line; 243 int del; 244 while ((line = readLine(buffer, message)) != null) { 245 final boolean done = line.endsWith(EOM); 246 if (done) { 247 line = line.substring(0, line.length() - EOM.length()); 248 } 249 if (!line.isEmpty()) { 250 del = line.indexOf(':'); 251 if (del < 0) { 252 headers.put(line.trim(), emptyList()); 253 } else { 254 headers 255 .put(line.substring(0, del).trim(), 256 singletonList(line.substring(del + 1).trim())); 257 } 258 } 259 if (done) { 260 break; 261 } 262 } 263 } catch (final IOException ioe) { 264 throw new IllegalStateException(ioe); 265 } 266 267 final List<String> uris = headers.get("destination"); 268 final String uri; 269 if (uris == null || uris.isEmpty()) { 270 uri = defaultUri; 271 } else { 272 uri = uris.iterator().next(); 273 } 274 275 final List<String> methods = headers.get("destinationMethod"); 276 final String method; 277 if (methods == null || methods.isEmpty()) { 278 method = defaultMethod; 279 } else { 280 method = methods.iterator().next(); 281 } 282 283 final String queryString; 284 final String path; 285 final int query = uri.indexOf('?'); 286 if (query > 0) { 287 queryString = uri.substring(query + 1); 288 path = uri.substring(0, query); 289 } else { 290 queryString = null; 291 path = uri; 292 } 293 294 request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path, 295 appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message), 296 session::getUserPrincipal, controller); 297 response = new InMemoryResponse(session::isOpen, 298 () -> { 299 if (session.getBasicRemote().getBatchingAllowed()) { 300 try { 301 session.getBasicRemote().flushBatch(); 302 } catch (final IOException e) { 303 throw new IllegalStateException(e); 304 } 305 } 306 }, bytes -> { 307 try { 308 session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes)); 309 } catch (final IOException e) { 310 throw new IllegalStateException(e); 311 } 312 }, (status, responseHeaders) -> { 313 final StringBuilder top = new StringBuilder("MESSAGE\r\n"); 314 top.append("status: ").append(status).append("\r\n"); 315 responseHeaders 316 .forEach((k, 317 v) -> top.append(k) 318 .append(": ") 319 .append(String.join(",", v)) 320 .append("\r\n")); 321 top.append("\r\n");// empty line, means the next bytes are the payload 322 return top.toString(); 323 }) { 324 325 @Override 326 protected void onClose(final OutputStream stream) throws IOException { 327 stream.write(EOM.getBytes(StandardCharsets.UTF_8)); 328 } 329 }; 330 request.setResponse(response); 331 } 332 333 @Override 334 public void onMessage(final byte[] byteBuffer, final boolean last) { 335 final ByteArrayInputStream message = new ByteArrayInputStream(byteBuffer); 336 337 final StringBuilder buffer = new StringBuilder(128); 338 try { // read headers from the message 339 if (request != null) { 340 ((WebSocketInputStream) request.getInputStream()).addStream(message); 341 } else if ("SEND".equalsIgnoreCase(readLine(buffer, message))) { 342 handleStart(buffer, message); 343 } else { 344 throw new IllegalArgumentException("not a message"); 345 } 346 } catch (IOException e) { 347 throw new IllegalStateException(e); 348 } 349 350 if (last) { 351 try { 352 controller.invoke(request, response); 353 } catch (final ServletException e) { 354 throw new IllegalArgumentException(e); 355 } finally { 356 request = null; 357 response = null; 358 } 359 } 360 } 361 } 362 363 @Override 364 public void onOpen(final Session session, final EndpointConfig endpointConfig) { 365 log.debug("Opened session {}", session.getId()); 366 session.addMessageHandler(byte[].class, new PartialMessageHandler(session)); 367 } 368 369 @Override 370 public void onClose(final Session session, final CloseReason closeReason) { 371 log.debug("Closed session {}", session.getId()); 372 } 373 374 @Override 375 public void onError(final Session session, final Throwable throwable) { 376 log.warn("Error for session {}", session.getId(), throwable); 377 } 378 379 private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException { 380 int c; 381 while ((c = in.read()) != -1) { 382 if (c == '\n') { 383 break; 384 } else if (c != '\r') { 385 buffer.append((char) c); 386 } 387 } 388 389 if (buffer.length() == 0) { 390 return null; 391 } 392 final String string = buffer.toString(); 393 buffer.setLength(0); 394 return string; 395 } 396 } 397 398 private static class WebSocketInputStream extends MemoryInputStream { 399 400 private int previous = Integer.MAX_VALUE; 401 402 private final Queue<InputStream> queue = new LinkedList<>(); 403 404 private WebSocketInputStream(final InputStream delegate) { 405 super(delegate); 406 queue.add(delegate); 407 } 408 409 @Override 410 public int read() throws IOException { 411 if (finished) { 412 return -1; 413 } 414 if (previous != Integer.MAX_VALUE) { 415 previous = Integer.MAX_VALUE; 416 return previous; 417 } 418 final int read = delegate().read(); 419 if (read == '^') { 420 previous = delegate().read(); 421 if (previous == '@') { 422 finished = true; 423 return -1; 424 } 425 } 426 if (read < 0) { 427 finished = true; 428 } 429 return read; 430 } 431 432 private InputStream delegate() throws IOException { 433 if (queue.isEmpty()) { 434 throw new IOException("Don't have an input stream."); 435 } 436 437 if (queue.peek().available() == 0) { 438 queue.remove(); 439 } 440 441 return queue.peek(); 442 } 443 444 public void addStream(final InputStream stream) { 445 queue.add(stream); 446 } 447 } 448 449 private static class WebSocketRegistry implements DestinationRegistry { 450 451 private final DestinationRegistry delegate; 452 453 private ServletController controller; 454 455 private WebSocketRegistry(final DestinationRegistry registry) { 456 this.delegate = registry; 457 } 458 459 @Override 460 public void addDestination(final AbstractHTTPDestination destination) { 461 throw new UnsupportedOperationException(); 462 } 463 464 @Override 465 public void removeDestination(final String path) { 466 throw new UnsupportedOperationException(); 467 } 468 469 @Override 470 public AbstractHTTPDestination getDestinationForPath(final String path) { 471 return wrap(delegate.getDestinationForPath(path)); 472 } 473 474 @Override 475 public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) { 476 return wrap(delegate.getDestinationForPath(path, tryDecoding)); 477 } 478 479 @Override 480 public AbstractHTTPDestination checkRestfulRequest(final String address) { 481 return wrap(delegate.checkRestfulRequest(address)); 482 } 483 484 @Override 485 public Collection<AbstractHTTPDestination> getDestinations() { 486 return delegate.getDestinations(); 487 } 488 489 @Override 490 public AbstractDestination[] getSortedDestinations() { 491 return delegate.getSortedDestinations(); 492 } 493 494 @Override 495 public Set<String> getDestinationsPaths() { 496 return delegate.getDestinationsPaths(); 497 } 498 499 private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) { 500 try { 501 return destination == null ? null : new WebSocketDestination(destination, this); 502 } catch (final IOException e) { 503 throw new IllegalStateException(e); 504 } 505 } 506 } 507 508 private static class WebSocketDestination extends AbstractHTTPDestination { 509 510 static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class); 511 512 private final AbstractHTTPDestination delegate; 513 514 private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry) 515 throws IOException { 516 super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false); 517 this.delegate = delegate; 518 this.cproviderFactory = new WebSocketContinuationFactory(registry); 519 } 520 521 @Override 522 public EndpointReferenceType getAddress() { 523 return delegate.getAddress(); 524 } 525 526 @Override 527 public Conduit getBackChannel(final Message inMessage) throws IOException { 528 return delegate.getBackChannel(inMessage); 529 } 530 531 @Override 532 public EndpointInfo getEndpointInfo() { 533 return delegate.getEndpointInfo(); 534 } 535 536 @Override 537 public void shutdown() { 538 throw new UnsupportedOperationException(); 539 } 540 541 @Override 542 public void setMessageObserver(final MessageObserver observer) { 543 throw new UnsupportedOperationException(); 544 } 545 546 @Override 547 public MessageObserver getMessageObserver() { 548 return delegate.getMessageObserver(); 549 } 550 551 @Override 552 protected Logger getLogger() { 553 return LOG; 554 } 555 556 @Override 557 public Bus getBus() { 558 return delegate.getBus(); 559 } 560 561 @Override 562 public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req, 563 final HttpServletResponse resp) throws IOException { 564 // eager create the message to ensure we set our continuation for @Suspended 565 Message inMessage = retrieveFromContinuation(req); 566 if (inMessage == null) { 567 inMessage = new MessageImpl(); 568 569 final ExchangeImpl exchange = new ExchangeImpl(); 570 exchange.setInMessage(inMessage); 571 setupMessage(inMessage, config, context, req, resp); 572 573 exchange.setSession(new HTTPSession(req)); 574 MessageImpl.class.cast(inMessage).setDestination(this); 575 } 576 577 delegate.invoke(config, context, req, resp); 578 } 579 580 @Override 581 public void finalizeConfig() { 582 delegate.finalizeConfig(); 583 } 584 585 @Override 586 public String getBeanName() { 587 return delegate.getBeanName(); 588 } 589 590 @Override 591 public EndpointReferenceType getAddressWithId(final String id) { 592 return delegate.getAddressWithId(id); 593 } 594 595 @Override 596 public String getId(final Map<String, Object> context) { 597 return delegate.getId(context); 598 } 599 600 @Override 601 public String getContextMatchStrategy() { 602 return delegate.getContextMatchStrategy(); 603 } 604 605 @Override 606 public boolean isFixedParameterOrder() { 607 return delegate.isFixedParameterOrder(); 608 } 609 610 @Override 611 public boolean isMultiplexWithAddress() { 612 return delegate.isMultiplexWithAddress(); 613 } 614 615 @Override 616 public HTTPServerPolicy getServer() { 617 return delegate.getServer(); 618 } 619 620 @Override 621 public void assertMessage(final Message message) { 622 delegate.assertMessage(message); 623 } 624 625 @Override 626 public boolean canAssert(final QName type) { 627 return delegate.canAssert(type); 628 } 629 630 @Override 631 public String getPath() { 632 return delegate.getPath(); 633 } 634 } 635 636 private static class WebSocketContinuationFactory implements ContinuationProviderFactory { 637 638 private static final String KEY = WebSocketContinuationFactory.class.getName(); 639 640 private final WebSocketRegistry registry; 641 642 private WebSocketContinuationFactory(final WebSocketRegistry registry) { 643 this.registry = registry; 644 } 645 646 @Override 647 public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req, 648 final HttpServletResponse resp) { 649 return new WebSocketContinuation(inMessage, req, resp, registry); 650 } 651 652 @Override 653 public Message retrieveFromContinuation(final HttpServletRequest req) { 654 return Message.class.cast(req.getAttribute(KEY)); 655 } 656 } 657 658 private static class WebSocketContinuation implements ContinuationProvider, Continuation { 659 660 private final Message message; 661 662 private final HttpServletRequest request; 663 664 private final HttpServletResponse response; 665 666 private final WebSocketRegistry registry; 667 668 private final ContinuationCallback callback; 669 670 private Object object; 671 672 private boolean resumed; 673 674 private boolean pending; 675 676 private boolean isNew; 677 678 private WebSocketContinuation(final Message message, final HttpServletRequest request, 679 final HttpServletResponse response, final WebSocketRegistry registry) { 680 this.message = message; 681 this.request = request; 682 this.response = response; 683 this.registry = registry; 684 this.request 685 .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE, 686 message.getExchange().getInMessage()); 687 this.callback = message.getExchange().get(ContinuationCallback.class); 688 } 689 690 @Override 691 public Continuation getContinuation() { 692 return this; 693 } 694 695 @Override 696 public void complete() { 697 message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE); 698 if (callback != null) { 699 final Exception ex = message.getExchange().get(Exception.class); 700 if (ex == null) { 701 callback.onComplete(); 702 } else { 703 callback.onError(ex); 704 } 705 } 706 try { 707 response.getWriter().close(); 708 } catch (final IOException e) { 709 throw new IllegalStateException(e); 710 } 711 } 712 713 @Override 714 public boolean suspend(final long timeout) { 715 isNew = false; 716 resumed = false; 717 pending = true; 718 message.getExchange().getInMessage().getInterceptorChain().suspend(); 719 return true; 720 } 721 722 @Override 723 public void resume() { 724 resumed = true; 725 try { 726 registry.controller.invoke(request, response); 727 } catch (final ServletException e) { 728 throw new IllegalStateException(e); 729 } 730 } 731 732 @Override 733 public void reset() { 734 pending = false; 735 resumed = false; 736 isNew = false; 737 object = null; 738 } 739 740 @Override 741 public boolean isNew() { 742 return isNew; 743 } 744 745 @Override 746 public boolean isPending() { 747 return pending; 748 } 749 750 @Override 751 public boolean isResumed() { 752 return resumed; 753 } 754 755 @Override 756 public boolean isTimeout() { 757 return false; 758 } 759 760 @Override 761 public Object getObject() { 762 return object; 763 } 764 765 @Override 766 public void setObject(final Object o) { 767 object = o; 768 } 769 770 @Override 771 public boolean isReadyForWrite() { 772 return true; 773 } 774 } 775}