/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.watsonx.ai.chat;

import com.ibm.watsonx.ai.chat.ChatResponse;
import com.ibm.watsonx.ai.chat.model.ChatUsage;
import com.ibm.watsonx.ai.chat.model.CompletedToolCall;
import com.ibm.watsonx.ai.chat.model.ExtractionTags;
import com.ibm.watsonx.ai.chat.model.FinishReason;
import com.ibm.watsonx.ai.chat.model.PartialChatResponse;
import com.ibm.watsonx.ai.chat.model.PartialToolCall;
import com.ibm.watsonx.ai.chat.model.ResultMessage;
import com.ibm.watsonx.ai.chat.model.ToolCall;
import com.ibm.watsonx.ai.chat.streaming.StreamingStateTracker;
import com.ibm.watsonx.ai.chat.streaming.StreamingToolFetcher;
import com.ibm.watsonx.ai.core.Json;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

public class SseEventProcessor {
    private final Object usageLock = new Object();
    private volatile Map<Integer, String> finishReasons = new ConcurrentHashMap<Integer, String>();
    private volatile String role;
    private volatile String refusal;
    private volatile Long created;
    private volatile String createdAt;
    private volatile String id;
    private volatile String modelId;
    private volatile String object;
    private volatile String model;
    private volatile String modelVersion;
    private volatile boolean pendingSSEError = false;
    private ChatUsage chatUsage;
    private final Map<Integer, StringBuilder> contentBuffers = new ConcurrentHashMap<Integer, StringBuilder>();
    private final Map<Integer, StringBuilder> thinkingBuffers = new ConcurrentHashMap<Integer, StringBuilder>();
    private final Map<Integer, List<StreamingToolFetcher>> toolFetchers = new ConcurrentHashMap<Integer, List<StreamingToolFetcher>>();
    private final StreamingStateTracker stateTracker;
    private final Map<String, Boolean> toolHasParameters;
    private final ExtractionTags extractionTags;

