/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ml.OpenAI;
import apoc.result.StringResult;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.jetbrains.annotations.NotNull;
import org.neo4j.graphdb.QueryExecutionException;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class Prompt {
    @Context
    public Transaction tx;
    @Context
    public Log log;
    @Context
    public ApocConfig apocConfig;
    @Context
    public ProcedureCallContext procedureCallContext;
    public static final String BACKTICKS = "```";
    public static final String EXPLAIN_SCHEMA_PROMPT = "You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.\nExplain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.\nKeep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm.\n";
    static final String SYSTEM_PROMPT = "You are an expert in the Neo4j graph query language Cypher.\nGiven a graph database schema of entities (nodes) with labels and attributes and\nrelationships with start- and end-node, relationship-type, direction and properties\nyou are able to develop read only matching Cypher statements that express a user question as a graph database query.\nOnly answer with a single Cypher statement in triple backticks, if you can't determine a statement, answer with an empty response.\nDo not explain, apologize or provide additional detail, otherwise people will come to harm.\n";
    private static final String SCHEMA_QUERY = "call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})\nYIELD label, other, elementType, type, property\nWITH label, elementType, \n     apoc.text.join(collect(case when NOT type = \"RELATIONSHIP\" then property+\": \"+type else null end),\", \") AS properties,    \n     collect(case when type = \"RELATIONSHIP\" AND elementType = \"node\" then \"(:\" + label + \")-[:\" + property + \"]->(:\" + toString(other[0]) + \")\" else null end) as patterns\nwith  elementType as type, \napoc.text.join(collect(\":\"+label+\" {\"+properties+\"}\"),\"\\n\") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),\"\\n\") as patterns\nreturn collect(case type when \"relationship\" then entities end)[0] as relationships, \ncollect(case type when \"node\" then entities end)[0] as nodes, \ncollect(case type when \"node\" then patterns end)[0] as patterns \n";
    private static final String SCHEMA_PROMPT = "    nodes:\n    %s\n    relationships:\n    %s\n    patterns:\n    %s\n";

    @Procedure(mode=Mode.READ)
    public Stream<PromptMapResult> query(@Name(value="question") String question, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) {
        String schema = this.loadSchema(this.tx, conf);
        String query = "";
        long retries = (Long)conf.getOrDefault("retries", 3L);
        boolean containsField = this.procedureCallContext.outputFields().collect(Collectors.toSet()).contains("query");
        while (true) {
            try {
                QueryResult queryResult = this.tryQuery(question, conf, schema);
                query = queryResult.query;
                return this.tx.execute(queryResult.query).stream().map(row -> containsField ? new PromptMapResult((Map<String, Object>)row, queryResult.query) : new PromptMapResult((Map<String, Object>)row));
            }
            catch (QueryExecutionException quee) {
                if (!this.log.isDebugEnabled()) continue;
                this.log.debug("Generated query for question %s\n%s\nfailed with %s".formatted(question, query, quee.getMessage()));
                if (--retries > 0L) continue;
                throw quee;
            }
            break;
        }
    }

    @Procedure
    public Stream<StringResult> schema(@Name(value="conf", defaultValue="{}") Map<String, Object> conf) throws MalformedURLException, JsonProcessingException {
        String schemaExplanation = this.prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", EXPLAIN_SCHEMA_PROMPT, "This database schema ", this.loadSchema(this.tx, conf), conf);
        return Stream.of(new StringResult(schemaExplanation));
    }

    @Procedure(mode=Mode.READ)
    public Stream<QueryResult> cypher(@Name(value="question") String question, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) {
        String schema = this.loadSchema(this.tx, conf);
        long count = (Long)conf.getOrDefault("count", 1L);
        return LongStream.rangeClosed(1L, count).mapToObj(i -> this.tryQuery(question, conf, schema));
    }

    @NotNull
    private QueryResult tryQuery(String question, Map<String, Object> conf, String schema) {
        String query = "";
        try {
            query = this.prompt(question, SYSTEM_PROMPT, "Cypher Statement (in backticks):", schema, conf);
            return new QueryResult(query, null, null);
        }
        catch (QueryExecutionException e) {
            return new QueryResult(query, e.getMessage(), e.getStatusCode());
        }
        catch (Exception e) {
            return new QueryResult(query, e.getMessage(), e.getClass().getSimpleName());
        }
    }

    @NotNull
    private String prompt(String userQuestion, String systemPrompt, String assistantPrompt, String schema, Map<String, Object> conf) throws JsonProcessingException, MalformedURLException {
        ArrayList<Map<String, Object>> prompt = new ArrayList<Map<String, Object>>();
        if (systemPrompt != null && !systemPrompt.isBlank()) {
            prompt.add(Map.of("role", "system", "content", systemPrompt));
        }
        if (schema != null && !schema.isBlank()) {
            prompt.add(Map.of("role", "system", "content", "The graph database schema consists of these elements\n" + schema));
        }
        if (userQuestion != null && !userQuestion.isBlank()) {
            prompt.add(Map.of("role", "user", "content", userQuestion));
        }
        if (assistantPrompt != null && !assistantPrompt.isBlank()) {
            prompt.add(Map.of("role", "assistant", "content", assistantPrompt));
        }
        String apiKey = (String)conf.get("apiKey");
        String model = (String)conf.getOrDefault("model", "gpt-3.5-turbo");
        String result = OpenAI.executeRequest(apiKey, Map.of(), "chat/completions", model, "messages", prompt, "$", this.apocConfig).map(v -> (Map)v).flatMap(m -> ((List)m.get("choices")).stream()).map(m -> (String)((Map)m.get("message")).get("content")).filter(s -> s != null && !s.isBlank()).map(s -> s.contains(BACKTICKS) ? s.substring(s.indexOf(BACKTICKS) + 3, s.lastIndexOf(BACKTICKS)) : s).collect(Collectors.joining(" ")).replaceAll("\n\n+", "\n");
        if (this.log.isDebugEnabled()) {
            this.log.debug("Generated query for question %s\n%s".formatted(userQuestion, result));
        }
        return result;
    }

    private String loadSchema(Transaction tx, Map<String, Object> conf) {
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("sample", conf.get("sample"));
        return tx.execute(SCHEMA_QUERY, params).stream().map(m -> SCHEMA_PROMPT.formatted(m.get("nodes"), m.get("relationships"), m.get("patterns"))).collect(Collectors.joining("\n"));
    }

    public class QueryResult {
        public final String query;

        public QueryResult(String query, String error, String type) {
            this.query = query;
        }

        public boolean hasError() {
            return false;
        }
    }

    public class PromptMapResult {
        public final Map<String, Object> value;
        public final String query;

        public PromptMapResult(Map<String, Object> value, String query) {
            this.value = value;
            this.query = query;
        }

        public PromptMapResult(Map<String, Object> value) {
            this.value = value;
            this.query = null;
        }
    }
}

