/*
 * Decompiled with CFR 0.152.
 */
package io.rsocket.broker.rsocket;

import io.rsocket.RSocket;
import io.rsocket.broker.RoutingTable;
import io.rsocket.broker.common.Tags;
import io.rsocket.broker.common.WellKnownKey;
import io.rsocket.broker.frames.Address;
import io.rsocket.broker.frames.RoutingType;
import io.rsocket.broker.query.RSocketQuery;
import io.rsocket.broker.rsocket.RSocketLocator;
import io.rsocket.broker.rsocket.ResolvingRSocket;
import io.rsocket.loadbalance.LoadbalanceStrategy;
import java.util.List;
import java.util.Map;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UnicastRSocketLocator
implements RSocketLocator {
    private static final Logger logger = LoggerFactory.getLogger(UnicastRSocketLocator.class);
    private final RSocketQuery rSocketQuery;
    private final RoutingTable routingTable;
    private final String defaultLoadBalancer;
    private final Map<String, LoadbalanceStrategy> loadbalancers;

    public UnicastRSocketLocator(RSocketQuery rSocketQuery, RoutingTable routingTable, Map<String, LoadbalanceStrategy> loadbalancers, String defaultLoadBalancer) {
        this.rSocketQuery = rSocketQuery;
        this.routingTable = routingTable;
        this.defaultLoadBalancer = defaultLoadBalancer;
        this.loadbalancers = loadbalancers;
        if (!this.loadbalancers.containsKey(defaultLoadBalancer)) {
            throw new IllegalStateException("No Loadbalancer for " + defaultLoadBalancer + ". Found " + this.loadbalancers.keySet());
        }
    }

    @Override
    public boolean supports(RoutingType routingType) {
        return routingType == RoutingType.UNICAST;
    }

    @Override
    public RSocket locate(Address address) {
        List<RSocket> found = this.rSocketQuery.query(address.getTags());
        int size = found.size();
        switch (size) {
            case 0: {
                return this.resolvingRSocket(address.getTags());
            }
            case 1: {
                return found.get(0);
            }
        }
        return this.loadbalance(found, address.getTags());
    }

    private ResolvingRSocket resolvingRSocket(Tags tags) {
        return new ResolvingRSocket((Publisher<RSocket>)this.routingTable.joinEvents(tags).next().map(routeSetup -> {
            List<RSocket> found = this.rSocketQuery.query(tags);
            if (logger.isWarnEnabled() && found.isEmpty()) {
                logger.warn("Unable to locate RSockets for tags {}", (Object)tags);
            }
            return this.loadbalance(found, tags);
        }));
    }

    private RSocket loadbalance(List<RSocket> rSockets, Tags tags) {
        String lbMethod;
        LoadbalanceStrategy strategy = null;
        if (tags.containsKey(WellKnownKey.LB_METHOD) && this.loadbalancers.containsKey(lbMethod = tags.get(WellKnownKey.LB_METHOD))) {
            strategy = this.loadbalancers.get(lbMethod);
        }
        if (strategy == null) {
            strategy = this.loadbalancers.get(this.defaultLoadBalancer);
        }
        return strategy.select(rSockets);
    }
}

