Skip to main content

aptos_sdk/account/
keyless.rs

1//! Keyless (OIDC-based) account support.
2
3use crate::account::account::{Account, AuthenticationKey};
4use crate::crypto::{Ed25519PrivateKey, Ed25519PublicKey, KEYLESS_SCHEME};
5use crate::error::{AptosError, AptosResult};
6use crate::types::AccountAddress;
7use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
8use rand::RngCore;
9use serde::{Deserialize, Serialize};
10use sha3::{Digest, Sha3_256};
11use std::fmt;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use url::Url;
14
15// Re-export JwkSet for use with from_jwt_with_jwks and refresh_proof_with_jwks
16pub use jsonwebtoken::jwk::JwkSet;
17
18/// Keyless signature payload for transaction authentication.
19#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20pub struct KeylessSignature {
21    /// Ephemeral public key bytes.
22    pub ephemeral_public_key: Vec<u8>,
23    /// Signature produced by the ephemeral key.
24    pub ephemeral_signature: Vec<u8>,
25    /// Zero-knowledge proof bytes.
26    pub proof: Vec<u8>,
27}
28
29impl KeylessSignature {
30    /// Serializes the signature using BCS.
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if BCS serialization fails.
35    pub fn to_bcs(&self) -> AptosResult<Vec<u8>> {
36        aptos_bcs::to_bytes(self).map_err(AptosError::bcs)
37    }
38}
39
40/// Short-lived key pair used for keyless signing.
41#[derive(Clone)]
42pub struct EphemeralKeyPair {
43    private_key: Ed25519PrivateKey,
44    public_key: Ed25519PublicKey,
45    expiry: SystemTime,
46    nonce: String,
47}
48
49impl EphemeralKeyPair {
50    /// Generates a new ephemeral key pair with the given expiry (in seconds).
51    pub fn generate(expiry_secs: u64) -> Self {
52        let private_key = Ed25519PrivateKey::generate();
53        let public_key = private_key.public_key();
54        let nonce = {
55            let mut bytes = [0u8; 16];
56            rand::rngs::OsRng.fill_bytes(&mut bytes);
57            const_hex::encode(bytes)
58        };
59        Self {
60            private_key,
61            public_key,
62            expiry: SystemTime::now() + Duration::from_secs(expiry_secs),
63            nonce,
64        }
65    }
66
67    /// Returns true if the key pair has expired.
68    pub fn is_expired(&self) -> bool {
69        SystemTime::now() >= self.expiry
70    }
71
72    /// Returns the nonce associated with this key pair.
73    pub fn nonce(&self) -> &str {
74        &self.nonce
75    }
76
77    /// Returns the public key.
78    pub fn public_key(&self) -> &Ed25519PublicKey {
79        &self.public_key
80    }
81}
82
83impl fmt::Debug for EphemeralKeyPair {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("EphemeralKeyPair")
86            .field("public_key", &self.public_key)
87            .field("expiry", &self.expiry)
88            .field("nonce", &self.nonce)
89            .finish_non_exhaustive()
90    }
91}
92
93/// Supported OIDC providers.
94#[derive(Clone, Debug, PartialEq, Eq)]
95pub enum OidcProvider {
96    /// Google identity provider.
97    Google,
98    /// Apple identity provider.
99    Apple,
100    /// Microsoft identity provider.
101    Microsoft,
102    /// Custom OIDC provider.
103    Custom {
104        /// Issuer URL.
105        issuer: String,
106        /// JWKS URL.
107        jwks_url: String,
108    },
109}
110
111impl OidcProvider {
112    /// Returns the issuer URL.
113    pub fn issuer(&self) -> &str {
114        match self {
115            OidcProvider::Google => "https://accounts.google.com",
116            OidcProvider::Apple => "https://appleid.apple.com",
117            OidcProvider::Microsoft => "https://login.microsoftonline.com/common/v2.0",
118            OidcProvider::Custom { issuer, .. } => issuer,
119        }
120    }
121
122    /// Returns the JWKS URL.
123    pub fn jwks_url(&self) -> &str {
124        match self {
125            OidcProvider::Google => "https://www.googleapis.com/oauth2/v3/certs",
126            OidcProvider::Apple => "https://appleid.apple.com/auth/keys",
127            OidcProvider::Microsoft => {
128                "https://login.microsoftonline.com/common/discovery/v2.0/keys"
129            }
130            OidcProvider::Custom { jwks_url, .. } => jwks_url,
131        }
132    }
133
134    /// Infers a provider from an issuer URL.
135    ///
136    /// # Security
137    ///
138    /// For unknown issuers, the JWKS URL is constructed as `{issuer}/.well-known/jwks.json`.
139    /// Non-HTTPS issuers are accepted at construction time but will produce an empty
140    /// JWKS URL, causing a clear error at JWKS fetch time. This prevents SSRF via
141    /// `http://`, `file://`, or other dangerous URL schemes without changing the
142    /// function signature. Callers controlling issuer input should additionally
143    /// validate the host (e.g., block private IP ranges) if SSRF is a concern.
144    pub fn from_issuer(issuer: &str) -> Self {
145        match issuer {
146            "https://accounts.google.com" => OidcProvider::Google,
147            "https://appleid.apple.com" => OidcProvider::Apple,
148            "https://login.microsoftonline.com/common/v2.0" => OidcProvider::Microsoft,
149            _ => {
150                // SECURITY: Only accept HTTPS issuers to prevent SSRF attacks.
151                // A malicious JWT could set `iss` to an internal URL (e.g.,
152                // http://169.254.169.254/) causing the SDK to make requests to
153                // attacker-chosen endpoints when fetching JWKS.
154                let jwks_url = if issuer.starts_with("https://") {
155                    format!("{issuer}/.well-known/jwks.json")
156                } else {
157                    // Non-HTTPS issuers get an invalid JWKS URL that will fail
158                    // at fetch time with a clear error rather than making requests
159                    // to potentially dangerous endpoints.
160                    String::new()
161                };
162                OidcProvider::Custom {
163                    issuer: issuer.to_string(),
164                    jwks_url,
165                }
166            }
167        }
168    }
169}
170
171/// Pepper bytes used in keyless address derivation.
172///
173/// # Security
174///
175/// The pepper is secret material used to derive keyless account addresses.
176/// It is automatically zeroized when dropped to prevent key material from
177/// lingering in memory.
178#[derive(Clone, PartialEq, Eq, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
179pub struct Pepper(Vec<u8>);
180
181impl std::fmt::Debug for Pepper {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(f, "Pepper(REDACTED)")
184    }
185}
186
187impl Pepper {
188    /// Creates a new pepper from raw bytes.
189    pub fn new(bytes: Vec<u8>) -> Self {
190        Self(bytes)
191    }
192
193    /// Returns the pepper as bytes.
194    pub fn as_bytes(&self) -> &[u8] {
195        &self.0
196    }
197
198    /// Creates a pepper from hex.
199    ///
200    /// # Errors
201    ///
202    /// Returns an error if the hex string is invalid or cannot be decoded.
203    pub fn from_hex(hex_str: &str) -> AptosResult<Self> {
204        Ok(Self(const_hex::decode(hex_str)?))
205    }
206
207    /// Returns the pepper as hex.
208    pub fn to_hex(&self) -> String {
209        const_hex::encode_prefixed(&self.0)
210    }
211}
212
213/// Zero-knowledge proof bytes.
214#[derive(Clone, Debug, PartialEq, Eq)]
215pub struct ZkProof(Vec<u8>);
216
217impl ZkProof {
218    /// Creates a new proof from raw bytes.
219    pub fn new(bytes: Vec<u8>) -> Self {
220        Self(bytes)
221    }
222
223    /// Returns the proof as bytes.
224    pub fn as_bytes(&self) -> &[u8] {
225        &self.0
226    }
227
228    /// Creates a proof from hex.
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if the hex string is invalid or cannot be decoded.
233    pub fn from_hex(hex_str: &str) -> AptosResult<Self> {
234        Ok(Self(const_hex::decode(hex_str)?))
235    }
236
237    /// Returns the proof as hex.
238    pub fn to_hex(&self) -> String {
239        const_hex::encode_prefixed(&self.0)
240    }
241}
242
243/// Service for obtaining pepper values.
244pub trait PepperService: Send + Sync {
245    /// Fetches the pepper for a JWT.
246    fn get_pepper(
247        &self,
248        jwt: &str,
249    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AptosResult<Pepper>> + Send + '_>>;
250}
251
252/// Service for generating zero-knowledge proofs.
253pub trait ProverService: Send + Sync {
254    /// Generates the proof for keyless authentication.
255    fn generate_proof<'a>(
256        &'a self,
257        jwt: &'a str,
258        ephemeral_key: &'a EphemeralKeyPair,
259        pepper: &'a Pepper,
260    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AptosResult<ZkProof>> + Send + 'a>>;
261}
262
263/// HTTP pepper service client.
264#[derive(Clone, Debug)]
265pub struct HttpPepperService {
266    url: Url,
267    client: reqwest::Client,
268}
269
270impl HttpPepperService {
271    /// Creates a new HTTP pepper service client.
272    pub fn new(url: Url) -> Self {
273        Self {
274            url,
275            client: reqwest::Client::new(),
276        }
277    }
278}
279
280#[derive(Serialize)]
281struct PepperRequest<'a> {
282    jwt: &'a str,
283}
284
285#[derive(Deserialize)]
286struct PepperResponse {
287    pepper: String,
288}
289
290impl PepperService for HttpPepperService {
291    fn get_pepper(
292        &self,
293        jwt: &str,
294    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AptosResult<Pepper>> + Send + '_>> {
295        let jwt = jwt.to_owned();
296        Box::pin(async move {
297            let response = self
298                .client
299                .post(self.url.clone())
300                .json(&PepperRequest { jwt: &jwt })
301                .send()
302                .await?
303                .error_for_status()?;
304
305            // SECURITY: Stream body with size limit to prevent OOM
306            let bytes =
307                crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
308            let payload: PepperResponse = serde_json::from_slice(&bytes).map_err(|e| {
309                AptosError::InvalidJwt(format!("failed to parse pepper response: {e}"))
310            })?;
311            Pepper::from_hex(&payload.pepper)
312        })
313    }
314}
315
316/// HTTP prover service client.
317#[derive(Clone, Debug)]
318pub struct HttpProverService {
319    url: Url,
320    client: reqwest::Client,
321}
322
323impl HttpProverService {
324    /// Creates a new HTTP prover service client.
325    pub fn new(url: Url) -> Self {
326        Self {
327            url,
328            client: reqwest::Client::new(),
329        }
330    }
331}
332
333#[derive(Serialize)]
334struct ProverRequest<'a> {
335    jwt: &'a str,
336    ephemeral_public_key: String,
337    nonce: &'a str,
338    pepper: String,
339}
340
341#[derive(Deserialize)]
342struct ProverResponse {
343    proof: String,
344}
345
346impl ProverService for HttpProverService {
347    fn generate_proof<'a>(
348        &'a self,
349        jwt: &'a str,
350        ephemeral_key: &'a EphemeralKeyPair,
351        pepper: &'a Pepper,
352    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AptosResult<ZkProof>> + Send + 'a>>
353    {
354        Box::pin(async move {
355            let request = ProverRequest {
356                jwt,
357                ephemeral_public_key: const_hex::encode_prefixed(
358                    ephemeral_key.public_key.to_bytes(),
359                ),
360                nonce: ephemeral_key.nonce(),
361                pepper: pepper.to_hex(),
362            };
363
364            let response = self
365                .client
366                .post(self.url.clone())
367                .json(&request)
368                .send()
369                .await?
370                .error_for_status()?;
371
372            // SECURITY: Stream body with size limit to prevent OOM
373            let bytes =
374                crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
375            let payload: ProverResponse = serde_json::from_slice(&bytes).map_err(|e| {
376                AptosError::InvalidJwt(format!("failed to parse prover response: {e}"))
377            })?;
378            ZkProof::from_hex(&payload.proof)
379        })
380    }
381}
382
383/// Account authenticated via OIDC.
384pub struct KeylessAccount {
385    ephemeral_key: EphemeralKeyPair,
386    provider: OidcProvider,
387    issuer: String,
388    audience: String,
389    user_id: String,
390    pepper: Pepper,
391    proof: ZkProof,
392    address: AccountAddress,
393    auth_key: AuthenticationKey,
394    jwt_expiration: Option<SystemTime>,
395}
396
397impl KeylessAccount {
398    /// Creates a keyless account from an OIDC JWT token.
399    ///
400    /// This method verifies the JWT signature using the OIDC provider's JWKS endpoint
401    /// before extracting claims and creating the account.
402    ///
403    /// # Network Requests
404    ///
405    /// This method makes HTTP requests to:
406    /// - The OIDC provider's JWKS endpoint to fetch signing keys
407    /// - The pepper service to obtain the pepper
408    /// - The prover service to generate a ZK proof
409    ///
410    /// For more control over network calls and caching, use [`Self::from_jwt_with_jwks`]
411    /// with pre-fetched JWKS.
412    ///
413    /// # Errors
414    ///
415    /// This function will return an error if:
416    /// - The JWT signature verification fails
417    /// - The JWT cannot be decoded or is missing required claims (iss, aud, sub, nonce)
418    /// - The JWT nonce doesn't match the ephemeral key's nonce
419    /// - The JWT is expired
420    /// - The JWKS cannot be fetched from the provider (network timeout, DNS failure,
421    ///   connection errors, HTTP errors, or invalid JWKS response)
422    /// - The pepper service fails to return a pepper
423    /// - The prover service fails to generate a proof
424    pub async fn from_jwt(
425        jwt: &str,
426        ephemeral_key: EphemeralKeyPair,
427        pepper_service: &dyn PepperService,
428        prover_service: &dyn ProverService,
429    ) -> AptosResult<Self> {
430        // First, decode without verification to get the issuer for JWKS lookup
431        let unverified_claims = decode_claims_unverified(jwt)?;
432        let issuer = unverified_claims
433            .iss
434            .as_ref()
435            .ok_or_else(|| AptosError::InvalidJwt("missing iss claim".into()))?;
436
437        // Determine provider and fetch JWKS
438        let provider = OidcProvider::from_issuer(issuer);
439        let client = reqwest::Client::builder()
440            .timeout(JWKS_FETCH_TIMEOUT)
441            .build()
442            .map_err(|e| AptosError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
443        let jwks = fetch_jwks(&client, provider.jwks_url()).await?;
444
445        // Now verify and decode the JWT properly
446        let claims = decode_and_verify_jwt(jwt, &jwks)?;
447        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
448
449        if nonce != ephemeral_key.nonce() {
450            return Err(AptosError::InvalidJwt("JWT nonce mismatch".into()));
451        }
452
453        let pepper = pepper_service.get_pepper(jwt).await?;
454        let proof = prover_service
455            .generate_proof(jwt, &ephemeral_key, &pepper)
456            .await?;
457
458        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
459        let auth_key = AuthenticationKey::new(address.to_bytes());
460
461        Ok(Self {
462            provider: OidcProvider::from_issuer(&issuer),
463            issuer,
464            audience,
465            user_id,
466            pepper,
467            proof,
468            address,
469            auth_key,
470            jwt_expiration: exp,
471            ephemeral_key,
472        })
473    }
474
475    /// Creates a keyless account from a JWT with pre-fetched JWKS.
476    ///
477    /// This method is useful when you want to:
478    /// - Cache the JWKS to avoid repeated network requests
479    /// - Have more control over HTTP client configuration
480    /// - Implement custom caching strategies based on HTTP cache headers
481    ///
482    /// # Errors
483    ///
484    /// This function will return an error if:
485    /// - The JWT signature verification fails
486    /// - The JWT cannot be decoded or is missing required claims (iss, aud, sub, nonce)
487    /// - The JWT nonce doesn't match the ephemeral key's nonce
488    /// - The JWT is expired
489    /// - The pepper service fails to return a pepper
490    /// - The prover service fails to generate a proof
491    pub async fn from_jwt_with_jwks(
492        jwt: &str,
493        jwks: &JwkSet,
494        ephemeral_key: EphemeralKeyPair,
495        pepper_service: &dyn PepperService,
496        prover_service: &dyn ProverService,
497    ) -> AptosResult<Self> {
498        // Verify and decode the JWT using the provided JWKS
499        let claims = decode_and_verify_jwt(jwt, jwks)?;
500        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
501
502        if nonce != ephemeral_key.nonce() {
503            return Err(AptosError::InvalidJwt("JWT nonce mismatch".into()));
504        }
505
506        let pepper = pepper_service.get_pepper(jwt).await?;
507        let proof = prover_service
508            .generate_proof(jwt, &ephemeral_key, &pepper)
509            .await?;
510
511        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
512        let auth_key = AuthenticationKey::new(address.to_bytes());
513
514        Ok(Self {
515            provider: OidcProvider::from_issuer(&issuer),
516            issuer,
517            audience,
518            user_id,
519            pepper,
520            proof,
521            address,
522            auth_key,
523            jwt_expiration: exp,
524            ephemeral_key,
525        })
526    }
527
528    /// Returns the OIDC provider.
529    pub fn provider(&self) -> &OidcProvider {
530        &self.provider
531    }
532
533    /// Returns the issuer.
534    pub fn issuer(&self) -> &str {
535        &self.issuer
536    }
537
538    /// Returns the audience.
539    pub fn audience(&self) -> &str {
540        &self.audience
541    }
542
543    /// Returns the user identifier (sub claim).
544    pub fn user_id(&self) -> &str {
545        &self.user_id
546    }
547
548    /// Returns the proof.
549    pub fn proof(&self) -> &ZkProof {
550        &self.proof
551    }
552
553    /// Returns true if the account is still valid.
554    pub fn is_valid(&self) -> bool {
555        if self.ephemeral_key.is_expired() {
556            return false;
557        }
558
559        match self.jwt_expiration {
560            Some(exp) => SystemTime::now() < exp,
561            None => true,
562        }
563    }
564
565    /// Refreshes the proof using a new JWT.
566    ///
567    /// This method verifies the JWT signature using the OIDC provider's JWKS endpoint.
568    ///
569    /// # Network Requests
570    ///
571    /// This method makes HTTP requests to fetch the JWKS from the OIDC provider.
572    /// For more control over network calls and caching, use [`Self::refresh_proof_with_jwks`].
573    ///
574    /// # Errors
575    ///
576    /// Returns an error if:
577    /// - The JWKS cannot be fetched (network timeout, DNS failure, connection errors)
578    /// - The JWT signature verification fails
579    /// - The JWT cannot be decoded
580    /// - The JWT nonce does not match the ephemeral key
581    /// - The JWT identity does not match the account
582    /// - The prover service fails to generate a new proof
583    pub async fn refresh_proof(
584        &mut self,
585        jwt: &str,
586        prover_service: &dyn ProverService,
587    ) -> AptosResult<()> {
588        // Fetch JWKS and verify JWT
589        let client = reqwest::Client::builder()
590            .timeout(JWKS_FETCH_TIMEOUT)
591            .build()
592            .map_err(|e| AptosError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
593        let jwks = fetch_jwks(&client, self.provider.jwks_url()).await?;
594        self.refresh_proof_with_jwks(jwt, &jwks, prover_service)
595            .await
596    }
597
598    /// Refreshes the proof using a new JWT with pre-fetched JWKS.
599    ///
600    /// This method is useful for caching the JWKS or using a custom HTTP client.
601    ///
602    /// # Errors
603    ///
604    /// Returns an error if:
605    /// - The JWT signature verification fails
606    /// - The JWT cannot be decoded
607    /// - The JWT nonce does not match the ephemeral key
608    /// - The JWT identity does not match the account
609    /// - The prover service fails to generate a new proof
610    pub async fn refresh_proof_with_jwks(
611        &mut self,
612        jwt: &str,
613        jwks: &JwkSet,
614        prover_service: &dyn ProverService,
615    ) -> AptosResult<()> {
616        let claims = decode_and_verify_jwt(jwt, jwks)?;
617        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
618
619        if nonce != self.ephemeral_key.nonce() {
620            return Err(AptosError::InvalidJwt("JWT nonce mismatch".into()));
621        }
622
623        if issuer != self.issuer || audience != self.audience || user_id != self.user_id {
624            return Err(AptosError::InvalidJwt(
625                "JWT identity does not match account".into(),
626            ));
627        }
628
629        let proof = prover_service
630            .generate_proof(jwt, &self.ephemeral_key, &self.pepper)
631            .await?;
632        self.proof = proof;
633        self.jwt_expiration = exp;
634        Ok(())
635    }
636
637    /// Signs a message and returns the structured keyless signature.
638    pub fn sign_keyless(&self, message: &[u8]) -> KeylessSignature {
639        let signature = self.ephemeral_key.private_key.sign(message).to_bytes();
640        KeylessSignature {
641            ephemeral_public_key: self.ephemeral_key.public_key.to_bytes().to_vec(),
642            ephemeral_signature: signature.to_vec(),
643            proof: self.proof.as_bytes().to_vec(),
644        }
645    }
646
647    /// Creates a keyless account from pre-verified JWT claims.
648    ///
649    /// This is useful for testing or when JWT verification is handled externally.
650    /// The caller is responsible for ensuring the JWT was properly verified.
651    ///
652    /// # Errors
653    ///
654    /// This function will return an error if:
655    /// - The nonce doesn't match the ephemeral key's nonce
656    /// - The pepper service fails to return a pepper
657    /// - The prover service fails to generate a proof
658    #[doc(hidden)]
659    #[allow(clippy::too_many_arguments)]
660    pub async fn from_verified_claims(
661        issuer: String,
662        audience: String,
663        user_id: String,
664        nonce: String,
665        exp: Option<SystemTime>,
666        ephemeral_key: EphemeralKeyPair,
667        pepper_service: &dyn PepperService,
668        prover_service: &dyn ProverService,
669        jwt_for_services: &str,
670    ) -> AptosResult<Self> {
671        if nonce != ephemeral_key.nonce() {
672            return Err(AptosError::InvalidJwt("nonce mismatch".into()));
673        }
674
675        let pepper = pepper_service.get_pepper(jwt_for_services).await?;
676        let proof = prover_service
677            .generate_proof(jwt_for_services, &ephemeral_key, &pepper)
678            .await?;
679
680        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
681        let auth_key = AuthenticationKey::new(address.to_bytes());
682
683        Ok(Self {
684            provider: OidcProvider::from_issuer(&issuer),
685            issuer,
686            audience,
687            user_id,
688            pepper,
689            proof,
690            address,
691            auth_key,
692            jwt_expiration: exp,
693            ephemeral_key,
694        })
695    }
696}
697
698impl Account for KeylessAccount {
699    fn address(&self) -> AccountAddress {
700        self.address
701    }
702
703    fn authentication_key(&self) -> AuthenticationKey {
704        self.auth_key
705    }
706
707    fn sign(&self, message: &[u8]) -> crate::error::AptosResult<Vec<u8>> {
708        let signature = self.sign_keyless(message);
709        signature
710            .to_bcs()
711            .map_err(|e| crate::error::AptosError::Bcs(e.to_string()))
712    }
713
714    fn public_key_bytes(&self) -> Vec<u8> {
715        self.ephemeral_key.public_key.to_bytes().to_vec()
716    }
717
718    fn signature_scheme(&self) -> u8 {
719        KEYLESS_SCHEME
720    }
721}
722
723impl fmt::Debug for KeylessAccount {
724    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725        f.debug_struct("KeylessAccount")
726            .field("address", &self.address)
727            .field("provider", &self.provider)
728            .field("issuer", &self.issuer)
729            .field("audience", &self.audience)
730            .field("user_id", &self.user_id)
731            .finish_non_exhaustive()
732    }
733}
734
735#[derive(Debug, Deserialize)]
736struct JwtClaims {
737    iss: Option<String>,
738    aud: Option<AudClaim>,
739    sub: Option<String>,
740    exp: Option<u64>,
741    nonce: Option<String>,
742}
743
744#[derive(Debug, Deserialize)]
745#[serde(untagged)]
746enum AudClaim {
747    Single(String),
748    Multiple(Vec<String>),
749}
750
751impl AudClaim {
752    fn first(&self) -> Option<&str> {
753        match self {
754            AudClaim::Single(value) => Some(value.as_str()),
755            AudClaim::Multiple(values) => values.first().map(std::string::String::as_str),
756        }
757    }
758}
759
760/// Default timeout for JWKS fetch requests (10 seconds).
761const JWKS_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
762
763/// Maximum JWKS response size: 1 MB (JWKS payloads are typically under 10 KB).
764const MAX_JWKS_RESPONSE_SIZE: usize = 1024 * 1024;
765
766/// Fetches the JWKS (JSON Web Key Set) from an OIDC provider.
767///
768/// # Errors
769///
770/// Returns an error if:
771/// - The JWKS cannot be fetched (network timeouts, DNS resolution failures,
772///   TLS/connection errors, or HTTP errors)
773/// - The JWKS endpoint returns a non-success status code
774/// - The response cannot be parsed as valid JWKS JSON
775async fn fetch_jwks(client: &reqwest::Client, jwks_url: &str) -> AptosResult<JwkSet> {
776    // SECURITY: Validate the JWKS URL scheme to prevent SSRF.
777    // The issuer comes from an untrusted JWT, so the derived JWKS URL could
778    // point to internal services (e.g., cloud metadata endpoints).
779    let parsed_url = Url::parse(jwks_url)
780        .map_err(|e| AptosError::InvalidJwt(format!("invalid JWKS URL: {e}")))?;
781    if parsed_url.scheme() != "https" {
782        return Err(AptosError::InvalidJwt(format!(
783            "JWKS URL must use HTTPS scheme, got: {}",
784            parsed_url.scheme()
785        )));
786    }
787
788    // Note: timeout is configured on the client, not per-request
789    let response = client.get(jwks_url).send().await?;
790
791    if !response.status().is_success() {
792        return Err(AptosError::InvalidJwt(format!(
793            "JWKS endpoint returned status: {}",
794            response.status()
795        )));
796    }
797
798    // SECURITY: Stream body with size limit to prevent OOM from a
799    // compromised or malicious JWKS endpoint (including chunked encoding).
800    let bytes = crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
801    let jwks: JwkSet = serde_json::from_slice(&bytes)
802        .map_err(|e| AptosError::InvalidJwt(format!("failed to parse JWKS: {e}")))?;
803    Ok(jwks)
804}
805
806/// Decodes and verifies a JWT using the provided JWKS.
807///
808/// This function:
809/// 1. Extracts the `kid` (key ID) from the JWT header
810/// 2. Finds the matching key in the JWKS
811/// 3. Verifies the signature and decodes the claims
812///
813/// # Errors
814///
815/// Returns an error if:
816/// - The JWT header cannot be decoded
817/// - No matching key is found in the JWKS
818/// - The signature verification fails
819/// - The claims cannot be decoded
820fn decode_and_verify_jwt(jwt: &str, jwks: &JwkSet) -> AptosResult<JwtClaims> {
821    // Decode header to get the key ID
822    let header = decode_header(jwt)
823        .map_err(|e| AptosError::InvalidJwt(format!("failed to decode JWT header: {e}")))?;
824
825    let kid = header
826        .kid
827        .as_ref()
828        .ok_or_else(|| AptosError::InvalidJwt("JWT header missing 'kid' field".into()))?;
829
830    // Find the matching key in the JWKS
831    let signing_key = jwks.find(kid).ok_or_else(|| {
832        AptosError::InvalidJwt("no matching key found for provided key identifier".into())
833    })?;
834
835    // Create decoding key from JWK
836    let decoding_key = DecodingKey::from_jwk(signing_key)
837        .map_err(|e| AptosError::InvalidJwt(format!("failed to create decoding key: {e}")))?;
838
839    // Determine the algorithm strictly from the JWK to prevent algorithm substitution attacks
840    let jwk_alg = signing_key
841        .common
842        .key_algorithm
843        .ok_or_else(|| AptosError::InvalidJwt("JWK missing 'alg' (key_algorithm) field".into()))?;
844
845    let algorithm = match jwk_alg {
846        // RSA algorithms
847        jsonwebtoken::jwk::KeyAlgorithm::RS256 => Algorithm::RS256,
848        jsonwebtoken::jwk::KeyAlgorithm::RS384 => Algorithm::RS384,
849        jsonwebtoken::jwk::KeyAlgorithm::RS512 => Algorithm::RS512,
850        // RSA-PSS algorithms
851        jsonwebtoken::jwk::KeyAlgorithm::PS256 => Algorithm::PS256,
852        jsonwebtoken::jwk::KeyAlgorithm::PS384 => Algorithm::PS384,
853        jsonwebtoken::jwk::KeyAlgorithm::PS512 => Algorithm::PS512,
854        // ECDSA algorithms
855        jsonwebtoken::jwk::KeyAlgorithm::ES256 => Algorithm::ES256,
856        jsonwebtoken::jwk::KeyAlgorithm::ES384 => Algorithm::ES384,
857        // EdDSA algorithm
858        jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Algorithm::EdDSA,
859        _ => {
860            return Err(AptosError::InvalidJwt(format!(
861                "unsupported JWK algorithm: {jwk_alg:?}"
862            )));
863        }
864    };
865
866    // Ensure the JWT header algorithm matches the JWK algorithm to prevent substitution
867    if header.alg != algorithm {
868        return Err(AptosError::InvalidJwt(format!(
869            "JWT header algorithm ({:?}) does not match JWK algorithm ({:?})",
870            header.alg, algorithm
871        )));
872    }
873
874    // Configure validation - we'll validate exp ourselves with more detailed errors
875    let mut validation = Validation::new(algorithm);
876    validation.validate_exp = false;
877    validation.validate_aud = false; // We'll check aud after decoding
878    validation.set_required_spec_claims::<String>(&[]);
879
880    let data = decode::<JwtClaims>(jwt, &decoding_key, &validation)
881        .map_err(|e| AptosError::InvalidJwt(format!("JWT verification failed: {e}")))?;
882
883    Ok(data.claims)
884}
885
886/// Decodes JWT claims without signature verification.
887///
888/// This is used only to extract the issuer (and other metadata) before we know
889/// which JWKS endpoint to fetch. This is safe because:
890/// 1. The extracted issuer is only used to determine which JWKS endpoint to fetch.
891/// 2. The JWT is fully verified immediately afterwards using `decode_and_verify_jwt`.
892/// 3. No security decisions are made based on these unverified claims.
893fn decode_claims_unverified(jwt: &str) -> AptosResult<JwtClaims> {
894    // Use dangerous decode only for initial issuer extraction to select the JWKS.
895    // The JWT is not trusted at this point: no authorization decisions are made
896    // based on these unverified claims, and the token is fully verified (including
897    // signature and claims validation) in `decode_and_verify_jwt` right after the
898    // appropriate JWKS has been fetched.
899    let data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(jwt)
900        .map_err(|e| AptosError::InvalidJwt(format!("failed to decode JWT claims: {e}")))?;
901    Ok(data.claims)
902}
903
904fn extract_claims(
905    claims: &JwtClaims,
906) -> AptosResult<(String, String, String, Option<SystemTime>, String)> {
907    let issuer = claims
908        .iss
909        .clone()
910        .ok_or_else(|| AptosError::InvalidJwt("missing iss claim".into()))?;
911    let audience = claims
912        .aud
913        .as_ref()
914        .and_then(|aud| aud.first())
915        .map(std::string::ToString::to_string)
916        .ok_or_else(|| AptosError::InvalidJwt("missing aud claim".into()))?;
917    let user_id = claims
918        .sub
919        .clone()
920        .ok_or_else(|| AptosError::InvalidJwt("missing sub claim".into()))?;
921    let nonce = claims
922        .nonce
923        .clone()
924        .ok_or_else(|| AptosError::InvalidJwt("missing nonce claim".into()))?;
925
926    let exp_time = claims.exp.map(|exp| UNIX_EPOCH + Duration::from_secs(exp));
927    if let Some(exp) = exp_time
928        && SystemTime::now() >= exp
929    {
930        let exp_secs = claims.exp.unwrap_or(0);
931        return Err(AptosError::InvalidJwt(format!(
932            "JWT is expired (exp: {exp_secs} seconds since UNIX_EPOCH)"
933        )));
934    }
935
936    Ok((issuer, audience, user_id, exp_time, nonce))
937}
938
939fn derive_keyless_address(
940    issuer: &str,
941    audience: &str,
942    user_id: &str,
943    pepper: &Pepper,
944) -> AccountAddress {
945    let issuer_hash = sha3_256_bytes(issuer.as_bytes());
946    let audience_hash = sha3_256_bytes(audience.as_bytes());
947    let user_hash = sha3_256_bytes(user_id.as_bytes());
948
949    let mut hasher = Sha3_256::new();
950    hasher.update(issuer_hash);
951    hasher.update(audience_hash);
952    hasher.update(user_hash);
953    hasher.update(pepper.as_bytes());
954    hasher.update([KEYLESS_SCHEME]);
955    let result = hasher.finalize();
956
957    let mut address = [0u8; 32];
958    address.copy_from_slice(&result);
959    AccountAddress::new(address)
960}
961
962fn sha3_256_bytes(data: &[u8]) -> [u8; 32] {
963    let mut hasher = Sha3_256::new();
964    hasher.update(data);
965    let result = hasher.finalize();
966    let mut output = [0u8; 32];
967    output.copy_from_slice(&result);
968    output
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974    use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
975
976    struct StaticPepperService {
977        pepper: Pepper,
978    }
979
980    impl PepperService for StaticPepperService {
981        fn get_pepper(
982            &self,
983            _jwt: &str,
984        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AptosResult<Pepper>> + Send + '_>>
985        {
986            Box::pin(async move { Ok(self.pepper.clone()) })
987        }
988    }
989
990    struct StaticProverService {
991        proof: ZkProof,
992    }
993
994    impl ProverService for StaticProverService {
995        fn generate_proof<'a>(
996            &'a self,
997            _jwt: &'a str,
998            _ephemeral_key: &'a EphemeralKeyPair,
999            _pepper: &'a Pepper,
1000        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AptosResult<ZkProof>> + Send + 'a>>
1001        {
1002            Box::pin(async move { Ok(self.proof.clone()) })
1003        }
1004    }
1005
1006    #[derive(Serialize, Deserialize)]
1007    struct TestClaims {
1008        iss: String,
1009        aud: String,
1010        sub: String,
1011        exp: u64,
1012        nonce: String,
1013    }
1014
1015    #[tokio::test]
1016    async fn test_keyless_account_creation() {
1017        let ephemeral = EphemeralKeyPair::generate(3600);
1018        let now = SystemTime::now()
1019            .duration_since(UNIX_EPOCH)
1020            .expect("time went backwards")
1021            .as_secs();
1022
1023        // Create a test JWT for the services (they don't validate it)
1024        let claims = TestClaims {
1025            iss: "https://accounts.google.com".to_string(),
1026            aud: "client-id".to_string(),
1027            sub: "user-123".to_string(),
1028            exp: now + 3600,
1029            nonce: ephemeral.nonce().to_string(),
1030        };
1031
1032        let jwt = encode(
1033            &Header::new(Algorithm::HS256),
1034            &claims,
1035            &EncodingKey::from_secret(b"secret"),
1036        )
1037        .unwrap();
1038
1039        let pepper_service = StaticPepperService {
1040            pepper: Pepper::new(vec![1, 2, 3, 4]),
1041        };
1042        let prover_service = StaticProverService {
1043            proof: ZkProof::new(vec![9, 9, 9]),
1044        };
1045
1046        // Use from_verified_claims for unit testing since we can't mock JWKS
1047        let exp_time = UNIX_EPOCH + std::time::Duration::from_secs(now + 3600);
1048        let account = KeylessAccount::from_verified_claims(
1049            "https://accounts.google.com".to_string(),
1050            "client-id".to_string(),
1051            "user-123".to_string(),
1052            ephemeral.nonce().to_string(),
1053            Some(exp_time),
1054            ephemeral,
1055            &pepper_service,
1056            &prover_service,
1057            &jwt,
1058        )
1059        .await
1060        .unwrap();
1061
1062        assert_eq!(account.issuer(), "https://accounts.google.com");
1063        assert_eq!(account.audience(), "client-id");
1064        assert_eq!(account.user_id(), "user-123");
1065        assert!(account.is_valid());
1066        assert!(!account.address().is_zero());
1067    }
1068
1069    #[tokio::test]
1070    async fn test_keyless_account_nonce_mismatch() {
1071        let ephemeral = EphemeralKeyPair::generate(3600);
1072        let now = SystemTime::now()
1073            .duration_since(UNIX_EPOCH)
1074            .expect("time went backwards")
1075            .as_secs();
1076
1077        let claims = TestClaims {
1078            iss: "https://accounts.google.com".to_string(),
1079            aud: "client-id".to_string(),
1080            sub: "user-123".to_string(),
1081            exp: now + 3600,
1082            nonce: ephemeral.nonce().to_string(),
1083        };
1084
1085        let jwt = encode(
1086            &Header::new(Algorithm::HS256),
1087            &claims,
1088            &EncodingKey::from_secret(b"secret"),
1089        )
1090        .unwrap();
1091
1092        let pepper_service = StaticPepperService {
1093            pepper: Pepper::new(vec![1, 2, 3, 4]),
1094        };
1095        let prover_service = StaticProverService {
1096            proof: ZkProof::new(vec![9, 9, 9]),
1097        };
1098
1099        // Use a different nonce to trigger mismatch
1100        let result = KeylessAccount::from_verified_claims(
1101            "https://accounts.google.com".to_string(),
1102            "client-id".to_string(),
1103            "user-123".to_string(),
1104            "wrong-nonce".to_string(), // This doesn't match ephemeral.nonce()
1105            None,
1106            ephemeral,
1107            &pepper_service,
1108            &prover_service,
1109            &jwt,
1110        )
1111        .await;
1112
1113        assert!(result.is_err());
1114        assert!(matches!(result, Err(AptosError::InvalidJwt(_))));
1115    }
1116
1117    #[test]
1118    fn test_decode_claims_unverified() {
1119        let now = SystemTime::now()
1120            .duration_since(UNIX_EPOCH)
1121            .expect("time went backwards")
1122            .as_secs();
1123
1124        let claims = TestClaims {
1125            iss: "https://accounts.google.com".to_string(),
1126            aud: "test-aud".to_string(),
1127            sub: "test-sub".to_string(),
1128            exp: now + 3600,
1129            nonce: "test-nonce".to_string(),
1130        };
1131
1132        let jwt = encode(
1133            &Header::new(Algorithm::HS256),
1134            &claims,
1135            &EncodingKey::from_secret(b"secret"),
1136        )
1137        .unwrap();
1138
1139        let decoded = decode_claims_unverified(&jwt).unwrap();
1140        assert_eq!(decoded.iss.unwrap(), "https://accounts.google.com");
1141        assert_eq!(decoded.sub.unwrap(), "test-sub");
1142        assert_eq!(decoded.nonce.unwrap(), "test-nonce");
1143    }
1144
1145    #[test]
1146    fn test_oidc_provider_detection() {
1147        assert!(matches!(
1148            OidcProvider::from_issuer("https://accounts.google.com"),
1149            OidcProvider::Google
1150        ));
1151        assert!(matches!(
1152            OidcProvider::from_issuer("https://appleid.apple.com"),
1153            OidcProvider::Apple
1154        ));
1155        assert!(matches!(
1156            OidcProvider::from_issuer("https://unknown.example.com"),
1157            OidcProvider::Custom { .. }
1158        ));
1159    }
1160
1161    #[test]
1162    fn test_decode_and_verify_jwt_missing_kid() {
1163        // Create a JWT without a kid in the header
1164        let now = SystemTime::now()
1165            .duration_since(UNIX_EPOCH)
1166            .expect("time went backwards")
1167            .as_secs();
1168
1169        let claims = TestClaims {
1170            iss: "https://accounts.google.com".to_string(),
1171            aud: "test-aud".to_string(),
1172            sub: "test-sub".to_string(),
1173            exp: now + 3600,
1174            nonce: "test-nonce".to_string(),
1175        };
1176
1177        // HS256 JWT without kid
1178        let jwt = encode(
1179            &Header::new(Algorithm::HS256),
1180            &claims,
1181            &EncodingKey::from_secret(b"secret"),
1182        )
1183        .unwrap();
1184
1185        // Empty JWKS
1186        let jwks = JwkSet { keys: vec![] };
1187
1188        let result = decode_and_verify_jwt(&jwt, &jwks);
1189        assert!(result.is_err());
1190        let err = result.unwrap_err();
1191        assert!(
1192            matches!(&err, AptosError::InvalidJwt(msg) if msg.contains("kid")),
1193            "Expected error about missing kid, got: {err:?}"
1194        );
1195    }
1196
1197    #[test]
1198    fn test_decode_and_verify_jwt_no_matching_key() {
1199        let now = SystemTime::now()
1200            .duration_since(UNIX_EPOCH)
1201            .expect("time went backwards")
1202            .as_secs();
1203
1204        let claims = TestClaims {
1205            iss: "https://accounts.google.com".to_string(),
1206            aud: "test-aud".to_string(),
1207            sub: "test-sub".to_string(),
1208            exp: now + 3600,
1209            nonce: "test-nonce".to_string(),
1210        };
1211
1212        // Create JWT with a kid in header (using HS256 for encoding)
1213        let mut header = Header::new(Algorithm::HS256);
1214        header.kid = Some("test-kid-123".to_string());
1215
1216        let jwt = encode(&header, &claims, &EncodingKey::from_secret(b"secret")).unwrap();
1217
1218        // Empty JWKS - no matching key
1219        let jwks = JwkSet { keys: vec![] };
1220
1221        let result = decode_and_verify_jwt(&jwt, &jwks);
1222        assert!(result.is_err());
1223        let err = result.unwrap_err();
1224        assert!(
1225            matches!(&err, AptosError::InvalidJwt(msg) if msg.contains("no matching key")),
1226            "Expected error about no matching key, got: {err:?}"
1227        );
1228    }
1229
1230    #[test]
1231    fn test_decode_and_verify_jwt_invalid_jwt_format() {
1232        let jwks = JwkSet { keys: vec![] };
1233
1234        // Completely invalid JWT
1235        let result = decode_and_verify_jwt("not-a-valid-jwt", &jwks);
1236        assert!(result.is_err());
1237
1238        // JWT with invalid base64
1239        let result = decode_and_verify_jwt("aaa.bbb.ccc", &jwks);
1240        assert!(result.is_err());
1241    }
1242}