/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.security.oauthbearer.internals.secured;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.BasicOAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ClaimValidationUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.SerializedJwt;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.NumericDate;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.jwt.consumer.JwtContext;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ValidatorAccessTokenValidator
implements AccessTokenValidator {
    private static final Logger log = LoggerFactory.getLogger(ValidatorAccessTokenValidator.class);
    private final JwtConsumer jwtConsumer;
    private final String scopeClaimName;
    private final String subClaimName;

    public ValidatorAccessTokenValidator(Integer clockSkew, Set<String> expectedAudiences, String expectedIssuer, VerificationKeyResolver verificationKeyResolver, String scopeClaimName, String subClaimName, Map<String, Boolean> claimOptions) {
        Objects.requireNonNull(claimOptions, "Claim options cannot be null");
        JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder();
        if (clockSkew != null) {
            jwtConsumerBuilder.setAllowedClockSkewInSeconds(clockSkew.intValue());
        }
        if (expectedAudiences != null && !expectedAudiences.isEmpty()) {
            jwtConsumerBuilder.setExpectedAudience(expectedAudiences.toArray(new String[0]));
        } else {
            jwtConsumerBuilder.setSkipDefaultAudienceValidation();
        }
        if (expectedIssuer != null) {
            jwtConsumerBuilder.setExpectedIssuer(expectedIssuer);
        }
        jwtConsumerBuilder.setJwsAlgorithmConstraints(AlgorithmConstraints.DISALLOW_NONE).setRequireExpirationTime().setVerificationKeyResolver(verificationKeyResolver).build();
        if (claimOptions.getOrDefault("iatRequired", false).booleanValue()) {
            jwtConsumerBuilder.setRequireIssuedAt();
        }
        if (claimOptions.getOrDefault("jtiRequired", false).booleanValue()) {
            jwtConsumerBuilder.setRequireJwtId();
        }
        this.jwtConsumer = jwtConsumerBuilder.build();
        this.scopeClaimName = scopeClaimName;
        this.subClaimName = subClaimName;
    }

    public ValidatorAccessTokenValidator(Integer clockSkew, Set<String> expectedAudiences, String expectedIssuer, VerificationKeyResolver verificationKeyResolver, String scopeClaimName, String subClaimName) {
        this(clockSkew, expectedAudiences, expectedIssuer, verificationKeyResolver, scopeClaimName, subClaimName, Collections.emptyMap());
    }

    @Override
    public OAuthBearerToken validate(String accessToken) throws ValidateException {
        JwtContext jwt;
        SerializedJwt serializedJwt = new SerializedJwt(accessToken);
        try {
            jwt = this.jwtConsumer.process(serializedJwt.getToken());
        }
        catch (InvalidJwtException e) {
            throw new ValidateException(String.format("Could not validate the access token: %s", e.getMessage()), e);
        }
        JwtClaims claims = jwt.getJwtClaims();
        Object scopeRaw = this.getClaim(() -> claims.getClaimValue(this.scopeClaimName), this.scopeClaimName);
        Collection<Object> scopeRawCollection = scopeRaw instanceof String ? Collections.singletonList((String)scopeRaw) : (scopeRaw instanceof Collection ? (Collection)scopeRaw : Collections.emptySet());
        NumericDate expirationRaw = this.getClaim(() -> ((JwtClaims)claims).getExpirationTime(), "exp");
        String subRaw = this.getClaim(() -> claims.getStringClaimValue(this.subClaimName), this.subClaimName);
        NumericDate issuedAtRaw = this.getClaim(() -> ((JwtClaims)claims).getIssuedAt(), "iat");
        Set<String> scopes = ClaimValidationUtils.validateScopes(this.scopeClaimName, scopeRawCollection);
        long expiration = ClaimValidationUtils.validateExpiration("exp", expirationRaw != null ? Long.valueOf(expirationRaw.getValueInMillis()) : null);
        String sub = ClaimValidationUtils.validateSubject(this.subClaimName, subRaw);
        Long issuedAt = ClaimValidationUtils.validateIssuedAt("iat", issuedAtRaw != null ? Long.valueOf(issuedAtRaw.getValueInMillis()) : null);
        BasicOAuthBearerToken token = new BasicOAuthBearerToken(accessToken, scopes, expiration, sub, issuedAt);
        return token;
    }

    private <T> T getClaim(ClaimSupplier<T> supplier, String claimName) throws ValidateException {
        try {
            T value = supplier.get();
            log.debug("getClaim - {}: {}", (Object)claimName, value);
            return value;
        }
        catch (MalformedClaimException e) {
            throw new ValidateException(String.format("Could not extract the '%s' claim from the access token", claimName), e);
        }
    }

    public static interface ClaimSupplier<T> {
        public T get() throws MalformedClaimException;
    }
}

