001/** 002 * Copyright (C) 2006-2025 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.component.server.configuration; 017 018import static java.util.Collections.emptyList; 019import static java.util.Collections.emptyMap; 020import static java.util.Collections.singletonList; 021import static java.util.Locale.ENGLISH; 022 023import java.io.ByteArrayInputStream; 024import java.io.IOException; 025import java.io.InputStream; 026import java.io.OutputStream; 027import java.nio.ByteBuffer; 028import java.nio.charset.StandardCharsets; 029import java.util.Collection; 030import java.util.Comparator; 031import java.util.HashMap; 032import java.util.LinkedList; 033import java.util.List; 034import java.util.Map; 035import java.util.Queue; 036import java.util.Set; 037import java.util.Spliterator; 038import java.util.Spliterators; 039import java.util.TreeMap; 040import java.util.logging.Logger; 041import java.util.stream.Stream; 042import java.util.stream.StreamSupport; 043 044import javax.enterprise.context.Dependent; 045import javax.enterprise.inject.Instance; 046import javax.inject.Inject; 047import javax.servlet.ServletConfig; 048import javax.servlet.ServletContext; 049import javax.servlet.ServletContextEvent; 050import javax.servlet.ServletContextListener; 051import javax.servlet.ServletException; 052import javax.servlet.annotation.WebListener; 053import javax.servlet.http.HttpServletRequest; 054import javax.servlet.http.HttpServletResponse; 055import javax.websocket.CloseReason; 056import javax.websocket.DeploymentException; 057import javax.websocket.Endpoint; 058import javax.websocket.EndpointConfig; 059import javax.websocket.MessageHandler.Partial; 060import javax.websocket.Session; 061import javax.websocket.server.ServerContainer; 062import javax.websocket.server.ServerEndpointConfig; 063import javax.ws.rs.ApplicationPath; 064import javax.ws.rs.core.Application; 065import javax.ws.rs.core.HttpHeaders; 066import javax.xml.namespace.QName; 067 068import org.apache.cxf.Bus; 069import org.apache.cxf.common.logging.LogUtils; 070import org.apache.cxf.continuations.Continuation; 071import org.apache.cxf.continuations.ContinuationCallback; 072import org.apache.cxf.continuations.ContinuationProvider; 073import org.apache.cxf.endpoint.ServerRegistry; 074import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean; 075import org.apache.cxf.message.ExchangeImpl; 076import org.apache.cxf.message.Message; 077import org.apache.cxf.message.MessageImpl; 078import org.apache.cxf.service.model.EndpointInfo; 079import org.apache.cxf.transport.AbstractDestination; 080import org.apache.cxf.transport.Conduit; 081import org.apache.cxf.transport.MessageObserver; 082import org.apache.cxf.transport.http.AbstractHTTPDestination; 083import org.apache.cxf.transport.http.ContinuationProviderFactory; 084import org.apache.cxf.transport.http.DestinationRegistry; 085import org.apache.cxf.transport.http.HTTPSession; 086import org.apache.cxf.transport.servlet.ServletController; 087import org.apache.cxf.transport.servlet.ServletDestination; 088import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet; 089import org.apache.cxf.transports.http.configuration.HTTPServerPolicy; 090import org.apache.cxf.ws.addressing.EndpointReferenceType; 091import org.talend.sdk.component.server.front.cxf.CxfExtractor; 092import org.talend.sdk.component.server.front.memory.InMemoryRequest; 093import org.talend.sdk.component.server.front.memory.InMemoryResponse; 094import org.talend.sdk.component.server.front.memory.MemoryInputStream; 095import org.talend.sdk.component.server.front.memory.SimpleServletConfig; 096 097import lombok.Data; 098import lombok.EqualsAndHashCode; 099import lombok.RequiredArgsConstructor; 100import lombok.extern.slf4j.Slf4j; 101 102// ensure any JAX-RS command can use websockets 103@Slf4j 104@Dependent 105@WebListener 106public class WebSocketBroadcastSetup implements ServletContextListener { 107 108 private static final String EOM = "^@"; 109 110 @Inject 111 private Bus bus; 112 113 @Inject 114 private CxfExtractor cxf; 115 116 @Inject 117 private Instance<Application> applications; 118 119 @Override 120 public void contextInitialized(final ServletContextEvent sce) { 121 final ServerContainer container = 122 ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName())); 123 124 final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class 125 .cast(bus 126 .getExtension(ServerRegistry.class) 127 .getServers() 128 .iterator() 129 .next() 130 .getEndpoint() 131 .get(JAXRSServiceFactoryBean.class.getName())); 132 133 final String appBase = StreamSupport 134 .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false) 135 .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class)) 136 .map(a -> a.getClass().getAnnotation(ApplicationPath.class)) 137 .map(ApplicationPath::value) 138 .findFirst() 139 .map(s -> !s.startsWith("/") ? "/" + s : s) 140 .orElse("/api/v1"); 141 final String version = appBase.replaceFirst("/api", ""); 142 143 final DestinationRegistry registry = cxf.getRegistry(); 144 final ServletContext servletContext = sce.getServletContext(); 145 146 final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry); 147 final ServletController controller = new ServletController(webSocketRegistry, 148 new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"), 149 new ServiceListGeneratorServlet(registry, bus)); 150 webSocketRegistry.controller = controller; 151 152 Stream 153 .concat(factory 154 .getClassResourceInfo() 155 .stream() 156 .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream()) 157 .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend.")) 158 .map(ori -> { 159 final String uri = ori.getClassResourceInfo().getURITemplate().getValue() 160 + ori.getURITemplate().getValue(); 161 return ServerEndpointConfig.Builder 162 .create(Endpoint.class, 163 "/websocket" + version + "/" 164 + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri) 165 .configurator(new ServerEndpointConfig.Configurator() { 166 167 @Override 168 public <T> T getEndpointInstance(final Class<T> clazz) 169 throws InstantiationException { 170 final Map<String, List<String>> headers = new HashMap<>(); 171 if (!ori.getProduceTypes().isEmpty()) { 172 headers 173 .put(HttpHeaders.CONTENT_TYPE, singletonList( 174 ori.getProduceTypes().iterator().next().toString())); 175 } 176 if (!ori.getConsumeTypes().isEmpty()) { 177 headers 178 .put(HttpHeaders.ACCEPT, singletonList( 179 ori.getConsumeTypes().iterator().next().toString())); 180 } 181 return (T) new JAXRSEndpoint(appBase, controller, servletContext, 182 ori.getHttpMethod(), uri, headers); 183 } 184 }) 185 .build(); 186 }), 187 Stream 188 .of(ServerEndpointConfig.Builder 189 .create(Endpoint.class, "/websocket" + version + "/bus") 190 .configurator(new ServerEndpointConfig.Configurator() { 191 192 @Override 193 public <T> T getEndpointInstance(final Class<T> clazz) 194 throws InstantiationException { 195 196 return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET", 197 "/", emptyMap()); 198 } 199 }) 200 .build())) 201 .sorted(Comparator.comparing(ServerEndpointConfig::getPath)) 202 .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath())) 203 .forEach(config -> { 204 try { 205 container.addEndpoint(config); 206 } catch (final DeploymentException e) { 207 throw new IllegalStateException(e); 208 } 209 }); 210 } 211 212 @Data 213 @EqualsAndHashCode(callSuper = false) 214 private static class JAXRSEndpoint extends Endpoint { 215 216 private final String appBase; 217 218 private final ServletController controller; 219 220 private final ServletContext context; 221 222 private final String defaultMethod; 223 224 private final String defaultUri; 225 226 private final Map<String, List<String>> baseHeaders; 227 228 @RequiredArgsConstructor 229 private class PartialMessageHandler implements Partial<byte[]> { 230 231 private final Session session; 232 233 private InMemoryRequest request; 234 235 private InMemoryResponse response; 236 237 private void handleStart(final StringBuilder buffer, final InputStream message) { 238 final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); 239 headers.putAll(baseHeaders); 240 241 try { // read headers from the message 242 String line; 243 int del; 244 while ((line = readLine(buffer, message)) != null) { 245 final boolean done = line.endsWith(EOM); 246 if (done) { 247 line = line.substring(0, line.length() - EOM.length()); 248 } 249 if (!line.isEmpty()) { 250 del = line.indexOf(':'); 251 if (del < 0) { 252 headers.put(line.trim(), emptyList()); 253 } else { 254 headers 255 .put(line.substring(0, del).trim(), 256 singletonList(line.substring(del + 1).trim())); 257 } 258 } 259 if (done) { 260 break; 261 } 262 } 263 } catch (final IOException ioe) { 264 throw new IllegalStateException(ioe); 265 } 266 267 final List<String> uris = headers.get("destination"); 268 final String uri; 269 if (uris == null || uris.isEmpty()) { 270 uri = defaultUri; 271 } else { 272 uri = uris.iterator().next(); 273 } 274 275 final List<String> methods = headers.get("destinationMethod"); 276 final String method; 277 if (methods == null || methods.isEmpty()) { 278 method = defaultMethod; 279 } else { 280 method = methods.iterator().next(); 281 } 282 283 final String queryString; 284 final String path; 285 final int query = uri.indexOf('?'); 286 if (query > 0) { 287 queryString = uri.substring(query + 1); 288 path = uri.substring(0, query); 289 } else { 290 queryString = null; 291 path = uri; 292 } 293 294 request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path, 295 appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message), 296 session::getUserPrincipal, controller); 297 response = new InMemoryResponse(session::isOpen, 298 () -> { 299 if (session.getBasicRemote().getBatchingAllowed()) { 300 try { 301 session.getBasicRemote().flushBatch(); 302 } catch (final IOException e) { 303 throw new IllegalStateException(e); 304 } 305 } 306 }, bytes -> { 307 try { 308 session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes)); 309 } catch (final IOException e) { 310 throw new IllegalStateException(e); 311 } 312 }, (status, responseHeaders) -> { 313 final StringBuilder top = new StringBuilder("MESSAGE\r\n"); 314 top.append("status: ").append(status).append("\r\n"); 315 responseHeaders 316 .forEach((k, 317 v) -> top.append(k) 318 .append(": ") 319 .append(String.join(",", v)) 320 .append("\r\n")); 321 top.append("\r\n");// empty line, means the next bytes are the payload 322 return top.toString(); 323 }) { 324 325 @Override 326 protected void onClose(final OutputStream stream) throws IOException { 327 stream.write(EOM.getBytes(StandardCharsets.UTF_8)); 328 } 329 }; 330 request.setResponse(response); 331 } 332 333 @Override 334 public void onMessage(final byte[] byteBuffer, final boolean last) { 335 final ByteArrayInputStream message = new ByteArrayInputStream(byteBuffer); 336 337 final StringBuilder buffer = new StringBuilder(128); 338 try { // read headers from the message 339 if (request != null) { 340 ((WebSocketInputStream) request.getInputStream()).addStream(message); 341 } else if ("SEND".equalsIgnoreCase(readLine(buffer, message))) { 342 handleStart(buffer, message); 343 } else { 344 throw new IllegalArgumentException("not a message"); 345 } 346 } catch (IOException e) { 347 throw new IllegalStateException(e); 348 } 349 350 if (last) { 351 try { 352 controller.invoke(request, response); 353 } catch (final ServletException e) { 354 throw new IllegalArgumentException(e); 355 } finally { 356 request = null; 357 response = null; 358 } 359 } 360 } 361 } 362 363 @Override 364 public void onOpen(final Session session, final EndpointConfig endpointConfig) { 365 log.debug("Opened session {}", session.getId()); 366 session.addMessageHandler(byte[].class, new PartialMessageHandler(session)); 367 } 368 369 @Override 370 public void onClose(final Session session, final CloseReason closeReason) { 371 log.debug("Closed session {}", session.getId()); 372 } 373 374 @Override 375 public void onError(final Session session, final Throwable throwable) { 376 log.warn("Error for session {}", session.getId(), throwable); 377 } 378 379 private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException { 380 int c; 381 while ((c = in.read()) != -1) { 382 if (c == '\n') { 383 break; 384 } else if (c != '\r') { 385 buffer.append((char) c); 386 } 387 } 388 389 if (buffer.length() == 0) { 390 return null; 391 } 392 final String string = buffer.toString(); 393 buffer.setLength(0); 394 return string; 395 } 396 } 397 398 private static class WebSocketInputStream extends MemoryInputStream { 399 400 private int previous = Integer.MAX_VALUE; 401 402 private final Queue<InputStream> queue = new LinkedList<>(); 403 404 private WebSocketInputStream(final InputStream delegate) { 405 super(delegate); 406 queue.add(delegate); 407 } 408 409 @Override 410 public int read() throws IOException { 411 if (finished) { 412 return -1; 413 } 414 if (previous != Integer.MAX_VALUE) { 415 int result = previous; 416 previous = Integer.MAX_VALUE; 417 return result; 418 } 419 final int read = delegate().read(); 420 if (read == '^') { 421 previous = delegate().read(); 422 if (previous == '@') { 423 finished = true; 424 return -1; 425 } 426 } 427 if (read < 0) { 428 finished = true; 429 } 430 return read; 431 } 432 433 private InputStream delegate() throws IOException { 434 if (queue.isEmpty()) { 435 throw new IOException("Don't have an input stream."); 436 } 437 438 if (queue.peek().available() == 0) { 439 queue.remove(); 440 } 441 442 return queue.peek(); 443 } 444 445 public void addStream(final InputStream stream) { 446 queue.add(stream); 447 } 448 } 449 450 private static class WebSocketRegistry implements DestinationRegistry { 451 452 private final DestinationRegistry delegate; 453 454 private ServletController controller; 455 456 private WebSocketRegistry(final DestinationRegistry registry) { 457 this.delegate = registry; 458 } 459 460 @Override 461 public void addDestination(final AbstractHTTPDestination destination) { 462 throw new UnsupportedOperationException(); 463 } 464 465 @Override 466 public void removeDestination(final String path) { 467 throw new UnsupportedOperationException(); 468 } 469 470 @Override 471 public AbstractHTTPDestination getDestinationForPath(final String path) { 472 return wrap(delegate.getDestinationForPath(path)); 473 } 474 475 @Override 476 public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) { 477 return wrap(delegate.getDestinationForPath(path, tryDecoding)); 478 } 479 480 @Override 481 public AbstractHTTPDestination checkRestfulRequest(final String address) { 482 return wrap(delegate.checkRestfulRequest(address)); 483 } 484 485 @Override 486 public Collection<AbstractHTTPDestination> getDestinations() { 487 return delegate.getDestinations(); 488 } 489 490 @Override 491 public AbstractDestination[] getSortedDestinations() { 492 return delegate.getSortedDestinations(); 493 } 494 495 @Override 496 public Set<String> getDestinationsPaths() { 497 return delegate.getDestinationsPaths(); 498 } 499 500 private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) { 501 try { 502 return destination == null ? null : new WebSocketDestination(destination, this); 503 } catch (final IOException e) { 504 throw new IllegalStateException(e); 505 } 506 } 507 } 508 509 private static class WebSocketDestination extends AbstractHTTPDestination { 510 511 static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class); 512 513 private final AbstractHTTPDestination delegate; 514 515 private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry) 516 throws IOException { 517 super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false); 518 this.delegate = delegate; 519 this.cproviderFactory = new WebSocketContinuationFactory(registry); 520 } 521 522 @Override 523 public EndpointReferenceType getAddress() { 524 return delegate.getAddress(); 525 } 526 527 @Override 528 public Conduit getBackChannel(final Message inMessage) throws IOException { 529 return delegate.getBackChannel(inMessage); 530 } 531 532 @Override 533 public EndpointInfo getEndpointInfo() { 534 return delegate.getEndpointInfo(); 535 } 536 537 @Override 538 public void shutdown() { 539 throw new UnsupportedOperationException(); 540 } 541 542 @Override 543 public void setMessageObserver(final MessageObserver observer) { 544 throw new UnsupportedOperationException(); 545 } 546 547 @Override 548 public MessageObserver getMessageObserver() { 549 return delegate.getMessageObserver(); 550 } 551 552 @Override 553 protected Logger getLogger() { 554 return LOG; 555 } 556 557 @Override 558 public Bus getBus() { 559 return delegate.getBus(); 560 } 561 562 @Override 563 public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req, 564 final HttpServletResponse resp) throws IOException { 565 // eager create the message to ensure we set our continuation for @Suspended 566 Message inMessage = retrieveFromContinuation(req); 567 if (inMessage == null) { 568 inMessage = new MessageImpl(); 569 570 final ExchangeImpl exchange = new ExchangeImpl(); 571 exchange.setInMessage(inMessage); 572 setupMessage(inMessage, config, context, req, resp); 573 574 exchange.setSession(new HTTPSession(req)); 575 MessageImpl.class.cast(inMessage).setDestination(this); 576 } 577 578 delegate.invoke(config, context, req, resp); 579 } 580 581 @Override 582 public void finalizeConfig() { 583 delegate.finalizeConfig(); 584 } 585 586 @Override 587 public String getBeanName() { 588 return delegate.getBeanName(); 589 } 590 591 @Override 592 public EndpointReferenceType getAddressWithId(final String id) { 593 return delegate.getAddressWithId(id); 594 } 595 596 @Override 597 public String getId(final Map<String, Object> context) { 598 return delegate.getId(context); 599 } 600 601 @Override 602 public String getContextMatchStrategy() { 603 return delegate.getContextMatchStrategy(); 604 } 605 606 @Override 607 public boolean isFixedParameterOrder() { 608 return delegate.isFixedParameterOrder(); 609 } 610 611 @Override 612 public boolean isMultiplexWithAddress() { 613 return delegate.isMultiplexWithAddress(); 614 } 615 616 @Override 617 public HTTPServerPolicy getServer() { 618 return delegate.getServer(); 619 } 620 621 @Override 622 public void assertMessage(final Message message) { 623 delegate.assertMessage(message); 624 } 625 626 @Override 627 public boolean canAssert(final QName type) { 628 return delegate.canAssert(type); 629 } 630 631 @Override 632 public String getPath() { 633 return delegate.getPath(); 634 } 635 } 636 637 private static class WebSocketContinuationFactory implements ContinuationProviderFactory { 638 639 private static final String KEY = WebSocketContinuationFactory.class.getName(); 640 641 private final WebSocketRegistry registry; 642 643 private WebSocketContinuationFactory(final WebSocketRegistry registry) { 644 this.registry = registry; 645 } 646 647 @Override 648 public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req, 649 final HttpServletResponse resp) { 650 return new WebSocketContinuation(inMessage, req, resp, registry); 651 } 652 653 @Override 654 public Message retrieveFromContinuation(final HttpServletRequest req) { 655 return Message.class.cast(req.getAttribute(KEY)); 656 } 657 } 658 659 private static class WebSocketContinuation implements ContinuationProvider, Continuation { 660 661 private final Message message; 662 663 private final HttpServletRequest request; 664 665 private final HttpServletResponse response; 666 667 private final WebSocketRegistry registry; 668 669 private final ContinuationCallback callback; 670 671 private Object object; 672 673 private boolean resumed; 674 675 private boolean pending; 676 677 private boolean isNew; 678 679 private WebSocketContinuation(final Message message, final HttpServletRequest request, 680 final HttpServletResponse response, final WebSocketRegistry registry) { 681 this.message = message; 682 this.request = request; 683 this.response = response; 684 this.registry = registry; 685 this.request 686 .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE, 687 message.getExchange().getInMessage()); 688 this.callback = message.getExchange().get(ContinuationCallback.class); 689 } 690 691 @Override 692 public Continuation getContinuation() { 693 return this; 694 } 695 696 @Override 697 public void complete() { 698 message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE); 699 if (callback != null) { 700 final Exception ex = message.getExchange().get(Exception.class); 701 if (ex == null) { 702 callback.onComplete(); 703 } else { 704 callback.onError(ex); 705 } 706 } 707 try { 708 response.getWriter().close(); 709 } catch (final IOException e) { 710 throw new IllegalStateException(e); 711 } 712 } 713 714 @Override 715 public boolean suspend(final long timeout) { 716 isNew = false; 717 resumed = false; 718 pending = true; 719 message.getExchange().getInMessage().getInterceptorChain().suspend(); 720 return true; 721 } 722 723 @Override 724 public void resume() { 725 resumed = true; 726 try { 727 registry.controller.invoke(request, response); 728 } catch (final ServletException e) { 729 throw new IllegalStateException(e); 730 } 731 } 732 733 @Override 734 public void reset() { 735 pending = false; 736 resumed = false; 737 isNew = false; 738 object = null; 739 } 740 741 @Override 742 public boolean isNew() { 743 return isNew; 744 } 745 746 @Override 747 public boolean isPending() { 748 return pending; 749 } 750 751 @Override 752 public boolean isResumed() { 753 return resumed; 754 } 755 756 @Override 757 public boolean isTimeout() { 758 return false; 759 } 760 761 @Override 762 public Object getObject() { 763 return object; 764 } 765 766 @Override 767 public void setObject(final Object o) { 768 object = o; 769 } 770 771 @Override 772 public boolean isReadyForWrite() { 773 return true; 774 } 775 } 776}