/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.network;

import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import java.security.Security;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import javax.net.ssl.SSLEngine;
import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.memory.SimpleMemoryPool;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.ChannelBuilder;
import org.apache.kafka.common.network.ChannelMetadataRegistry;
import org.apache.kafka.common.network.EchoServer;
import org.apache.kafka.common.network.KafkaChannel;
import org.apache.kafka.common.network.Mode;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.NetworkTestUtils;
import org.apache.kafka.common.network.Selector;
import org.apache.kafka.common.network.SelectorTest;
import org.apache.kafka.common.network.SslChannelBuilder;
import org.apache.kafka.common.network.SslSender;
import org.apache.kafka.common.network.SslTransportLayer;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.ssl.SslFactory;
import org.apache.kafka.common.security.ssl.mock.TestKeyManagerFactory;
import org.apache.kafka.common.security.ssl.mock.TestProviderCreator;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.test.TestSslUtils;
import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public abstract class SslSelectorTest
extends SelectorTest {
    private Map<String, Object> sslClientConfigs;

    @Override
    @BeforeEach
    public void setUp() throws Exception {
        File trustStoreFile = File.createTempFile("truststore", ".jks");
        Map<String, Object> sslServerConfigs = TestSslUtils.createSslConfig(false, true, Mode.SERVER, trustStoreFile, "server");
        this.server = new EchoServer(SecurityProtocol.SSL, sslServerConfigs);
        this.server.start();
        this.time = new MockTime();
        this.sslClientConfigs = this.createSslClientConfigs(trustStoreFile);
        LogContext logContext = new LogContext();
        this.channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false, logContext);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.metrics = new Metrics();
        this.selector = new Selector(5000L, this.metrics, this.time, "MetricGroup", this.channelBuilder, logContext);
    }

    protected abstract Map<String, Object> createSslClientConfigs(File var1) throws GeneralSecurityException, IOException;

    @Override
    @AfterEach
    public void tearDown() throws Exception {
        this.selector.close();
        this.server.close();
        this.metrics.close();
    }

    @Override
    protected Map<String, Object> clientConfigs() {
        return this.sslClientConfigs;
    }

    @Test
    public void testConnectionWithCustomKeyManager() throws Exception {
        TestProviderCreator testProviderCreator = new TestProviderCreator();
        int requestSize = 102400;
        String node = "0";
        String request = TestUtils.randomString(requestSize);
        Map<String, Object> sslServerConfigs = TestSslUtils.createSslConfig("TestAlgorithm", "TestAlgorithm", TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS);
        sslServerConfigs.put("security.providers", testProviderCreator.getClass().getName());
        EchoServer server = new EchoServer(SecurityProtocol.SSL, sslServerConfigs);
        server.start();
        MockTime time = new MockTime();
        File trustStoreFile = new File(TestKeyManagerFactory.TestKeyManager.mockTrustStoreFile);
        Map<String, Object> sslClientConfigs = TestSslUtils.createSslConfig(true, true, Mode.CLIENT, trustStoreFile, "client");
        TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
        channelBuilder.configure(sslClientConfigs);
        Metrics metrics = new Metrics();
        Selector selector = new Selector(5000L, metrics, (Time)time, "MetricGroup", (ChannelBuilder)channelBuilder, new LogContext());
        selector.connect("0", new InetSocketAddress("localhost", server.port), 4096, 4096);
        NetworkTestUtils.waitForChannelReady(selector, "0");
        selector.send(this.createSend("0", request));
        this.waitForBytesBuffered(selector, "0");
        TestUtils.waitForCondition(() -> SslSelectorTest.cipherMetrics(metrics).size() == 1, "Waiting for cipher metrics to be created.");
        Assertions.assertEquals((Object)1, (Object)SslSelectorTest.cipherMetrics(metrics).get(0).metricValue());
        Assertions.assertNotNull((Object)selector.channel("0").channelMetadataRegistry().cipherInformation());
        selector.close("0");
        super.verifySelectorEmpty(selector);
        Assertions.assertEquals((int)1, (int)SslSelectorTest.cipherMetrics(metrics).size());
        Assertions.assertEquals((Object)0, (Object)SslSelectorTest.cipherMetrics(metrics).get(0).metricValue());
        Security.removeProvider(testProviderCreator.getProvider().getName());
        selector.close();
        server.close();
        metrics.close();
    }

    @Test
    public void testDisconnectWithIntermediateBufferedBytes() throws Exception {
        int requestSize = 102400;
        String node = "0";
        String request = TestUtils.randomString(requestSize);
        this.selector.close();
        this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.selector = new Selector(5000L, this.metrics, this.time, "MetricGroup", this.channelBuilder, new LogContext());
        this.connect("0", new InetSocketAddress("localhost", this.server.port));
        this.selector.send(this.createSend("0", request));
        this.waitForBytesBuffered(this.selector, "0");
        this.selector.close("0");
        this.verifySelectorEmpty();
    }

    private void waitForBytesBuffered(Selector selector, String node) throws Exception {
        TestUtils.waitForCondition(() -> {
            try {
                selector.poll(0L);
                return selector.channel(node).hasBytesBuffered();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }, 2000L, "Failed to reach socket state with bytes buffered");
    }

    @Test
    public void testBytesBufferedChannelWithNoIncomingBytes() throws Exception {
        this.verifyNoUnnecessaryPollWithBytesBuffered(key -> key.interestOps(key.interestOps() & 0xFFFFFFFE));
    }

    @Test
    public void testBytesBufferedChannelAfterMute() throws Exception {
        this.verifyNoUnnecessaryPollWithBytesBuffered(key -> ((KafkaChannel)key.attachment()).mute());
    }

    private void verifyNoUnnecessaryPollWithBytesBuffered(Consumer<SelectionKey> disableRead) throws Exception {
        this.selector.close();
        final String node1 = "1";
        String node2 = "2";
        final AtomicInteger node1Polls = new AtomicInteger();
        this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.selector = new Selector(5000L, this.metrics, this.time, "MetricGroup", this.channelBuilder, new LogContext()){

            void pollSelectionKeys(Set<SelectionKey> selectionKeys, boolean isImmediatelyConnected, long currentTimeNanos) {
                for (SelectionKey key : selectionKeys) {
                    KafkaChannel channel = (KafkaChannel)key.attachment();
                    if (channel == null || !channel.id().equals(node1)) continue;
                    node1Polls.incrementAndGet();
                }
                super.pollSelectionKeys(selectionKeys, isImmediatelyConnected, currentTimeNanos);
            }
        };
        int largeRequestSize = 102400;
        this.connect(node1, new InetSocketAddress("localhost", this.server.port));
        this.selector.send(this.createSend(node1, TestUtils.randomString(largeRequestSize)));
        this.waitForBytesBuffered(this.selector, node1);
        TestSslChannelBuilder.TestSslTransportLayer.transportLayers.get(node1).truncateReadBuffer();
        disableRead.accept(this.selector.channel(node1).selectionKey());
        node1Polls.set(0);
        this.connect(node2, new InetSocketAddress("localhost", this.server.port));
        int received = 0;
        String request = TestUtils.randomString(10);
        this.selector.send(this.createSend(node2, request));
        while (received < 100) {
            received += this.selector.completedReceives().size();
            if (!this.selector.completedSends().isEmpty()) {
                this.selector.send(this.createSend(node2, request));
            }
            this.selector.poll(5L);
        }
        Assertions.assertEquals((int)1, (int)node1Polls.get());
        this.selector.close(node1);
        this.selector.close(node2);
        this.verifySelectorEmpty();
    }

    @Override
    @Test
    public void testMuteOnOOM() throws Exception {
        this.selector.close();
        SimpleMemoryPool pool = new SimpleMemoryPool(900L, 900, false, null);
        String tlsProtocol = "TLSv1.2";
        File trustStoreFile = File.createTempFile("truststore", ".jks");
        Map<String, Object> sslServerConfigs = new TestSslUtils.SslConfigsBuilder(Mode.SERVER).tlsProtocol(tlsProtocol).createNewTrustStore(trustStoreFile).build();
        this.channelBuilder = new SslChannelBuilder(Mode.SERVER, null, false, new LogContext());
        this.channelBuilder.configure(sslServerConfigs);
        this.selector = new Selector(-1, 5000L, this.metrics, this.time, "MetricGroup", new HashMap(), true, false, this.channelBuilder, (MemoryPool)pool, new LogContext());
        try (ServerSocketChannel ss = ServerSocketChannel.open();){
            Collection<Object> completed;
            ss.bind(new InetSocketAddress(0));
            InetSocketAddress serverAddress = (InetSocketAddress)ss.getLocalAddress();
            SslSender sender1 = this.createSender(tlsProtocol, serverAddress, this.randomPayload(900));
            SslSender sender2 = this.createSender(tlsProtocol, serverAddress, this.randomPayload(900));
            sender1.start();
            sender2.start();
            SocketChannel channelX = ss.accept();
            channelX.configureBlocking(false);
            SocketChannel channelY = ss.accept();
            channelY.configureBlocking(false);
            this.selector.register("clientX", channelX);
            this.selector.register("clientY", channelY);
            boolean handshaked = false;
            NetworkReceive firstReceive = null;
            long deadline = System.currentTimeMillis() + 5000L;
            while (System.currentTimeMillis() < deadline) {
                this.selector.poll(10L);
                completed = this.selector.completedReceives();
                if (firstReceive == null) {
                    if (!completed.isEmpty()) {
                        Assertions.assertEquals((int)1, (int)completed.size(), (String)"expecting a single request");
                        firstReceive = (NetworkReceive)completed.iterator().next();
                        Assertions.assertTrue((boolean)this.selector.isMadeReadProgressLastPoll());
                        Assertions.assertEquals((long)0L, (long)pool.availableMemory());
                    }
                } else {
                    Assertions.assertTrue((boolean)completed.isEmpty(), (String)"only expecting single request");
                }
                if (!(handshaked = sender1.waitForHandshake(1L) && sender2.waitForHandshake(1L)) || firstReceive == null || !this.selector.isOutOfMemory()) continue;
                break;
            }
            Assertions.assertTrue((boolean)handshaked, (String)"could not initiate connections within timeout");
            this.selector.poll(10L);
            Assertions.assertTrue((boolean)this.selector.completedReceives().isEmpty());
            Assertions.assertEquals((long)0L, (long)pool.availableMemory());
            Assertions.assertNotNull(firstReceive, (String)"First receive not complete");
            Assertions.assertTrue((boolean)this.selector.isOutOfMemory(), (String)"Selector not out of memory");
            firstReceive.close();
            Assertions.assertEquals((long)900L, (long)pool.availableMemory());
            completed = Collections.emptyList();
            deadline = System.currentTimeMillis() + 5000L;
            while (System.currentTimeMillis() < deadline && completed.isEmpty()) {
                this.selector.poll(1000L);
                completed = this.selector.completedReceives();
            }
            Assertions.assertEquals((int)1, (int)completed.size(), (String)"could not read remaining request within timeout");
            Assertions.assertEquals((long)0L, (long)pool.availableMemory());
            Assertions.assertFalse((boolean)this.selector.isOutOfMemory());
        }
    }

    @Override
    protected void connect(String node, InetSocketAddress serverAddr) throws IOException {
        this.blockingConnect(node, serverAddr);
    }

    private SslSender createSender(String tlsProtocol, InetSocketAddress serverAddress, byte[] payload) {
        return new SslSender(tlsProtocol, serverAddress, payload);
    }

    private static class TestSslChannelBuilder
    extends SslChannelBuilder {
        public TestSslChannelBuilder(Mode mode) {
            super(mode, null, false, new LogContext());
        }

        protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, ChannelMetadataRegistry metadataRegistry) throws IOException {
            SocketChannel socketChannel = (SocketChannel)key.channel();
            SSLEngine sslEngine = sslFactory.createSslEngine(socketChannel.socket());
            TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, metadataRegistry);
            return transportLayer;
        }

        static class TestSslTransportLayer
        extends SslTransportLayer {
            static Map<String, TestSslTransportLayer> transportLayers = new HashMap<String, TestSslTransportLayer>();
            boolean muteSocket = false;

            public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine, ChannelMetadataRegistry metadataRegistry) {
                super(channelId, key, sslEngine, metadataRegistry);
                transportLayers.put(channelId, this);
            }

            protected int readFromSocketChannel() throws IOException {
                if (this.muteSocket) {
                    if ((this.selectionKey().interestOps() & 1) != 0) {
                        this.muteSocket = false;
                    }
                    return 0;
                }
                this.muteSocket = true;
                return super.readFromSocketChannel();
            }

            void truncateReadBuffer() throws Exception {
                this.netReadBuffer().position(1);
                this.appReadBuffer().position(0);
                this.muteSocket = true;
            }
        }
    }
}

