import { nanoid } from 'nanoid'; import type { Context } from 'hono'; import { z } from 'zod'; import { authorizationRequestSchema, tokenRequestSchema, authorizationQuerySchema, tokenFormDataSchema, bearerTokenSchema, basicAuthSchema, revokeTokenRequestSchema, introspectTokenRequestSchema, } from './schemas'; import type { OIDCProviderConfig, OIDCClient, OIDCUser, AuthorizationCode, AuthorizationRequest, TokenRequest, TokenResponse, OIDCError, DiscoveryDocument, } from './types'; import type { StorageAdapter } from './storage/adapter'; import { TokenManager } from './auth/token-manager'; import { JWTUtils } from './utils/jwt'; import type { KeyPair } from './utils/jwt'; import { PKCEUtils } from './utils/pkce'; import { ValidationUtils } from './utils/validation'; import { PasswordAuth } from './auth'; /** * OIDC Provider - 简化版本,符合规范且代码简洁 */ export class OIDCProvider { private readonly config: OIDCProviderConfig; private readonly tokenTTL: Required>; private readonly storage: StorageAdapter; private readonly tokenManager: TokenManager; private readonly jwtUtils: JWTUtils; private readonly findUser: (userId: string) => Promise; private readonly findClient: (clientId: string) => Promise; private passwordAuth: PasswordAuth; constructor(config: OIDCProviderConfig) { this.validateConfig(config); this.config = config; this.tokenTTL = { accessToken: config.tokenTTL?.accessToken || 3600, refreshToken: config.tokenTTL?.refreshToken || 2592000, authorizationCode: config.tokenTTL?.authorizationCode || 600, idToken: config.tokenTTL?.idToken || 3600, }; this.storage = config.storage; this.tokenManager = new TokenManager(config.storage); this.findUser = config.findUser; this.findClient = config.findClient; // 创建JWT工具类,使用简化的构造函数 this.jwtUtils = new JWTUtils(config.signingKey); this.passwordAuth = new PasswordAuth(config, config.authConfig.passwordValidator, config.authConfig); } private validateConfig(config: OIDCProviderConfig): void { const required = ['issuer', 'storage', 'findUser', 'findClient', 'authConfig']; for (const field of required) { if (!config[field as keyof OIDCProviderConfig]) { throw new Error(`配置项 ${field} 是必需的`); } } // 如果提供了signingKey且是字符串,说明使用HMAC // 如果没有提供signingKey,将自动生成RSA密钥 } private get defaultScopes(): string[] { return this.config.scopes || ['openid', 'profile', 'email']; } private get supportedResponseTypes(): string[] { return this.config.responseTypes || ['code']; } private get supportedGrantTypes(): string[] { return this.config.grantTypes || ['authorization_code', 'refresh_token']; } private get supportedClaims(): string[] { return this.config.claims || ['sub', 'name', 'email', 'email_verified']; } private get enablePKCE(): boolean { return this.config.enablePKCE ?? true; } private get requirePKCE(): boolean { return this.config.requirePKCE ?? false; } private get rotateRefreshTokens(): boolean { return this.config.rotateRefreshTokens ?? true; } async findToken(token: string): Promise<{ tokenData: any; type: 'access' | 'refresh' } | null> { const [accessToken, refreshToken] = await Promise.all([ this.tokenManager.getAccessToken(token), this.tokenManager.getRefreshToken(token) ]); const tokenData = accessToken || refreshToken; if (tokenData && tokenData.expires_at < new Date()) { // 清理过期token await Promise.all([ accessToken && this.tokenManager.deleteAccessToken(token), refreshToken && this.tokenManager.deleteRefreshToken(token), ].filter(Boolean)); return null; } if (!tokenData) return null; return { tokenData, type: accessToken ? 'access' : 'refresh' }; } parseAuthRequest(query: Record): AuthorizationRequest { const normalized = authorizationQuerySchema.parse(query); return authorizationRequestSchema.parse({ response_type: normalized.response_type || '', client_id: normalized.client_id || '', redirect_uri: normalized.redirect_uri || '', scope: normalized.scope || 'openid', state: normalized.state, nonce: normalized.nonce, code_challenge: normalized.code_challenge, code_challenge_method: normalized.code_challenge_method as 'plain' | 'S256' | undefined, prompt: normalized.prompt, max_age: normalized.max_age ? parseInt(normalized.max_age, 10) : undefined, id_token_hint: normalized.id_token_hint, login_hint: normalized.login_hint, acr_values: normalized.acr_values, }); } parseTokenRequest(body: FormData, authHeader?: string): TokenRequest { const formData = tokenFormDataSchema.parse(body); // 处理客户端认证 if (authHeader?.startsWith('Basic ')) { const basicAuth = basicAuthSchema.parse(authHeader); formData.client_id = basicAuth.username || formData.client_id || ''; formData.client_secret = basicAuth.password || formData.client_secret || ''; } return tokenRequestSchema.parse({ grant_type: formData.grant_type || '', client_id: formData.client_id || '', code: formData.code, redirect_uri: formData.redirect_uri, client_secret: formData.client_secret, refresh_token: formData.refresh_token, code_verifier: formData.code_verifier, scope: formData.scope, }); } async validateClient(clientId: string, clientSecret?: string): Promise { const client = await this.findClient(clientId); if (!client) throw new Error('客户端不存在'); // 机密客户端需要密钥 if (client.client_type === 'confidential' && clientSecret !== client.client_secret) { throw new Error('客户端认证失败'); } // 公开客户端不应该发送密钥 if (client.client_type === 'public' && clientSecret) { throw new Error('公开客户端不应发送密钥'); } return client; } validatePKCE(request: AuthorizationRequest, client: OIDCClient): void { const isPublic = client.client_type === 'public'; const hasChallenge = !!request.code_challenge; if (this.requirePKCE && isPublic && !hasChallenge) { throw new Error('公开客户端必须使用PKCE'); } if (hasChallenge && request.code_challenge) { const codeChallenge = request.code_challenge; const method = request.code_challenge_method || 'plain'; if (!PKCEUtils.isValidCodeChallenge(codeChallenge, method)) { throw new Error('无效的PKCE参数'); } } } getDiscoveryDocument(): DiscoveryDocument { const baseUrl = this.config.issuer; // 根据signingKey类型确定签名算法 const signingAlgorithm = typeof this.config.signingKey === 'string' ? 'HS256' : 'RS256'; return { issuer: this.config.issuer, authorization_endpoint: `${baseUrl}/auth`, token_endpoint: `${baseUrl}/token`, userinfo_endpoint: `${baseUrl}/userinfo`, jwks_uri: `${baseUrl}/.well-known/jwks.json`, revocation_endpoint: `${baseUrl}/revoke`, introspection_endpoint: `${baseUrl}/introspect`, end_session_endpoint: `${baseUrl}/logout`, response_types_supported: this.supportedResponseTypes, grant_types_supported: this.supportedGrantTypes, scopes_supported: this.defaultScopes, claims_supported: this.supportedClaims, token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post', 'none'], id_token_signing_alg_values_supported: [signingAlgorithm], subject_types_supported: ['public'], code_challenge_methods_supported: this.enablePKCE ? ['plain', 'S256'] : undefined, response_modes_supported: ['query', 'fragment'], }; } async handleAuthorizationRequest(request: AuthorizationRequest, userId?: string) { // 基本验证 if (!request.client_id || !request.response_type || !request.redirect_uri) { throw new Error('缺少必需参数'); } const client = await this.findClient(request.client_id); if (!client || !client.redirect_uris.includes(request.redirect_uri)) { throw new Error('无效的客户端或重定向URI'); } // 验证请求 const validation = ValidationUtils.validateAuthorizationRequest( request, client, this.defaultScopes, this.supportedResponseTypes ); if (!validation.valid) { throw new Error(validation.errors.join(', ')); } if (!userId) { throw new Error('需要用户认证'); } this.validatePKCE(request, client); // 生成授权码 if (request.response_type === 'code') { const code = nanoid(32); const authCode: AuthorizationCode = { code, client_id: request.client_id, user_id: userId, redirect_uri: request.redirect_uri, scope: request.scope, code_challenge: request.code_challenge, code_challenge_method: request.code_challenge_method, nonce: request.nonce, state: request.state, expires_at: new Date(Date.now() + this.tokenTTL.authorizationCode * 1000), created_at: new Date(), }; await this.tokenManager.storeAuthorizationCode(authCode); return { code, state: request.state }; } throw new Error(`不支持的响应类型: ${request.response_type}`); } async handleTokenRequest(request: TokenRequest): Promise { if (!request.grant_type || !request.client_id) { throw new Error('缺少必需参数'); } const client = await this.validateClient(request.client_id, request.client_secret); switch (request.grant_type) { case 'authorization_code': return this.handleAuthorizationCodeGrant(request, client); case 'refresh_token': return this.handleRefreshTokenGrant(request, client); default: throw new Error(`不支持的授权类型: ${request.grant_type}`); } } private async handleAuthorizationCodeGrant(request: TokenRequest, client: OIDCClient): Promise { if (!request.code || !request.redirect_uri) { throw new Error('缺少授权码或重定向URI'); } const authCode = await this.tokenManager.getAuthorizationCode(request.code); if (!authCode || authCode.expires_at < new Date()) { if (authCode) await this.tokenManager.deleteAuthorizationCode(request.code); throw new Error('授权码无效或已过期'); } // 验证授权码 if (authCode.client_id !== request.client_id || authCode.redirect_uri !== request.redirect_uri) { throw new Error('授权码不匹配'); } // 验证PKCE if (authCode.code_challenge) { if (!request.code_verifier || !PKCEUtils.verifyCodeChallenge( request.code_verifier, authCode.code_challenge, authCode.code_challenge_method || 'plain' )) { throw new Error('PKCE验证失败'); } } const user = await this.findUser(authCode.user_id); if (!user) throw new Error('用户不存在'); // 删除已使用的授权码 await this.tokenManager.deleteAuthorizationCode(request.code); return this.generateTokens(user, client, authCode.scope, authCode.nonce); } private async handleRefreshTokenGrant(request: TokenRequest, client: OIDCClient): Promise { if (!request.refresh_token) { throw new Error('缺少刷新token'); } const refreshToken = await this.tokenManager.getRefreshToken(request.refresh_token); if (!refreshToken || refreshToken.expires_at < new Date()) { if (refreshToken) await this.tokenManager.deleteRefreshToken(request.refresh_token); throw new Error('刷新token无效或已过期'); } if (refreshToken.client_id !== request.client_id) { throw new Error('刷新token客户端不匹配'); } const user = await this.findUser(refreshToken.user_id); if (!user) throw new Error('用户不存在'); // 验证scope let scope = refreshToken.scope; if (request.scope) { const requestedScopes = request.scope.split(' '); const originalScopes = refreshToken.scope.split(' '); if (!requestedScopes.every(s => originalScopes.includes(s))) { throw new Error('请求的scope超出原始scope'); } scope = request.scope; } const tokens = await this.generateTokens(user, client, scope); // 旋转刷新token if (this.rotateRefreshTokens) { await this.tokenManager.deleteRefreshToken(request.refresh_token); } else { tokens.refresh_token = request.refresh_token; } return tokens; } private async generateTokens( user: OIDCUser, client: OIDCClient, scope: string, nonce?: string ): Promise { const now = new Date(); const [accessToken, refreshTokenValue] = await Promise.all([ this.jwtUtils.generateAccessToken({ issuer: this.config.issuer, subject: user.sub, audience: client.client_id, clientId: client.client_id, scope, expiresIn: this.tokenTTL.accessToken, }), nanoid(64) ]); // 生成ID Token(如果scope包含openid) let idToken: string | undefined; if (scope.includes('openid')) { idToken = await this.jwtUtils.generateIDToken({ issuer: this.config.issuer, subject: user.sub, audience: client.client_id, user, authTime: Math.floor(now.getTime() / 1000), nonce, expiresIn: this.tokenTTL.idToken, }); } // 存储tokens const storeOperations = [ this.tokenManager.storeAccessToken({ token: accessToken, client_id: client.client_id, user_id: user.sub, scope, expires_at: new Date(now.getTime() + this.tokenTTL.accessToken * 1000), created_at: now, }), this.tokenManager.storeRefreshToken({ token: refreshTokenValue, client_id: client.client_id, user_id: user.sub, scope, expires_at: new Date(now.getTime() + this.tokenTTL.refreshToken * 1000), created_at: now, }) ]; if (idToken) { storeOperations.push(this.tokenManager.storeIDToken({ token: idToken, client_id: client.client_id, user_id: user.sub, nonce, expires_at: new Date(now.getTime() + this.tokenTTL.idToken * 1000), created_at: now, })); } await Promise.all(storeOperations); const response: TokenResponse = { access_token: accessToken, token_type: 'Bearer', expires_in: this.tokenTTL.accessToken, refresh_token: refreshTokenValue, scope, }; if (idToken) response.id_token = idToken; return response; } async getUserInfo(accessToken: string): Promise> { const tokenData = await this.tokenManager.getAccessToken(accessToken); if (!tokenData || tokenData.expires_at < new Date()) { if (tokenData) await this.tokenManager.deleteAccessToken(accessToken); throw new Error('Token无效或已过期'); } // 验证JWT await this.jwtUtils.verifyToken(accessToken); const user = await this.findUser(tokenData.user_id); if (!user) throw new Error('用户不存在'); const claims = this.getClaimsForScope(tokenData.scope); return this.filterUserClaims(user, claims); } private getClaimsForScope(scope: string): string[] { const scopes = scope.split(' '); const claimsMap = { profile: ['name', 'given_name', 'family_name', 'picture'], email: ['email', 'email_verified'], phone: ['phone_number', 'phone_number_verified'], address: ['address'], }; const claims = new Set(['sub']); scopes.forEach(s => { if (s in claimsMap) { claimsMap[s as keyof typeof claimsMap].forEach(c => claims.add(c)); } }); return Array.from(claims); } private filterUserClaims(user: OIDCUser, claims: string[]): Partial { return Object.fromEntries( claims .filter(claim => claim in user && user[claim as keyof OIDCUser] !== undefined) .map(claim => [claim, user[claim as keyof OIDCUser]]) ); } async revokeToken(token: string, clientId?: string): Promise { const result = await this.findToken(token); if (!result) return; // Token不存在或已过期 const { tokenData, type } = result; if (clientId && tokenData.client_id !== clientId) { throw new Error('Token不属于此客户端'); } if (type === 'access') { await this.tokenManager.deleteAccessToken(token); } else { // 撤销刷新token时,同时撤销相关的访问token await Promise.all([ this.tokenManager.deleteRefreshToken(token), this.tokenManager.deleteAccessTokensByUserAndClient(tokenData.user_id, tokenData.client_id) ]); } } async introspectToken(token: string, clientId?: string) { const result = await this.findToken(token); if (!result) return { active: false }; const { tokenData, type } = result; if (clientId && tokenData.client_id !== clientId) { return { active: false }; } const user = await this.findUser(tokenData.user_id); return { active: true, scope: tokenData.scope, client_id: tokenData.client_id, username: user?.username, token_type: type === 'access' ? 'Bearer' : 'refresh_token', exp: Math.floor(tokenData.expires_at.getTime() / 1000), iat: Math.floor(tokenData.created_at.getTime() / 1000), sub: tokenData.user_id, aud: tokenData.client_id, iss: this.config.issuer, }; } async getJWKS() { return this.jwtUtils.generateJWKS(); } // HTTP处理器 private createErrorResponse(error: unknown): { error: string; error_description: string } { const message = error instanceof Error ? error.message : '服务器内部错误'; console.error('OIDC Provider Error:', error); return { error: 'invalid_request', error_description: message }; } async handleLogin(c: Context): Promise { try { const formData = await c.req.formData(); const authRequest = Object.fromEntries( ['response_type', 'client_id', 'redirect_uri', 'scope', 'state', 'nonce', 'code_challenge', 'code_challenge_method'] .map(field => [field, formData.get(field)?.toString() || '']) ); const authResult = await this.passwordAuth.authenticate(c); return authResult.success ? await this.passwordAuth.handleAuthenticationSuccess(c, authResult) : await this.passwordAuth.handleAuthenticationFailure(c, authResult, authRequest as any); } catch (error) { return c.json(this.createErrorResponse(error), 500); } } async handleLogout(c: Context): Promise { return this.passwordAuth.logout(c); } async handleAuthorization(c: Context): Promise { try { const authRequest = this.parseAuthRequest(c.req.query()); const userId = await this.passwordAuth.getCurrentUser(c); if (!userId) { return this.passwordAuth.handleAuthenticationRequired(c, authRequest); } const result = await this.handleAuthorizationRequest(authRequest, userId); const params = new URLSearchParams(); if (result.code) params.set('code', result.code); if (result.state) params.set('state', result.state); return c.redirect(`${authRequest.redirect_uri}?${params.toString()}`); } catch (error) { return c.json(this.createErrorResponse(error), 400); } } async handleToken(c: Context): Promise { try { const body = await c.req.formData(); const tokenRequest = this.parseTokenRequest(body, c.req.header('Authorization')); const response = await this.handleTokenRequest(tokenRequest); c.header('Cache-Control', 'no-store'); c.header('Pragma', 'no-cache'); return c.json(response); } catch (error) { return c.json(this.createErrorResponse(error), 400); } } async handleUserInfo(c: Context): Promise { try { const authHeader = c.req.header('Authorization'); if (!authHeader?.startsWith('Bearer ')) { c.header('WWW-Authenticate', 'Bearer realm="userinfo"'); return c.json({ error: 'invalid_token', error_description: 'Bearer token required' }, 401); } const token = bearerTokenSchema.parse(authHeader); const user = await this.getUserInfo(token); return c.json(user); } catch (error) { c.header('WWW-Authenticate', 'Bearer realm="userinfo"'); return c.json({ error: 'invalid_token', error_description: 'Invalid token' }, 401); } } async handleRevoke(c: Context): Promise { try { const body = await c.req.formData(); const { token, client_id } = revokeTokenRequestSchema.parse({ token: body.get('token')?.toString(), client_id: body.get('client_id')?.toString(), }); if (!token) { return c.json({ error: 'invalid_request', error_description: 'Missing token' }, 400); } await this.revokeToken(token, client_id); return c.body(null, 200); } catch (error) { return c.json(this.createErrorResponse(error), 400); } } async handleIntrospect(c: Context): Promise { try { const body = await c.req.formData(); const { token, client_id } = introspectTokenRequestSchema.parse({ token: body.get('token')?.toString(), client_id: body.get('client_id')?.toString(), }); if (!token) { return c.json({ error: 'invalid_request', error_description: 'Missing token' }, 400); } const result = await this.introspectToken(token, client_id); return c.json(result); } catch (error) { return c.json(this.createErrorResponse(error), 400); } } }