001/**
002 * Copyright (C) 2006-2022 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.server.front;
017
018import static java.util.Collections.emptyMap;
019import static java.util.Optional.ofNullable;
020import static java.util.concurrent.CompletableFuture.completedFuture;
021import static java.util.stream.Collectors.joining;
022import static java.util.stream.Collectors.toList;
023import static java.util.stream.Collectors.toSet;
024
025import java.io.ByteArrayInputStream;
026import java.io.ByteArrayOutputStream;
027import java.io.IOException;
028import java.io.InputStream;
029import java.nio.charset.StandardCharsets;
030import java.security.Principal;
031import java.util.Collection;
032import java.util.Collections;
033import java.util.List;
034import java.util.Map;
035import java.util.concurrent.CompletableFuture;
036import java.util.concurrent.CompletionStage;
037import java.util.concurrent.ExecutionException;
038import java.util.stream.Stream;
039
040import javax.annotation.PostConstruct;
041import javax.enterprise.context.ApplicationScoped;
042import javax.inject.Inject;
043import javax.json.bind.Jsonb;
044import javax.servlet.ServletContext;
045import javax.servlet.ServletException;
046import javax.servlet.http.HttpServletRequest;
047import javax.ws.rs.HttpMethod;
048import javax.ws.rs.WebApplicationException;
049import javax.ws.rs.core.Context;
050import javax.ws.rs.core.Response;
051import javax.ws.rs.core.UriInfo;
052
053import org.apache.cxf.Bus;
054import org.apache.cxf.transport.http.DestinationRegistry;
055import org.apache.cxf.transport.servlet.ServletController;
056import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet;
057import org.talend.sdk.component.server.api.BulkReadResource;
058import org.talend.sdk.component.server.front.cxf.CxfExtractor;
059import org.talend.sdk.component.server.front.memory.InMemoryRequest;
060import org.talend.sdk.component.server.front.memory.InMemoryResponse;
061import org.talend.sdk.component.server.front.memory.MemoryInputStream;
062import org.talend.sdk.component.server.front.memory.SimpleServletConfig;
063import org.talend.sdk.component.server.front.model.BulkRequests;
064import org.talend.sdk.component.server.front.model.BulkResponses;
065import org.talend.sdk.component.server.front.model.ErrorDictionary;
066import org.talend.sdk.component.server.front.model.error.ErrorPayload;
067import org.talend.sdk.component.server.service.qualifier.ComponentServer;
068
069import lombok.extern.slf4j.Slf4j;
070
071@Slf4j
072@ApplicationScoped
073public class BulkReadResourceImpl implements BulkReadResource {
074
075    private static final CompletableFuture[] EMPTY_PROMISES = new CompletableFuture[0];
076
077    @Inject
078    private CxfExtractor cxf;
079
080    @Inject
081    private Bus bus;
082
083    @Inject
084    @Context
085    private ServletContext servletContext;
086
087    @Inject
088    @Context
089    private HttpServletRequest httpServletRequest;
090
091    @Inject
092    @Context
093    private UriInfo uriInfo;
094
095    @Inject
096    @Context
097    private HttpServletRequest request;
098
099    @Inject
100    @ComponentServer
101    private Jsonb defaultMapper;
102
103    private ServletController controller;
104
105    private final String appPrefix = "/api/v1";
106
107    private final Collection<String> blacklisted =
108            Stream.of(appPrefix + "/component/icon/", appPrefix + "/component/dependency/").collect(toSet());
109
110    private final BulkResponses.Result forbiddenInBulkModeResponse =
111            new BulkResponses.Result(Response.Status.FORBIDDEN.getStatusCode(), emptyMap(),
112                    "{\"code\":\"UNAUTHORIZED\",\"description\":\"Forbidden endpoint in bulk mode.\"}"
113                            .getBytes(StandardCharsets.UTF_8));
114
115    private final BulkResponses.Result forbiddenResponse =
116            new BulkResponses.Result(Response.Status.FORBIDDEN.getStatusCode(), emptyMap(),
117                    "{\"code\":\"UNAUTHORIZED\",\"description\":\"Secured endpoint, ensure to pass the right token.\"}"
118                            .getBytes(StandardCharsets.UTF_8));
119
120    private final BulkResponses.Result invalidResponse =
121            new BulkResponses.Result(Response.Status.BAD_REQUEST.getStatusCode(), emptyMap(),
122                    "{\"code\":\"UNEXPECTED\",\"description\":\"unknownEndpoint.\"}".getBytes(StandardCharsets.UTF_8));
123
124    @PostConstruct
125    private void init() {
126        final DestinationRegistry registry = cxf.getRegistry();
127        controller = new ServletController(registry,
128                new SimpleServletConfig(servletContext, "Talend Component Kit Bulk Transport"),
129                new ServiceListGeneratorServlet(registry, bus));
130    }
131
132    @Override
133    public CompletionStage<BulkResponses> bulk(final BulkRequests requests) {
134        final Collection<CompletableFuture<BulkResponses.Result>> responses =
135                ofNullable(requests.getRequests()).map(Collection::stream).orElseGet(Stream::empty).map(request -> {
136                    if (isBlacklisted(request)) {
137                        return completedFuture(forbiddenInBulkModeResponse);
138                    }
139                    if (request.getPath() == null || !request.getPath().startsWith(appPrefix)
140                            || request.getPath().contains("?")) {
141                        return completedFuture(invalidResponse);
142                    }
143                    return doExecute(request, uriInfo);
144                }).collect(toList());
145        return CompletableFuture
146                .allOf(responses.toArray(EMPTY_PROMISES))
147                .handle((ignored, error) -> new BulkResponses(responses.stream().map(it -> {
148                    try {
149                        return it.get();
150                    } catch (final InterruptedException e) {
151                        Thread.currentThread().interrupt();
152                        throw new IllegalStateException(e);
153                    } catch (final ExecutionException e) {
154                        throw new WebApplicationException(Response
155                                .serverError()
156                                .entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, e.getMessage()))
157                                .build());
158                    }
159                }).collect(toList())));
160    }
161
162    private boolean isBlacklisted(final BulkRequests.Request request) {
163        return blacklisted.stream().anyMatch(it -> request.getPath() == null || request.getPath().startsWith(it));
164    }
165
166    private CompletableFuture<BulkResponses.Result> doExecute(final BulkRequests.Request inputRequest,
167            final UriInfo info) {
168        final Map<String, List<String>> headers =
169                ofNullable(inputRequest.getHeaders()).orElseGet(Collections::emptyMap);
170        final String path = ofNullable(inputRequest.getPath()).map(it -> it.substring(appPrefix.length())).orElse("/");
171
172        // theorically we should encode these params but should be ok this way for now - due to the param we can accept
173        final String queryString = ofNullable(inputRequest.getQueryParameters())
174                .map(Map::entrySet)
175                .map(Collection::stream)
176                .orElseGet(Stream::empty)
177                .flatMap(it -> ofNullable(it.getValue())
178                        .map(Collection::stream)
179                        .orElseGet(Stream::empty)
180                        .map(value -> it.getKey() + '=' + value))
181                .collect(joining("&"));
182
183        final int port = info.getBaseUri().getPort();
184        final Principal userPrincipal = request.getUserPrincipal(); // this is ap proxy so ready it early
185        final InMemoryRequest request = new InMemoryRequest(ofNullable(inputRequest.getVerb()).orElse(HttpMethod.GET),
186                headers, path, appPrefix + path, appPrefix, queryString, port < 0 ? 8080 : port, servletContext,
187                new MemoryInputStream(ofNullable(inputRequest.getPayload())
188                        .map(it -> it.getBytes(StandardCharsets.UTF_8))
189                        .map(ByteArrayInputStream::new)
190                        .map(InputStream.class::cast)
191                        .orElse(null)),
192                () -> userPrincipal, controller);
193        final BulkResponses.Result result = new BulkResponses.Result();
194        final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
195        final CompletableFuture<BulkResponses.Result> promise = new CompletableFuture<>();
196        final InMemoryResponse response = new InMemoryResponse(() -> true, () -> {
197            result.setResponse(outputStream.toByteArray());
198            promise.complete(result);
199        }, bytes -> {
200            try {
201                outputStream.write(bytes);
202            } catch (final IOException e) {
203                throw new IllegalStateException(e);
204            }
205        }, (status, responseHeaders) -> {
206            result.setStatus(status);
207            result.setHeaders(headers);
208            return "";
209        });
210        request.setResponse(response);
211        try {
212            controller.invoke(request, response);
213        } catch (final ServletException e) {
214            result.setStatus(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode());
215            result
216                    .setResponse(defaultMapper
217                            .toJson(new ErrorPayload(ErrorDictionary.UNEXPECTED, e.getMessage()))
218                            .getBytes(StandardCharsets.UTF_8));
219            promise.complete(result);
220            throw new IllegalStateException(e);
221        }
222        return promise;
223    }
224}