001/**
002 * Copyright (C) 2006-2025 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.server.configuration;
017
018import static java.util.Collections.emptyList;
019import static java.util.Collections.emptyMap;
020import static java.util.Collections.singletonList;
021import static java.util.Locale.ENGLISH;
022
023import java.io.ByteArrayInputStream;
024import java.io.IOException;
025import java.io.InputStream;
026import java.io.OutputStream;
027import java.nio.ByteBuffer;
028import java.nio.charset.StandardCharsets;
029import java.util.Collection;
030import java.util.Comparator;
031import java.util.HashMap;
032import java.util.LinkedList;
033import java.util.List;
034import java.util.Map;
035import java.util.Queue;
036import java.util.Set;
037import java.util.Spliterator;
038import java.util.Spliterators;
039import java.util.TreeMap;
040import java.util.logging.Logger;
041import java.util.stream.Stream;
042import java.util.stream.StreamSupport;
043
044import javax.enterprise.context.Dependent;
045import javax.enterprise.inject.Instance;
046import javax.inject.Inject;
047import javax.servlet.ServletConfig;
048import javax.servlet.ServletContext;
049import javax.servlet.ServletContextEvent;
050import javax.servlet.ServletContextListener;
051import javax.servlet.ServletException;
052import javax.servlet.annotation.WebListener;
053import javax.servlet.http.HttpServletRequest;
054import javax.servlet.http.HttpServletResponse;
055import javax.websocket.CloseReason;
056import javax.websocket.DeploymentException;
057import javax.websocket.Endpoint;
058import javax.websocket.EndpointConfig;
059import javax.websocket.MessageHandler.Partial;
060import javax.websocket.Session;
061import javax.websocket.server.ServerContainer;
062import javax.websocket.server.ServerEndpointConfig;
063import javax.ws.rs.ApplicationPath;
064import javax.ws.rs.core.Application;
065import javax.ws.rs.core.HttpHeaders;
066import javax.xml.namespace.QName;
067
068import org.apache.cxf.Bus;
069import org.apache.cxf.common.logging.LogUtils;
070import org.apache.cxf.continuations.Continuation;
071import org.apache.cxf.continuations.ContinuationCallback;
072import org.apache.cxf.continuations.ContinuationProvider;
073import org.apache.cxf.endpoint.ServerRegistry;
074import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean;
075import org.apache.cxf.message.ExchangeImpl;
076import org.apache.cxf.message.Message;
077import org.apache.cxf.message.MessageImpl;
078import org.apache.cxf.service.model.EndpointInfo;
079import org.apache.cxf.transport.AbstractDestination;
080import org.apache.cxf.transport.Conduit;
081import org.apache.cxf.transport.MessageObserver;
082import org.apache.cxf.transport.http.AbstractHTTPDestination;
083import org.apache.cxf.transport.http.ContinuationProviderFactory;
084import org.apache.cxf.transport.http.DestinationRegistry;
085import org.apache.cxf.transport.http.HTTPSession;
086import org.apache.cxf.transport.servlet.ServletController;
087import org.apache.cxf.transport.servlet.ServletDestination;
088import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet;
089import org.apache.cxf.transports.http.configuration.HTTPServerPolicy;
090import org.apache.cxf.ws.addressing.EndpointReferenceType;
091import org.talend.sdk.component.server.front.cxf.CxfExtractor;
092import org.talend.sdk.component.server.front.memory.InMemoryRequest;
093import org.talend.sdk.component.server.front.memory.InMemoryResponse;
094import org.talend.sdk.component.server.front.memory.MemoryInputStream;
095import org.talend.sdk.component.server.front.memory.SimpleServletConfig;
096
097import lombok.Data;
098import lombok.EqualsAndHashCode;
099import lombok.RequiredArgsConstructor;
100import lombok.extern.slf4j.Slf4j;
101
102// ensure any JAX-RS command can use websockets
103@Slf4j
104@Dependent
105@WebListener
106public class WebSocketBroadcastSetup implements ServletContextListener {
107
108    private static final String EOM = "^@";
109
110    @Inject
111    private Bus bus;
112
113    @Inject
114    private CxfExtractor cxf;
115
116    @Inject
117    private Instance<Application> applications;
118
119    @Override
120    public void contextInitialized(final ServletContextEvent sce) {
121        final ServerContainer container =
122                ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName()));
123
124        final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class
125                .cast(bus
126                        .getExtension(ServerRegistry.class)
127                        .getServers()
128                        .iterator()
129                        .next()
130                        .getEndpoint()
131                        .get(JAXRSServiceFactoryBean.class.getName()));
132
133        final String appBase = StreamSupport
134                .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false)
135                .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class))
136                .map(a -> a.getClass().getAnnotation(ApplicationPath.class))
137                .map(ApplicationPath::value)
138                .findFirst()
139                .map(s -> !s.startsWith("/") ? "/" + s : s)
140                .orElse("/api/v1");
141        final String version = appBase.replaceFirst("/api", "");
142
143        final DestinationRegistry registry = cxf.getRegistry();
144        final ServletContext servletContext = sce.getServletContext();
145
146        final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry);
147        final ServletController controller = new ServletController(webSocketRegistry,
148                new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"),
149                new ServiceListGeneratorServlet(registry, bus));
150        webSocketRegistry.controller = controller;
151
152        Stream
153                .concat(factory
154                        .getClassResourceInfo()
155                        .stream()
156                        .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream())
157                        .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend."))
158                        .map(ori -> {
159                            final String uri = ori.getClassResourceInfo().getURITemplate().getValue()
160                                    + ori.getURITemplate().getValue();
161                            return ServerEndpointConfig.Builder
162                                    .create(Endpoint.class,
163                                            "/websocket" + version + "/"
164                                                    + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri)
165                                    .configurator(new ServerEndpointConfig.Configurator() {
166
167                                        @Override
168                                        public <T> T getEndpointInstance(final Class<T> clazz)
169                                                throws InstantiationException {
170                                            final Map<String, List<String>> headers = new HashMap<>();
171                                            if (!ori.getProduceTypes().isEmpty()) {
172                                                headers
173                                                        .put(HttpHeaders.CONTENT_TYPE, singletonList(
174                                                                ori.getProduceTypes().iterator().next().toString()));
175                                            }
176                                            if (!ori.getConsumeTypes().isEmpty()) {
177                                                headers
178                                                        .put(HttpHeaders.ACCEPT, singletonList(
179                                                                ori.getConsumeTypes().iterator().next().toString()));
180                                            }
181                                            return (T) new JAXRSEndpoint(appBase, controller, servletContext,
182                                                    ori.getHttpMethod(), uri, headers);
183                                        }
184                                    })
185                                    .build();
186                        }),
187                        Stream
188                                .of(ServerEndpointConfig.Builder
189                                        .create(Endpoint.class, "/websocket" + version + "/bus")
190                                        .configurator(new ServerEndpointConfig.Configurator() {
191
192                                            @Override
193                                            public <T> T getEndpointInstance(final Class<T> clazz)
194                                                    throws InstantiationException {
195
196                                                return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET",
197                                                        "/", emptyMap());
198                                            }
199                                        })
200                                        .build()))
201                .sorted(Comparator.comparing(ServerEndpointConfig::getPath))
202                .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath()))
203                .forEach(config -> {
204                    try {
205                        container.addEndpoint(config);
206                    } catch (final DeploymentException e) {
207                        throw new IllegalStateException(e);
208                    }
209                });
210    }
211
212    @Data
213    @EqualsAndHashCode(callSuper = false)
214    private static class JAXRSEndpoint extends Endpoint {
215
216        private final String appBase;
217
218        private final ServletController controller;
219
220        private final ServletContext context;
221
222        private final String defaultMethod;
223
224        private final String defaultUri;
225
226        private final Map<String, List<String>> baseHeaders;
227
228        @RequiredArgsConstructor
229        private class PartialMessageHandler implements Partial<byte[]> {
230
231            private final Session session;
232
233            private InMemoryRequest request;
234
235            private InMemoryResponse response;
236
237            private void handleStart(final StringBuilder buffer, final InputStream message) {
238                final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
239                headers.putAll(baseHeaders);
240
241                try { // read headers from the message
242                    String line;
243                    int del;
244                    while ((line = readLine(buffer, message)) != null) {
245                        final boolean done = line.endsWith(EOM);
246                        if (done) {
247                            line = line.substring(0, line.length() - EOM.length());
248                        }
249                        if (!line.isEmpty()) {
250                            del = line.indexOf(':');
251                            if (del < 0) {
252                                headers.put(line.trim(), emptyList());
253                            } else {
254                                headers
255                                        .put(line.substring(0, del).trim(),
256                                                singletonList(line.substring(del + 1).trim()));
257                            }
258                        }
259                        if (done) {
260                            break;
261                        }
262                    }
263                } catch (final IOException ioe) {
264                    throw new IllegalStateException(ioe);
265                }
266
267                final List<String> uris = headers.get("destination");
268                final String uri;
269                if (uris == null || uris.isEmpty()) {
270                    uri = defaultUri;
271                } else {
272                    uri = uris.iterator().next();
273                }
274
275                final List<String> methods = headers.get("destinationMethod");
276                final String method;
277                if (methods == null || methods.isEmpty()) {
278                    method = defaultMethod;
279                } else {
280                    method = methods.iterator().next();
281                }
282
283                final String queryString;
284                final String path;
285                final int query = uri.indexOf('?');
286                if (query > 0) {
287                    queryString = uri.substring(query + 1);
288                    path = uri.substring(0, query);
289                } else {
290                    queryString = null;
291                    path = uri;
292                }
293
294                request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path,
295                        appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message),
296                        session::getUserPrincipal, controller);
297                response = new InMemoryResponse(session::isOpen,
298                        () -> {
299                            if (session.getBasicRemote().getBatchingAllowed()) {
300                                try {
301                                    session.getBasicRemote().flushBatch();
302                                } catch (final IOException e) {
303                                    throw new IllegalStateException(e);
304                                }
305                            }
306                        }, bytes -> {
307                            try {
308                                session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes));
309                            } catch (final IOException e) {
310                                throw new IllegalStateException(e);
311                            }
312                        }, (status, responseHeaders) -> {
313                            final StringBuilder top = new StringBuilder("MESSAGE\r\n");
314                            top.append("status: ").append(status).append("\r\n");
315                            responseHeaders
316                                    .forEach((k,
317                                            v) -> top.append(k)
318                                                    .append(": ")
319                                                    .append(String.join(",", v))
320                                                    .append("\r\n"));
321                            top.append("\r\n");// empty line, means the next bytes are the payload
322                            return top.toString();
323                        }) {
324
325                    @Override
326                    protected void onClose(final OutputStream stream) throws IOException {
327                        stream.write(EOM.getBytes(StandardCharsets.UTF_8));
328                    }
329                };
330                request.setResponse(response);
331            }
332
333            @Override
334            public void onMessage(final byte[] byteBuffer, final boolean last) {
335                final ByteArrayInputStream message = new ByteArrayInputStream(byteBuffer);
336
337                final StringBuilder buffer = new StringBuilder(128);
338                try { // read headers from the message
339                    if (request != null) {
340                        ((WebSocketInputStream) request.getInputStream()).addStream(message);
341                    } else if ("SEND".equalsIgnoreCase(readLine(buffer, message))) {
342                        handleStart(buffer, message);
343                    } else {
344                        throw new IllegalArgumentException("not a message");
345                    }
346                } catch (IOException e) {
347                    throw new IllegalStateException(e);
348                }
349
350                if (last) {
351                    try {
352                        controller.invoke(request, response);
353                    } catch (final ServletException e) {
354                        throw new IllegalArgumentException(e);
355                    } finally {
356                        request = null;
357                        response = null;
358                    }
359                }
360            }
361        }
362
363        @Override
364        public void onOpen(final Session session, final EndpointConfig endpointConfig) {
365            log.debug("Opened session {}", session.getId());
366            session.addMessageHandler(byte[].class, new PartialMessageHandler(session));
367        }
368
369        @Override
370        public void onClose(final Session session, final CloseReason closeReason) {
371            log.debug("Closed session {}", session.getId());
372        }
373
374        @Override
375        public void onError(final Session session, final Throwable throwable) {
376            log.warn("Error for session {}", session.getId(), throwable);
377        }
378
379        private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException {
380            int c;
381            while ((c = in.read()) != -1) {
382                if (c == '\n') {
383                    break;
384                } else if (c != '\r') {
385                    buffer.append((char) c);
386                }
387            }
388
389            if (buffer.length() == 0) {
390                return null;
391            }
392            final String string = buffer.toString();
393            buffer.setLength(0);
394            return string;
395        }
396    }
397
398    private static class WebSocketInputStream extends MemoryInputStream {
399
400        private int previous = Integer.MAX_VALUE;
401
402        private final Queue<InputStream> queue = new LinkedList<>();
403
404        private WebSocketInputStream(final InputStream delegate) {
405            super(delegate);
406            queue.add(delegate);
407        }
408
409        @Override
410        public int read() throws IOException {
411            if (finished) {
412                return -1;
413            }
414            if (previous != Integer.MAX_VALUE) {
415                int result = previous;
416                previous = Integer.MAX_VALUE;
417                return result;
418            }
419            final int read = delegate().read();
420            if (read == '^') {
421                previous = delegate().read();
422                if (previous == '@') {
423                    finished = true;
424                    return -1;
425                }
426            }
427            if (read < 0) {
428                finished = true;
429            }
430            return read;
431        }
432
433        private InputStream delegate() throws IOException {
434            if (queue.isEmpty()) {
435                throw new IOException("Don't have an input stream.");
436            }
437
438            if (queue.peek().available() == 0) {
439                queue.remove();
440            }
441
442            return queue.peek();
443        }
444
445        public void addStream(final InputStream stream) {
446            queue.add(stream);
447        }
448    }
449
450    private static class WebSocketRegistry implements DestinationRegistry {
451
452        private final DestinationRegistry delegate;
453
454        private ServletController controller;
455
456        private WebSocketRegistry(final DestinationRegistry registry) {
457            this.delegate = registry;
458        }
459
460        @Override
461        public void addDestination(final AbstractHTTPDestination destination) {
462            throw new UnsupportedOperationException();
463        }
464
465        @Override
466        public void removeDestination(final String path) {
467            throw new UnsupportedOperationException();
468        }
469
470        @Override
471        public AbstractHTTPDestination getDestinationForPath(final String path) {
472            return wrap(delegate.getDestinationForPath(path));
473        }
474
475        @Override
476        public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) {
477            return wrap(delegate.getDestinationForPath(path, tryDecoding));
478        }
479
480        @Override
481        public AbstractHTTPDestination checkRestfulRequest(final String address) {
482            return wrap(delegate.checkRestfulRequest(address));
483        }
484
485        @Override
486        public Collection<AbstractHTTPDestination> getDestinations() {
487            return delegate.getDestinations();
488        }
489
490        @Override
491        public AbstractDestination[] getSortedDestinations() {
492            return delegate.getSortedDestinations();
493        }
494
495        @Override
496        public Set<String> getDestinationsPaths() {
497            return delegate.getDestinationsPaths();
498        }
499
500        private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) {
501            try {
502                return destination == null ? null : new WebSocketDestination(destination, this);
503            } catch (final IOException e) {
504                throw new IllegalStateException(e);
505            }
506        }
507    }
508
509    private static class WebSocketDestination extends AbstractHTTPDestination {
510
511        static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class);
512
513        private final AbstractHTTPDestination delegate;
514
515        private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry)
516                throws IOException {
517            super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false);
518            this.delegate = delegate;
519            this.cproviderFactory = new WebSocketContinuationFactory(registry);
520        }
521
522        @Override
523        public EndpointReferenceType getAddress() {
524            return delegate.getAddress();
525        }
526
527        @Override
528        public Conduit getBackChannel(final Message inMessage) throws IOException {
529            return delegate.getBackChannel(inMessage);
530        }
531
532        @Override
533        public EndpointInfo getEndpointInfo() {
534            return delegate.getEndpointInfo();
535        }
536
537        @Override
538        public void shutdown() {
539            throw new UnsupportedOperationException();
540        }
541
542        @Override
543        public void setMessageObserver(final MessageObserver observer) {
544            throw new UnsupportedOperationException();
545        }
546
547        @Override
548        public MessageObserver getMessageObserver() {
549            return delegate.getMessageObserver();
550        }
551
552        @Override
553        protected Logger getLogger() {
554            return LOG;
555        }
556
557        @Override
558        public Bus getBus() {
559            return delegate.getBus();
560        }
561
562        @Override
563        public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req,
564                final HttpServletResponse resp) throws IOException {
565            // eager create the message to ensure we set our continuation for @Suspended
566            Message inMessage = retrieveFromContinuation(req);
567            if (inMessage == null) {
568                inMessage = new MessageImpl();
569
570                final ExchangeImpl exchange = new ExchangeImpl();
571                exchange.setInMessage(inMessage);
572                setupMessage(inMessage, config, context, req, resp);
573
574                exchange.setSession(new HTTPSession(req));
575                MessageImpl.class.cast(inMessage).setDestination(this);
576            }
577
578            delegate.invoke(config, context, req, resp);
579        }
580
581        @Override
582        public void finalizeConfig() {
583            delegate.finalizeConfig();
584        }
585
586        @Override
587        public String getBeanName() {
588            return delegate.getBeanName();
589        }
590
591        @Override
592        public EndpointReferenceType getAddressWithId(final String id) {
593            return delegate.getAddressWithId(id);
594        }
595
596        @Override
597        public String getId(final Map<String, Object> context) {
598            return delegate.getId(context);
599        }
600
601        @Override
602        public String getContextMatchStrategy() {
603            return delegate.getContextMatchStrategy();
604        }
605
606        @Override
607        public boolean isFixedParameterOrder() {
608            return delegate.isFixedParameterOrder();
609        }
610
611        @Override
612        public boolean isMultiplexWithAddress() {
613            return delegate.isMultiplexWithAddress();
614        }
615
616        @Override
617        public HTTPServerPolicy getServer() {
618            return delegate.getServer();
619        }
620
621        @Override
622        public void assertMessage(final Message message) {
623            delegate.assertMessage(message);
624        }
625
626        @Override
627        public boolean canAssert(final QName type) {
628            return delegate.canAssert(type);
629        }
630
631        @Override
632        public String getPath() {
633            return delegate.getPath();
634        }
635    }
636
637    private static class WebSocketContinuationFactory implements ContinuationProviderFactory {
638
639        private static final String KEY = WebSocketContinuationFactory.class.getName();
640
641        private final WebSocketRegistry registry;
642
643        private WebSocketContinuationFactory(final WebSocketRegistry registry) {
644            this.registry = registry;
645        }
646
647        @Override
648        public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req,
649                final HttpServletResponse resp) {
650            return new WebSocketContinuation(inMessage, req, resp, registry);
651        }
652
653        @Override
654        public Message retrieveFromContinuation(final HttpServletRequest req) {
655            return Message.class.cast(req.getAttribute(KEY));
656        }
657    }
658
659    private static class WebSocketContinuation implements ContinuationProvider, Continuation {
660
661        private final Message message;
662
663        private final HttpServletRequest request;
664
665        private final HttpServletResponse response;
666
667        private final WebSocketRegistry registry;
668
669        private final ContinuationCallback callback;
670
671        private Object object;
672
673        private boolean resumed;
674
675        private boolean pending;
676
677        private boolean isNew;
678
679        private WebSocketContinuation(final Message message, final HttpServletRequest request,
680                final HttpServletResponse response, final WebSocketRegistry registry) {
681            this.message = message;
682            this.request = request;
683            this.response = response;
684            this.registry = registry;
685            this.request
686                    .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE,
687                            message.getExchange().getInMessage());
688            this.callback = message.getExchange().get(ContinuationCallback.class);
689        }
690
691        @Override
692        public Continuation getContinuation() {
693            return this;
694        }
695
696        @Override
697        public void complete() {
698            message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE);
699            if (callback != null) {
700                final Exception ex = message.getExchange().get(Exception.class);
701                if (ex == null) {
702                    callback.onComplete();
703                } else {
704                    callback.onError(ex);
705                }
706            }
707            try {
708                response.getWriter().close();
709            } catch (final IOException e) {
710                throw new IllegalStateException(e);
711            }
712        }
713
714        @Override
715        public boolean suspend(final long timeout) {
716            isNew = false;
717            resumed = false;
718            pending = true;
719            message.getExchange().getInMessage().getInterceptorChain().suspend();
720            return true;
721        }
722
723        @Override
724        public void resume() {
725            resumed = true;
726            try {
727                registry.controller.invoke(request, response);
728            } catch (final ServletException e) {
729                throw new IllegalStateException(e);
730            }
731        }
732
733        @Override
734        public void reset() {
735            pending = false;
736            resumed = false;
737            isNew = false;
738            object = null;
739        }
740
741        @Override
742        public boolean isNew() {
743            return isNew;
744        }
745
746        @Override
747        public boolean isPending() {
748            return pending;
749        }
750
751        @Override
752        public boolean isResumed() {
753            return resumed;
754        }
755
756        @Override
757        public boolean isTimeout() {
758            return false;
759        }
760
761        @Override
762        public Object getObject() {
763            return object;
764        }
765
766        @Override
767        public void setObject(final Object o) {
768            object = o;
769        }
770
771        @Override
772        public boolean isReadyForWrite() {
773            return true;
774        }
775    }
776}