    public SseEventProcessor(Map<String, Boolean> toolHasParameters, ExtractionTags extractionTags) {
        this.toolHasParameters = toolHasParameters;
        this.extractionTags = extractionTags;
        this.stateTracker = Objects.nonNull(extractionTags) ? new StreamingStateTracker(extractionTags) : null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public ProcessResult processChunk(String partialMessage) {
        String token;
        if (Objects.isNull(partialMessage) || partialMessage.isBlank()) {
            return ProcessResult.empty();
        }
        if (partialMessage.startsWith("event: error")) {
            this.pendingSSEError = true;
            return ProcessResult.empty();
        }
        if (partialMessage.startsWith("event: close")) {
            return ProcessResult.empty();
        }
        if (!partialMessage.startsWith("data:")) {
            return ProcessResult.empty();
        }
        String messageData = partialMessage.split("data: ")[1];
        if (this.pendingSSEError) {
            this.pendingSSEError = false;
            return ProcessResult.error(new RuntimeException(messageData));
        }
        PartialChatResponse chunk = (PartialChatResponse)Json.fromJson((String)messageData, PartialChatResponse.class);
        ArrayList<CallbackEvent> events = new ArrayList<CallbackEvent>();
        if (Objects.nonNull(chunk.usage())) {
            Object object = this.usageLock;
            synchronized (object) {
                this.chatUsage = chunk.usage();
            }
        }
        if (chunk.choices().size() == 0) {
            if (this.toolFetchers.isEmpty()) {
                return ProcessResult.empty();
            }
            this.toolFetchers.keySet().forEach(messageIndex -> {
                if (Objects.nonNull(this.finishReasons.get(messageIndex))) {
                    return;
                }
                this.finishReasons.put((Integer)messageIndex, FinishReason.TOOL_CALLS.value());
                List<StreamingToolFetcher> tools = this.toolFetchers.get(messageIndex);
                events.add(new CallbackEvent.CompleteToolCallEvent(tools.get(tools.size() - 1).build()));
            });
            return events.isEmpty() ? ProcessResult.empty() : ProcessResult.events(events);
        }
        PartialChatResponse.ResultChoice message = chunk.choices().get(0);
        Integer messageIndex2 = message.index();
        String finishReason = this.finishReasons.get(messageIndex2);
        StringBuilder contentBuffer = this.contentBuffers.computeIfAbsent(messageIndex2, StringBuilder::new);
        StringBuilder thinkingBuffer = this.thinkingBuffers.computeIfAbsent(messageIndex2, StringBuilder::new);
        if (Objects.isNull(this.created) && Objects.nonNull(chunk.created())) {
            this.created = chunk.created();
        }
        if (Objects.isNull(this.createdAt) && Objects.nonNull(chunk.createdAt())) {
            this.createdAt = chunk.createdAt();
        }
        if (Objects.isNull(this.id) && Objects.nonNull(chunk.id())) {
            this.id = chunk.id();
        }
        if (Objects.isNull(this.modelId) && Objects.nonNull(chunk.modelId())) {
            this.modelId = chunk.modelId();
        }
        if (Objects.isNull(this.object) && Objects.nonNull(chunk.object())) {
            this.object = chunk.object();
        }
        if (Objects.isNull(this.modelVersion) && Objects.nonNull(chunk.modelVersion())) {
            this.modelVersion = chunk.modelVersion();
        }
        if (Objects.isNull(this.model) && Objects.nonNull(chunk.model())) {
            this.model = chunk.model();
        }
        if (Objects.isNull(finishReason) && Objects.nonNull(message.finishReason())) {
            finishReason = message.finishReason();
            this.finishReasons.put(messageIndex2, finishReason);
        }
        if (Objects.isNull(this.role) && Objects.nonNull(message.delta().role())) {
            this.role = message.delta().role();
        }
        if (Objects.isNull(this.refusal) && Objects.nonNull(message.delta().refusal())) {
            this.refusal = message.delta().refusal();
        }
        if (message.delta().toolCalls() != null) {
            for (ToolCall deltaTool : message.delta().toolCalls()) {
                StreamingToolFetcher toolFetcher;
                Integer toolIndex = deltaTool.index();
                List tools = this.toolFetchers.computeIfAbsent(messageIndex2, ArrayList::new);
                if (toolIndex + 1 > tools.size()) {
                    toolFetcher = new StreamingToolFetcher(this.id, messageIndex2, toolIndex);
                    tools.add(toolFetcher);
                    if (toolIndex - 1 >= 0) {
                        events.add(new CallbackEvent.CompleteToolCallEvent(((StreamingToolFetcher)tools.get(toolIndex - 1)).build()));
                    }
                } else {
                    toolFetcher = (StreamingToolFetcher)tools.get(toolIndex);
                }
                toolFetcher.setId(deltaTool.id());
                if (!Objects.nonNull(deltaTool.function())) continue;
                toolFetcher.setName(deltaTool.function().name());
                toolFetcher.appendArguments(deltaTool.function().arguments());
                Boolean toolHasParameter = this.toolHasParameters.get(toolFetcher.getName());
                String arguments = Objects.isNull(toolHasParameter) || toolHasParameter != false ? deltaTool.function().arguments() : "{}";
                if (arguments.isEmpty()) continue;
                PartialToolCall partialToolCall = new PartialToolCall(this.id, messageIndex2, toolFetcher.getToolIndex(), toolFetcher.getId(), toolFetcher.getName(), arguments);
                events.add(new CallbackEvent.PartialToolCallEvent(partialToolCall));
            }
        }
        if (Objects.nonNull(message.delta().content()) && !(token = message.delta().content()).isEmpty()) {
            contentBuffer.append(token);
            if (Objects.nonNull(this.stateTracker)) {
                StreamingStateTracker.Result r = this.stateTracker.update(token);
                Optional<String> content = r.content();
                switch (r.state()) {
                    case RESPONSE: 
                    case NO_THINKING: {
                        content.ifPresent(c -> events.add(new CallbackEvent.PartialResponseEvent((String)c, chunk)));
                        break;
                    }
                    case THINKING: {
                        content.ifPresent(c -> {
                            thinkingBuffer.append((String)c);
                            events.add(new CallbackEvent.PartialThinkingEvent((String)c, chunk));
                        });
                        break;
                    }
                }
            } else {
                events.add(new CallbackEvent.PartialResponseEvent(token, chunk));
            }
        }
        if (Objects.nonNull(message.delta().reasoningContent())) {
            token = message.delta().reasoningContent();
            if (token.isEmpty()) {
                return ProcessResult.empty();
            }
            thinkingBuffer.append(token);
            events.add(new CallbackEvent.PartialThinkingEvent(token, chunk));
        }
        if (FinishReason.TOOL_CALLS.value().equals(finishReason)) {
            List<StreamingToolFetcher> tools = this.toolFetchers.get(messageIndex2);
            events.add(new CallbackEvent.CompleteToolCallEvent(tools.get(tools.size() - 1).build()));
        }
        return ProcessResult.events(events);
    }

    public ChatResponse buildResponse() {
        List<ChatResponse.ResultChoice> choices = this.contentBuffers.keySet().stream().map(key -> {
            String content = this.contentBuffers.get(key).isEmpty() ? null : this.contentBuffers.get(key).toString().trim();
            String thinking = this.thinkingBuffers.get(key).isEmpty() ? null : this.thinkingBuffers.get(key).toString().trim();
            List<StreamingToolFetcher> tools = this.toolFetchers.get(key);
            ResultMessage resultMessage = new ResultMessage(this.role, content, thinking, this.refusal, Objects.nonNull(tools) && !tools.isEmpty() ? tools.stream().map(StreamingToolFetcher::build).map(CompletedToolCall::toolCall).toList() : null);
            return new ChatResponse.ResultChoice((Integer)key, resultMessage, this.finishReasons.get(key));
        }).toList();
        return ChatResponse.build().created(this.created).createdAt(this.createdAt).id(this.id).modelId(this.modelId).object(this.object).model(this.model).modelVersion(this.modelVersion).extractionTags(this.extractionTags).usage(this.chatUsage).choices(choices).build();
    }

    public record ProcessResult(List<CallbackEvent> events, boolean hasError, Throwable error) {
        public static ProcessResult empty() {
            return new ProcessResult(List.of(), false, null);
        }

        public static ProcessResult events(List<CallbackEvent> events) {
            return new ProcessResult(events, false, null);
        }

        public static ProcessResult error(Throwable t) {
            return new ProcessResult(List.of(), true, t);
        }
    }

    /*
     * Uses 'sealed' constructs - enablewith --sealed true
     */
    public static interface CallbackEvent {

        public record ErrorEvent(Throwable error) implements CallbackEvent
        {
        }

        public record CompleteToolCallEvent(CompletedToolCall completeToolCall) implements CallbackEvent
        {
        }

        public record PartialToolCallEvent(PartialToolCall toolCall) implements CallbackEvent
        {
        }

        public record PartialThinkingEvent(String content, PartialChatResponse chunk) implements CallbackEvent
        {
        }

        public record PartialResponseEvent(String content, PartialChatResponse chunk) implements CallbackEvent
        {
        }
    }
}

