/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.spark.sql.connector.read.partitioner;

import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Projections;
import com.mongodb.client.model.Sorts;
import com.mongodb.spark.sql.connector.assertions.Assertions;
import com.mongodb.spark.sql.connector.config.MongoConfig;
import com.mongodb.spark.sql.connector.config.ReadConfig;
import com.mongodb.spark.sql.connector.read.MongoInputPartition;
import com.mongodb.spark.sql.connector.read.partitioner.FieldPartitioner;
import com.mongodb.spark.sql.connector.read.partitioner.PartitionerHelper;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.bson.BsonDocument;
import org.bson.conversions.Bson;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;

@ApiStatus.Internal
public final class SamplePartitioner
extends FieldPartitioner {
    public static final String PARTITION_SIZE_MB_CONFIG = "partition.size";
    private static final int PARTITION_SIZE_MB_DEFAULT = 64;
    static final String SAMPLES_PER_PARTITION_CONFIG = "samples.per.partition";
    private static final int SAMPLES_PER_PARTITION_DEFAULT = 10;

    @Override
    public List<MongoInputPartition> generatePartitions(ReadConfig readConfig) {
        MongoConfig partitionerOptions = readConfig.getPartitionerOptions();
        String partitionField = this.getPartitionField(readConfig);
        long partitionSizeInBytes = Assertions.validateConfig(partitionerOptions.getInt(PARTITION_SIZE_MB_CONFIG, 64), i -> i > 0, () -> String.format("Invalid config: %s should be greater than zero.", PARTITION_SIZE_MB_CONFIG)) * 1000 * 1000;
        int samplesPerPartition = Assertions.validateConfig(partitionerOptions.getInt(SAMPLES_PER_PARTITION_CONFIG, 10), i -> i > 1, () -> String.format("Invalid config: %s should be greater than one.", SAMPLES_PER_PARTITION_CONFIG));
        BsonDocument storageStats = PartitionerHelper.storageStats(readConfig);
        if (storageStats.isEmpty()) {
            LOGGER.warn("Unable to get collection stats (collstats) returning a single partition.");
            return PartitionerHelper.SINGLE_PARTITIONER.generatePartitions(readConfig);
        }
        BsonDocument matchQuery = PartitionerHelper.matchQuery(readConfig.getAggregationPipeline());
        long count = matchQuery.isEmpty() ? storageStats.getNumber((Object)"count").longValue() : ((Long)readConfig.withCollection(coll -> coll.countDocuments((Bson)matchQuery))).longValue();
        double avgObjSizeInBytes = storageStats.getNumber((Object)"avgObjSize").doubleValue();
        double numDocumentsPerPartition = Math.floor((double)partitionSizeInBytes / avgObjSizeInBytes);
        if (numDocumentsPerPartition >= (double)count) {
            LOGGER.info("Fewer documents ({}) than the calculated number of documents per partition ({}). Returning a single partition", (Object)count, (Object)numDocumentsPerPartition);
            return PartitionerHelper.SINGLE_PARTITIONER.generatePartitions(readConfig);
        }
        int numberOfSamples = (int)Math.ceil((double)((long)samplesPerPartition * count) / numDocumentsPerPartition);
        Bson projection = partitionField.equals("_id") ? Projections.include((String[])new String[]{partitionField}) : Projections.fields((Bson[])new Bson[]{Projections.include((String[])new String[]{partitionField}), Projections.excludeId()});
        List samples = (List)readConfig.withCollection(coll -> (ArrayList)coll.aggregate(Arrays.asList(Aggregates.match((Bson)matchQuery), Aggregates.sample((int)numberOfSamples), Aggregates.project((Bson)projection), Aggregates.sort((Bson)Sorts.ascending((String[])new String[]{partitionField})))).allowDiskUse(Boolean.valueOf(readConfig.getAggregationAllowDiskUse())).into(new ArrayList()));
        return this.createMongoInputPartitions(partitionField, this.getRightHandBoundaries(samples, samplesPerPartition), readConfig);
    }

    @NotNull
    private List<BsonDocument> getRightHandBoundaries(List<BsonDocument> samples, int samplesPerPartition) {
        int lastIndex = samples.size() - 1;
        return IntStream.range(0, samples.size()).filter(n -> n % samplesPerPartition == 0 || n == lastIndex).mapToObj(samples::get).skip(1L).collect(Collectors.toList());
    }
}

