package org.apache.zeppelin.realm.jwt;

import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.SignedJWT;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.servlet.ServletException;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Groups;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAccount;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/zeppelin/realm/jwt/KnoxJwtRealm.class */
public class KnoxJwtRealm extends AuthorizingRealm {
    private static final Logger LOGGER = LoggerFactory.getLogger(KnoxJwtRealm.class);
    private String providerUrl;
    private String redirectParam;
    private String cookieName;
    private String publicKeyPath;
    private String login;
    private String logout;
    private Boolean logoutAPI;
    private Groups hadoopGroups;

    protected void onInit() {
        super.onInit();
        try {
            this.hadoopGroups = new Groups(new Configuration());
        } catch (Exception e) {
            LOGGER.error("Exception in onInit", e);
        }
    }

    public boolean supports(AuthenticationToken authenticationToken) {
        return authenticationToken instanceof JWTAuthenticationToken;
    }

    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) {
        JWTAuthenticationToken jWTAuthenticationToken = (JWTAuthenticationToken) authenticationToken;
        if (!validateToken(jWTAuthenticationToken.getToken())) {
            return null;
        }
        try {
            SimpleAccount simpleAccount = new SimpleAccount(getName(jWTAuthenticationToken), jWTAuthenticationToken.getToken(), getName());
            simpleAccount.addRole(mapGroupPrincipals(getName(jWTAuthenticationToken)));
            return simpleAccount;
        } catch (ParseException e) {
            LOGGER.error("ParseException in doGetAuthenticationInfo", e);
            return null;
        }
    }

    public String getName(JWTAuthenticationToken jWTAuthenticationToken) throws ParseException {
        return SignedJWT.parse(jWTAuthenticationToken.getToken()).getJWTClaimsSet().getSubject();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean validateToken(String str) {
        try {
            SignedJWT parse = SignedJWT.parse(str);
            if (!validateSignature(parse)) {
                LOGGER.warn("Signature of JWT token could not be verified. Please check the public key");
                return false;
            }
            if (!validateExpiration(parse)) {
                LOGGER.warn("Expiration time validation of JWT token failed.");
                return false;
            }
            String str2 = (String) SecurityUtils.getSubject().getPrincipal();
            if (str2 == null) {
                return true;
            }
            return parse.getJWTClaimsSet().getSubject().equals(str2);
        } catch (ParseException e) {
            LOGGER.info("ParseException in validateToken", e);
            return false;
        }
    }

    public static RSAPublicKey parseRSAPublicKey(String str) throws IOException, ServletException {
        try {
            return (RSAPublicKey) ((X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(new ByteArrayInputStream(FileUtils.readFileToString(new File(str), Charset.defaultCharset()).getBytes(StandardCharsets.UTF_8)))).getPublicKey();
        } catch (UnsupportedEncodingException e) {
            throw new ServletException(e);
        } catch (IOException e2) {
            throw new IOException(e2);
        } catch (CertificateException e3) {
            throw new ServletException(str.startsWith("-----BEGIN CERTIFICATE-----\n") ? "CertificateException - be sure not to include PEM header and footer in the PEM configuration element." : "CertificateException - PEM may be corrupt", e3);
        }
    }

    protected boolean validateSignature(SignedJWT signedJWT) {
        boolean z = false;
        if (JWSObject.State.SIGNED == signedJWT.getState() && signedJWT.getSignature() != null) {
            try {
                if (signedJWT.verify(new RSASSAVerifier(parseRSAPublicKey(this.publicKeyPath)))) {
                    z = true;
                }
            } catch (Exception e) {
                LOGGER.info("Exception in validateSignature", e);
            }
        }
        return z;
    }

    protected boolean validateExpiration(SignedJWT signedJWT) {
        boolean z = false;
        try {
            Date expirationTime = signedJWT.getJWTClaimsSet().getExpirationTime();
            if (expirationTime == null || new Date().before(expirationTime)) {
                if (LOGGER.isDebugEnabled()) {
                    LOGGER.debug("SSO token expiration date has been successfully validated");
                }
                z = true;
            } else {
                LOGGER.warn("SSO expiration date validation failed.");
            }
        } catch (ParseException e) {
            LOGGER.warn("SSO expiration date validation failed.", e);
        }
        return z;
    }

    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principalCollection) {
        return new SimpleAuthorizationInfo(mapGroupPrincipals(principalCollection.toString()));
    }

    public Set<String> mapGroupPrincipals(String str) {
        HashSet hashSet;
        try {
            List groups = this.hadoopGroups.getGroups(str);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug(String.format("group found %s, %s", str, groups.toString()));
            }
            hashSet = new HashSet(groups);
        } catch (IOException e) {
            if (e.toString().contains("No groups found for user")) {
                LOGGER.info(String.format("No groups found for user %s", str));
            } else {
                LOGGER.info(String.format("errorGettingUserGroups for %s", str));
            }
            hashSet = new HashSet();
        }
        return hashSet;
    }

    public String getProviderUrl() {
        return this.providerUrl;
    }

    public void setProviderUrl(String str) {
        this.providerUrl = str;
    }

    public String getRedirectParam() {
        return this.redirectParam;
    }

    public void setRedirectParam(String str) {
        this.redirectParam = str;
    }

    public String getCookieName() {
        return this.cookieName;
    }

    public void setCookieName(String str) {
        this.cookieName = str;
    }

    public String getPublicKeyPath() {
        return this.publicKeyPath;
    }

    public void setPublicKeyPath(String str) {
        this.publicKeyPath = str;
    }

    public String getLogin() {
        return this.login;
    }

    public void setLogin(String str) {
        this.login = str;
    }

    public String getLogout() {
        return this.logout;
    }

    public void setLogout(String str) {
        this.logout = str;
    }

    public Boolean getLogoutAPI() {
        return this.logoutAPI;
    }

    public void setLogoutAPI(Boolean bool) {
        this.logoutAPI = bool;
    }
}
