/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rpc;

import ai.djl.inference.streaming.ChunkedBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.util.Utils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class RpcClient {
    private static final Logger logger = LoggerFactory.getLogger(RpcClient.class);
    private static final Set<String> RESERVED_KEYS = new HashSet<String>(Arrays.asList("engine", "translatorfactory", "translator", "model_name", "artifact_id", "application", "task", "djl_rpc_uri", "method", "api_key", "content-type"));
    private URL url;
    private String method;
    private Map<CaseInsensitiveKey, String> headers;

    private RpcClient(URL url, String method, Map<CaseInsensitiveKey, String> headers) {
        this.url = url;
        this.method = method;
        this.headers = headers;
    }

    public static RpcClient getClient(Map<String, ?> arguments) throws MalformedURLException {
        String url = arguments.get("djl_rpc_uri").toString();
        String method = RpcClient.getOrDefault(arguments, "method", "POST");
        String apiKey = RpcClient.getOrDefault(arguments, "api_key", null);
        String contentType = RpcClient.getOrDefault(arguments, "content-type", "application/json");
        ConcurrentHashMap<CaseInsensitiveKey, String> httpHeaders = new ConcurrentHashMap<CaseInsensitiveKey, String>();
        for (Map.Entry<String, ?> entry : arguments.entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue().toString().trim();
            if (RESERVED_KEYS.contains(key.toLowerCase(Locale.ROOT))) continue;
            httpHeaders.put(new CaseInsensitiveKey(key), value);
        }
        httpHeaders.put(new CaseInsensitiveKey("Content-Type"), contentType);
        String authHeader = "Authorization";
        if (url.startsWith("https://generativelanguage.googleapis.com/")) {
            if (apiKey == null && (apiKey = Utils.getenv("GEMINI_API_KEY")) == null) {
                apiKey = Utils.getenv("GOOGLE_API_KEY");
            }
            if (!url.endsWith("/openai/chat/completions")) {
                authHeader = "x-goog-api-key";
            }
        } else if (url.startsWith("https://api.anthropic.com/")) {
            if (apiKey == null) {
                apiKey = Utils.getEnvOrSystemProperty("ANTHROPIC_API_KEY");
            }
            authHeader = "x-api-key";
        } else if (url.startsWith("https://api.openai.com/")) {
            if (apiKey == null) {
                apiKey = Utils.getEnvOrSystemProperty("OPENAI_API_KEY");
            }
            httpHeaders.put(new CaseInsensitiveKey("OpenAI-Organization"), "org-IS5aEokdvmbYXyWeJhhwe5Xn");
            String project = Utils.getEnvOrSystemProperty("OPENAI_PROJECT");
            if (project != null) {
                httpHeaders.put(new CaseInsensitiveKey("OpenAI-Project"), project);
            }
        }
        if (apiKey == null) {
            apiKey = Utils.getEnvOrSystemProperty("GENAI_API_KEY");
        }
        if (apiKey != null) {
            if ("Authorization".equals(authHeader)) {
                httpHeaders.put(new CaseInsensitiveKey(authHeader), "Bearer " + apiKey);
            } else {
                httpHeaders.put(new CaseInsensitiveKey(authHeader), apiKey);
            }
        }
        return new RpcClient(new URL(url), method, httpHeaders);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Output send(Input input) throws IOException {
        if (Utils.isOfflineMode()) {
            throw new IOException("Offline mode is enabled.");
        }
        HttpURLConnection conn = (HttpURLConnection)this.url.openConnection();
        boolean isStream = false;
        try {
            InputStream is;
            conn.setRequestMethod(this.method);
            if ("POST".equals(this.method) || "PUT".equals(this.method)) {
                conn.setDoOutput(true);
            }
            Map<String, String> prop = input.getProperties();
            ConcurrentHashMap<CaseInsensitiveKey, String> reqHeaders = new ConcurrentHashMap<CaseInsensitiveKey, String>(this.headers);
            for (Map.Entry<String, String> entry : prop.entrySet()) {
                reqHeaders.put(new CaseInsensitiveKey(entry.getKey()), entry.getValue());
            }
            for (Map.Entry<String, String> entry : reqHeaders.entrySet()) {
                conn.addRequestProperty(((CaseInsensitiveKey)((Object)entry.getKey())).key, entry.getValue());
            }
            conn.connect();
            BytesSupplier content = input.getData();
            if (content != null) {
                try (OutputStream outputStream = conn.getOutputStream();){
                    outputStream.write(content.getAsBytes());
                }
            }
            int n = conn.getResponseCode();
            Output out = new Output(n, conn.getResponseMessage());
            Map<String, List<String>> respHeaders = conn.getHeaderFields();
            for (Map.Entry<String, List<String>> entry : respHeaders.entrySet()) {
                String key = entry.getKey();
                String value = entry.getValue().get(0);
                if (key == null || value == null) continue;
                value = value.toLowerCase(Locale.ROOT);
                if ("content-type".equalsIgnoreCase(key) && (value.startsWith("text/event-stream") || value.startsWith("application/jsonlines"))) {
                    isStream = true;
                }
                out.addProperty(key, value);
            }
            if (n == 200) {
                if (isStream) {
                    ChunkedBytesSupplier cbs = new ChunkedBytesSupplier();
                    out.add(cbs);
                    CompletableFuture.supplyAsync(() -> RpcClient.handleStream(conn, cbs));
                } else {
                    is = conn.getInputStream();
                    try {
                        out.add(Utils.toByteArray(is));
                    }
                    finally {
                        if (is != null) {
                            is.close();
                        }
                    }
                }
            } else {
                is = conn.getErrorStream();
                try {
                    if (is != null) {
                        String error = Utils.toString(is);
                        out.add(error);
                        logger.warn("Failed to invoke model server: {}", (Object)error);
                    } else {
                        logger.warn("Failed to invoke model server, code: {}", (Object)n);
                    }
                }
                finally {
                    if (is != null) {
                        is.close();
                    }
                }
            }
            Output output = out;
            return output;
        }
        finally {
            if (!isStream) {
                conn.disconnect();
            }
        }
    }

    private static String getOrDefault(Map<String, ?> arguments, String key, String def) {
        for (Map.Entry<String, ?> entry : arguments.entrySet()) {
            if (!entry.getKey().equalsIgnoreCase(key)) continue;
            return entry.getValue().toString();
        }
        return def;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Void handleStream(HttpURLConnection conn, ChunkedBytesSupplier cbs) {
        BytesSupplier pendingChunk = null;
        try (InputStreamReader r = new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8);
             BufferedReader reader = new BufferedReader(r);){
            String line;
            StringBuilder sb = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                if (line.startsWith("data: ")) {
                    if (sb.length() > 0) {
                        sb.append('\n');
                    }
                    sb.append(line.substring(6));
                    continue;
                }
                if (line.startsWith("event: ")) continue;
                if (!line.isEmpty()) {
                    if (pendingChunk != null) {
                        cbs.appendContent(pendingChunk, false);
                    }
                    pendingChunk = BytesSupplier.wrap(line);
                    continue;
                }
                if (sb.length() <= 0) continue;
                if (pendingChunk != null) {
                    cbs.appendContent(pendingChunk, false);
                }
                pendingChunk = BytesSupplier.wrap(sb.toString());
                sb.setLength(0);
            }
        }
        catch (IOException e) {
            logger.warn("Failed run inference.", (Throwable)e);
            cbs.appendContent(BytesSupplier.wrap("connection abort exceptionally"), false);
        }
        finally {
            if (pendingChunk == null) {
                pendingChunk = BytesSupplier.wrap(new byte[0]);
            }
            cbs.appendContent(pendingChunk, true);
            conn.disconnect();
        }
        return null;
    }

    static final class CaseInsensitiveKey {
        String key;

        public CaseInsensitiveKey(String key) {
            this.key = key;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof CaseInsensitiveKey)) {
                return false;
            }
            CaseInsensitiveKey header = (CaseInsensitiveKey)o;
            return this.key.equalsIgnoreCase(header.key);
        }

        public int hashCode() {
            return Objects.hashCode(this.key.toLowerCase(Locale.ROOT));
        }
    }
}

