001/**
002 * Copyright (C) 2006-2022 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.server.configuration;
017
018import static java.util.Collections.emptyList;
019import static java.util.Collections.emptyMap;
020import static java.util.Collections.singletonList;
021import static java.util.Locale.ENGLISH;
022
023import java.io.ByteArrayInputStream;
024import java.io.IOException;
025import java.io.InputStream;
026import java.io.OutputStream;
027import java.nio.ByteBuffer;
028import java.nio.charset.StandardCharsets;
029import java.util.Collection;
030import java.util.Comparator;
031import java.util.HashMap;
032import java.util.LinkedList;
033import java.util.List;
034import java.util.Map;
035import java.util.Queue;
036import java.util.Set;
037import java.util.Spliterator;
038import java.util.Spliterators;
039import java.util.TreeMap;
040import java.util.logging.Logger;
041import java.util.stream.Stream;
042import java.util.stream.StreamSupport;
043
044import javax.enterprise.context.Dependent;
045import javax.enterprise.inject.Instance;
046import javax.inject.Inject;
047import javax.servlet.ServletConfig;
048import javax.servlet.ServletContext;
049import javax.servlet.ServletContextEvent;
050import javax.servlet.ServletContextListener;
051import javax.servlet.ServletException;
052import javax.servlet.annotation.WebListener;
053import javax.servlet.http.HttpServletRequest;
054import javax.servlet.http.HttpServletResponse;
055import javax.websocket.CloseReason;
056import javax.websocket.DeploymentException;
057import javax.websocket.Endpoint;
058import javax.websocket.EndpointConfig;
059import javax.websocket.MessageHandler.Partial;
060import javax.websocket.Session;
061import javax.websocket.server.ServerContainer;
062import javax.websocket.server.ServerEndpointConfig;
063import javax.ws.rs.ApplicationPath;
064import javax.ws.rs.core.Application;
065import javax.ws.rs.core.HttpHeaders;
066import javax.xml.namespace.QName;
067
068import org.apache.cxf.Bus;
069import org.apache.cxf.common.logging.LogUtils;
070import org.apache.cxf.continuations.Continuation;
071import org.apache.cxf.continuations.ContinuationCallback;
072import org.apache.cxf.continuations.ContinuationProvider;
073import org.apache.cxf.endpoint.ServerRegistry;
074import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean;
075import org.apache.cxf.message.ExchangeImpl;
076import org.apache.cxf.message.Message;
077import org.apache.cxf.message.MessageImpl;
078import org.apache.cxf.service.model.EndpointInfo;
079import org.apache.cxf.transport.AbstractDestination;
080import org.apache.cxf.transport.Conduit;
081import org.apache.cxf.transport.MessageObserver;
082import org.apache.cxf.transport.http.AbstractHTTPDestination;
083import org.apache.cxf.transport.http.ContinuationProviderFactory;
084import org.apache.cxf.transport.http.DestinationRegistry;
085import org.apache.cxf.transport.http.HTTPSession;
086import org.apache.cxf.transport.servlet.ServletController;
087import org.apache.cxf.transport.servlet.ServletDestination;
088import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet;
089import org.apache.cxf.transports.http.configuration.HTTPServerPolicy;
090import org.apache.cxf.ws.addressing.EndpointReferenceType;
091import org.talend.sdk.component.server.front.cxf.CxfExtractor;
092import org.talend.sdk.component.server.front.memory.InMemoryRequest;
093import org.talend.sdk.component.server.front.memory.InMemoryResponse;
094import org.talend.sdk.component.server.front.memory.MemoryInputStream;
095import org.talend.sdk.component.server.front.memory.SimpleServletConfig;
096
097import lombok.Data;
098import lombok.EqualsAndHashCode;
099import lombok.RequiredArgsConstructor;
100import lombok.extern.slf4j.Slf4j;
101
102// ensure any JAX-RS command can use websockets
103@Slf4j
104@Dependent
105@WebListener
106public class WebSocketBroadcastSetup implements ServletContextListener {
107
108    private static final String EOM = "^@";
109
110    @Inject
111    private Bus bus;
112
113    @Inject
114    private CxfExtractor cxf;
115
116    @Inject
117    private Instance<Application> applications;
118
119    @Override
120    public void contextInitialized(final ServletContextEvent sce) {
121        final ServerContainer container =
122                ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName()));
123
124        final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class
125                .cast(bus
126                        .getExtension(ServerRegistry.class)
127                        .getServers()
128                        .iterator()
129                        .next()
130                        .getEndpoint()
131                        .get(JAXRSServiceFactoryBean.class.getName()));
132
133        final String appBase = StreamSupport
134                .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false)
135                .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class))
136                .map(a -> a.getClass().getAnnotation(ApplicationPath.class))
137                .map(ApplicationPath::value)
138                .findFirst()
139                .map(s -> !s.startsWith("/") ? "/" + s : s)
140                .orElse("/api/v1");
141        final String version = appBase.replaceFirst("/api", "");
142
143        final DestinationRegistry registry = cxf.getRegistry();
144        final ServletContext servletContext = sce.getServletContext();
145
146        final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry);
147        final ServletController controller = new ServletController(webSocketRegistry,
148                new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"),
149                new ServiceListGeneratorServlet(registry, bus));
150        webSocketRegistry.controller = controller;
151
152        Stream
153                .concat(factory
154                        .getClassResourceInfo()
155                        .stream()
156                        .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream())
157                        .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend."))
158                        .map(ori -> {
159                            final String uri = ori.getClassResourceInfo().getURITemplate().getValue()
160                                    + ori.getURITemplate().getValue();
161                            return ServerEndpointConfig.Builder
162                                    .create(Endpoint.class,
163                                            "/websocket" + version + "/"
164                                                    + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri)
165                                    .configurator(new ServerEndpointConfig.Configurator() {
166
167                                        @Override
168                                        public <T> T getEndpointInstance(final Class<T> clazz)
169                                                throws InstantiationException {
170                                            final Map<String, List<String>> headers = new HashMap<>();
171                                            if (!ori.getProduceTypes().isEmpty()) {
172                                                headers
173                                                        .put(HttpHeaders.CONTENT_TYPE, singletonList(
174                                                                ori.getProduceTypes().iterator().next().toString()));
175                                            }
176                                            if (!ori.getConsumeTypes().isEmpty()) {
177                                                headers
178                                                        .put(HttpHeaders.ACCEPT, singletonList(
179                                                                ori.getConsumeTypes().iterator().next().toString()));
180                                            }
181                                            return (T) new JAXRSEndpoint(appBase, controller, servletContext,
182                                                    ori.getHttpMethod(), uri, headers);
183                                        }
184                                    })
185                                    .build();
186                        }),
187                        Stream
188                                .of(ServerEndpointConfig.Builder
189                                        .create(Endpoint.class, "/websocket" + version + "/bus")
190                                        .configurator(new ServerEndpointConfig.Configurator() {
191
192                                            @Override
193                                            public <T> T getEndpointInstance(final Class<T> clazz)
194                                                    throws InstantiationException {
195
196                                                return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET",
197                                                        "/", emptyMap());
198                                            }
199                                        })
200                                        .build()))
201                .sorted(Comparator.comparing(ServerEndpointConfig::getPath))
202                .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath()))
203                .forEach(config -> {
204                    try {
205                        container.addEndpoint(config);
206                    } catch (final DeploymentException e) {
207                        throw new IllegalStateException(e);
208                    }
209                });
210    }
211
212    @Data
213    @EqualsAndHashCode(callSuper = false)
214    private static class JAXRSEndpoint extends Endpoint {
215
216        private final String appBase;
217
218        private final ServletController controller;
219
220        private final ServletContext context;
221
222        private final String defaultMethod;
223
224        private final String defaultUri;
225
226        private final Map<String, List<String>> baseHeaders;
227
228        @RequiredArgsConstructor
229        private class PartialMessageHandler implements Partial<byte[]> {
230
231            private final Session session;
232
233            private InMemoryRequest request;
234
235            private InMemoryResponse response;
236
237            private void handleStart(final StringBuilder buffer, final InputStream message) {
238                final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
239                headers.putAll(baseHeaders);
240
241                try { // read headers from the message
242                    String line;
243                    int del;
244                    while ((line = readLine(buffer, message)) != null) {
245                        final boolean done = line.endsWith(EOM);
246                        if (done) {
247                            line = line.substring(0, line.length() - EOM.length());
248                        }
249                        if (!line.isEmpty()) {
250                            del = line.indexOf(':');
251                            if (del < 0) {
252                                headers.put(line.trim(), emptyList());
253                            } else {
254                                headers
255                                        .put(line.substring(0, del).trim(),
256                                                singletonList(line.substring(del + 1).trim()));
257                            }
258                        }
259                        if (done) {
260                            break;
261                        }
262                    }
263                } catch (final IOException ioe) {
264                    throw new IllegalStateException(ioe);
265                }
266
267                final List<String> uris = headers.get("destination");
268                final String uri;
269                if (uris == null || uris.isEmpty()) {
270                    uri = defaultUri;
271                } else {
272                    uri = uris.iterator().next();
273                }
274
275                final List<String> methods = headers.get("destinationMethod");
276                final String method;
277                if (methods == null || methods.isEmpty()) {
278                    method = defaultMethod;
279                } else {
280                    method = methods.iterator().next();
281                }
282
283                final String queryString;
284                final String path;
285                final int query = uri.indexOf('?');
286                if (query > 0) {
287                    queryString = uri.substring(query + 1);
288                    path = uri.substring(0, query);
289                } else {
290                    queryString = null;
291                    path = uri;
292                }
293
294                request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path,
295                        appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message),
296                        session::getUserPrincipal, controller);
297                response = new InMemoryResponse(session::isOpen,
298                        () -> {
299                            if (session.getBasicRemote().getBatchingAllowed()) {
300                                try {
301                                    session.getBasicRemote().flushBatch();
302                                } catch (final IOException e) {
303                                    throw new IllegalStateException(e);
304                                }
305                            }
306                        }, bytes -> {
307                            try {
308                                session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes));
309                            } catch (final IOException e) {
310                                throw new IllegalStateException(e);
311                            }
312                        }, (status, responseHeaders) -> {
313                            final StringBuilder top = new StringBuilder("MESSAGE\r\n");
314                            top.append("status: ").append(status).append("\r\n");
315                            responseHeaders
316                                    .forEach((k,
317                                            v) -> top.append(k)
318                                                    .append(": ")
319                                                    .append(String.join(",", v))
320                                                    .append("\r\n"));
321                            top.append("\r\n");// empty line, means the next bytes are the payload
322                            return top.toString();
323                        }) {
324
325                    @Override
326                    protected void onClose(final OutputStream stream) throws IOException {
327                        stream.write(EOM.getBytes(StandardCharsets.UTF_8));
328                    }
329                };
330                request.setResponse(response);
331            }
332
333            @Override
334            public void onMessage(final byte[] byteBuffer, final boolean last) {
335                final ByteArrayInputStream message = new ByteArrayInputStream(byteBuffer);
336
337                final StringBuilder buffer = new StringBuilder(128);
338                try { // read headers from the message
339                    if (request != null) {
340                        ((WebSocketInputStream) request.getInputStream()).addStream(message);
341                    } else if ("SEND".equalsIgnoreCase(readLine(buffer, message))) {
342                        handleStart(buffer, message);
343                    } else {
344                        throw new IllegalArgumentException("not a message");
345                    }
346                } catch (IOException e) {
347                    throw new IllegalStateException(e);
348                }
349
350                if (last) {
351                    try {
352                        controller.invoke(request, response);
353                    } catch (final ServletException e) {
354                        throw new IllegalArgumentException(e);
355                    } finally {
356                        request = null;
357                        response = null;
358                    }
359                }
360            }
361        }
362
363        @Override
364        public void onOpen(final Session session, final EndpointConfig endpointConfig) {
365            log.debug("Opened session {}", session.getId());
366            session.addMessageHandler(byte[].class, new PartialMessageHandler(session));
367        }
368
369        @Override
370        public void onClose(final Session session, final CloseReason closeReason) {
371            log.debug("Closed session {}", session.getId());
372        }
373
374        @Override
375        public void onError(final Session session, final Throwable throwable) {
376            log.warn("Error for session {}", session.getId(), throwable);
377        }
378
379        private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException {
380            int c;
381            while ((c = in.read()) != -1) {
382                if (c == '\n') {
383                    break;
384                } else if (c != '\r') {
385                    buffer.append((char) c);
386                }
387            }
388
389            if (buffer.length() == 0) {
390                return null;
391            }
392            final String string = buffer.toString();
393            buffer.setLength(0);
394            return string;
395        }
396    }
397
398    private static class WebSocketInputStream extends MemoryInputStream {
399
400        private int previous = Integer.MAX_VALUE;
401
402        private final Queue<InputStream> queue = new LinkedList<>();
403
404        private WebSocketInputStream(final InputStream delegate) {
405            super(delegate);
406            queue.add(delegate);
407        }
408
409        @Override
410        public int read() throws IOException {
411            if (finished) {
412                return -1;
413            }
414            if (previous != Integer.MAX_VALUE) {
415                previous = Integer.MAX_VALUE;
416                return previous;
417            }
418            final int read = delegate().read();
419            if (read == '^') {
420                previous = delegate().read();
421                if (previous == '@') {
422                    finished = true;
423                    return -1;
424                }
425            }
426            if (read < 0) {
427                finished = true;
428            }
429            return read;
430        }
431
432        private InputStream delegate() throws IOException {
433            if (queue.isEmpty()) {
434                throw new IOException("Don't have an input stream.");
435            }
436
437            if (queue.peek().available() == 0) {
438                queue.remove();
439            }
440
441            return queue.peek();
442        }
443
444        public void addStream(final InputStream stream) {
445            queue.add(stream);
446        }
447    }
448
449    private static class WebSocketRegistry implements DestinationRegistry {
450
451        private final DestinationRegistry delegate;
452
453        private ServletController controller;
454
455        private WebSocketRegistry(final DestinationRegistry registry) {
456            this.delegate = registry;
457        }
458
459        @Override
460        public void addDestination(final AbstractHTTPDestination destination) {
461            throw new UnsupportedOperationException();
462        }
463
464        @Override
465        public void removeDestination(final String path) {
466            throw new UnsupportedOperationException();
467        }
468
469        @Override
470        public AbstractHTTPDestination getDestinationForPath(final String path) {
471            return wrap(delegate.getDestinationForPath(path));
472        }
473
474        @Override
475        public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) {
476            return wrap(delegate.getDestinationForPath(path, tryDecoding));
477        }
478
479        @Override
480        public AbstractHTTPDestination checkRestfulRequest(final String address) {
481            return wrap(delegate.checkRestfulRequest(address));
482        }
483
484        @Override
485        public Collection<AbstractHTTPDestination> getDestinations() {
486            return delegate.getDestinations();
487        }
488
489        @Override
490        public AbstractDestination[] getSortedDestinations() {
491            return delegate.getSortedDestinations();
492        }
493
494        @Override
495        public Set<String> getDestinationsPaths() {
496            return delegate.getDestinationsPaths();
497        }
498
499        private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) {
500            try {
501                return destination == null ? null : new WebSocketDestination(destination, this);
502            } catch (final IOException e) {
503                throw new IllegalStateException(e);
504            }
505        }
506    }
507
508    private static class WebSocketDestination extends AbstractHTTPDestination {
509
510        static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class);
511
512        private final AbstractHTTPDestination delegate;
513
514        private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry)
515                throws IOException {
516            super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false);
517            this.delegate = delegate;
518            this.cproviderFactory = new WebSocketContinuationFactory(registry);
519        }
520
521        @Override
522        public EndpointReferenceType getAddress() {
523            return delegate.getAddress();
524        }
525
526        @Override
527        public Conduit getBackChannel(final Message inMessage) throws IOException {
528            return delegate.getBackChannel(inMessage);
529        }
530
531        @Override
532        public EndpointInfo getEndpointInfo() {
533            return delegate.getEndpointInfo();
534        }
535
536        @Override
537        public void shutdown() {
538            throw new UnsupportedOperationException();
539        }
540
541        @Override
542        public void setMessageObserver(final MessageObserver observer) {
543            throw new UnsupportedOperationException();
544        }
545
546        @Override
547        public MessageObserver getMessageObserver() {
548            return delegate.getMessageObserver();
549        }
550
551        @Override
552        protected Logger getLogger() {
553            return LOG;
554        }
555
556        @Override
557        public Bus getBus() {
558            return delegate.getBus();
559        }
560
561        @Override
562        public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req,
563                final HttpServletResponse resp) throws IOException {
564            // eager create the message to ensure we set our continuation for @Suspended
565            Message inMessage = retrieveFromContinuation(req);
566            if (inMessage == null) {
567                inMessage = new MessageImpl();
568
569                final ExchangeImpl exchange = new ExchangeImpl();
570                exchange.setInMessage(inMessage);
571                setupMessage(inMessage, config, context, req, resp);
572
573                exchange.setSession(new HTTPSession(req));
574                MessageImpl.class.cast(inMessage).setDestination(this);
575            }
576
577            delegate.invoke(config, context, req, resp);
578        }
579
580        @Override
581        public void finalizeConfig() {
582            delegate.finalizeConfig();
583        }
584
585        @Override
586        public String getBeanName() {
587            return delegate.getBeanName();
588        }
589
590        @Override
591        public EndpointReferenceType getAddressWithId(final String id) {
592            return delegate.getAddressWithId(id);
593        }
594
595        @Override
596        public String getId(final Map<String, Object> context) {
597            return delegate.getId(context);
598        }
599
600        @Override
601        public String getContextMatchStrategy() {
602            return delegate.getContextMatchStrategy();
603        }
604
605        @Override
606        public boolean isFixedParameterOrder() {
607            return delegate.isFixedParameterOrder();
608        }
609
610        @Override
611        public boolean isMultiplexWithAddress() {
612            return delegate.isMultiplexWithAddress();
613        }
614
615        @Override
616        public HTTPServerPolicy getServer() {
617            return delegate.getServer();
618        }
619
620        @Override
621        public void assertMessage(final Message message) {
622            delegate.assertMessage(message);
623        }
624
625        @Override
626        public boolean canAssert(final QName type) {
627            return delegate.canAssert(type);
628        }
629
630        @Override
631        public String getPath() {
632            return delegate.getPath();
633        }
634    }
635
636    private static class WebSocketContinuationFactory implements ContinuationProviderFactory {
637
638        private static final String KEY = WebSocketContinuationFactory.class.getName();
639
640        private final WebSocketRegistry registry;
641
642        private WebSocketContinuationFactory(final WebSocketRegistry registry) {
643            this.registry = registry;
644        }
645
646        @Override
647        public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req,
648                final HttpServletResponse resp) {
649            return new WebSocketContinuation(inMessage, req, resp, registry);
650        }
651
652        @Override
653        public Message retrieveFromContinuation(final HttpServletRequest req) {
654            return Message.class.cast(req.getAttribute(KEY));
655        }
656    }
657
658    private static class WebSocketContinuation implements ContinuationProvider, Continuation {
659
660        private final Message message;
661
662        private final HttpServletRequest request;
663
664        private final HttpServletResponse response;
665
666        private final WebSocketRegistry registry;
667
668        private final ContinuationCallback callback;
669
670        private Object object;
671
672        private boolean resumed;
673
674        private boolean pending;
675
676        private boolean isNew;
677
678        private WebSocketContinuation(final Message message, final HttpServletRequest request,
679                final HttpServletResponse response, final WebSocketRegistry registry) {
680            this.message = message;
681            this.request = request;
682            this.response = response;
683            this.registry = registry;
684            this.request
685                    .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE,
686                            message.getExchange().getInMessage());
687            this.callback = message.getExchange().get(ContinuationCallback.class);
688        }
689
690        @Override
691        public Continuation getContinuation() {
692            return this;
693        }
694
695        @Override
696        public void complete() {
697            message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE);
698            if (callback != null) {
699                final Exception ex = message.getExchange().get(Exception.class);
700                if (ex == null) {
701                    callback.onComplete();
702                } else {
703                    callback.onError(ex);
704                }
705            }
706            try {
707                response.getWriter().close();
708            } catch (final IOException e) {
709                throw new IllegalStateException(e);
710            }
711        }
712
713        @Override
714        public boolean suspend(final long timeout) {
715            isNew = false;
716            resumed = false;
717            pending = true;
718            message.getExchange().getInMessage().getInterceptorChain().suspend();
719            return true;
720        }
721
722        @Override
723        public void resume() {
724            resumed = true;
725            try {
726                registry.controller.invoke(request, response);
727            } catch (final ServletException e) {
728                throw new IllegalStateException(e);
729            }
730        }
731
732        @Override
733        public void reset() {
734            pending = false;
735            resumed = false;
736            isNew = false;
737            object = null;
738        }
739
740        @Override
741        public boolean isNew() {
742            return isNew;
743        }
744
745        @Override
746        public boolean isPending() {
747            return pending;
748        }
749
750        @Override
751        public boolean isResumed() {
752            return resumed;
753        }
754
755        @Override
756        public boolean isTimeout() {
757            return false;
758        }
759
760        @Override
761        public Object getObject() {
762            return object;
763        }
764
765        @Override
766        public void setObject(final Object o) {
767            object = o;
768        }
769
770        @Override
771        public boolean isReadyForWrite() {
772            return true;
773        }
774    }
775}