fenghuo/packages/oidc-provider/src/provider.ts

663 lines
24 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { nanoid } from 'nanoid';
import type { Context } from 'hono';
import { z } from 'zod';
import {
authorizationRequestSchema,
tokenRequestSchema,
authorizationQuerySchema,
tokenFormDataSchema,
bearerTokenSchema,
basicAuthSchema,
revokeTokenRequestSchema,
introspectTokenRequestSchema,
} from './schemas';
import type {
OIDCProviderConfig,
OIDCClient,
OIDCUser,
AuthorizationCode,
AuthorizationRequest,
TokenRequest,
TokenResponse,
OIDCError,
DiscoveryDocument,
} from './types';
import type { StorageAdapter } from './storage/adapter';
import { TokenManager } from './auth/token-manager';
import { JWTUtils } from './utils/jwt';
import type { KeyPair } from './utils/jwt';
import { PKCEUtils } from './utils/pkce';
import { ValidationUtils } from './utils/validation';
import { PasswordAuth } from './auth';
/**
* OIDC Provider - 简化版本,符合规范且代码简洁
*/
export class OIDCProvider {
private readonly config: OIDCProviderConfig;
private readonly tokenTTL: Required<NonNullable<OIDCProviderConfig['tokenTTL']>>;
private readonly storage: StorageAdapter;
private readonly tokenManager: TokenManager;
private readonly jwtUtils: JWTUtils;
private readonly findUser: (userId: string) => Promise<OIDCUser | null>;
private readonly findClient: (clientId: string) => Promise<OIDCClient | null>;
private passwordAuth: PasswordAuth;
constructor(config: OIDCProviderConfig) {
this.validateConfig(config);
this.config = config;
this.tokenTTL = {
accessToken: config.tokenTTL?.accessToken || 3600,
refreshToken: config.tokenTTL?.refreshToken || 2592000,
authorizationCode: config.tokenTTL?.authorizationCode || 600,
idToken: config.tokenTTL?.idToken || 3600,
};
this.storage = config.storage;
this.tokenManager = new TokenManager(config.storage);
this.findUser = config.findUser;
this.findClient = config.findClient;
// 创建JWT工具类使用简化的构造函数
this.jwtUtils = new JWTUtils(config.signingKey);
this.passwordAuth = new PasswordAuth(config, config.authConfig.passwordValidator, config.authConfig);
}
private validateConfig(config: OIDCProviderConfig): void {
const required = ['issuer', 'storage', 'findUser', 'findClient', 'authConfig'];
for (const field of required) {
if (!config[field as keyof OIDCProviderConfig]) {
throw new Error(`配置项 ${field} 是必需的`);
}
}
// 如果提供了signingKey且是字符串说明使用HMAC
// 如果没有提供signingKey将自动生成RSA密钥
}
private get defaultScopes(): string[] {
return this.config.scopes || ['openid', 'profile', 'email'];
}
private get supportedResponseTypes(): string[] {
return this.config.responseTypes || ['code'];
}
private get supportedGrantTypes(): string[] {
return this.config.grantTypes || ['authorization_code', 'refresh_token'];
}
private get supportedClaims(): string[] {
return this.config.claims || ['sub', 'name', 'email', 'email_verified'];
}
private get enablePKCE(): boolean {
return this.config.enablePKCE ?? true;
}
private get requirePKCE(): boolean {
return this.config.requirePKCE ?? false;
}
private get rotateRefreshTokens(): boolean {
return this.config.rotateRefreshTokens ?? true;
}
async findToken(token: string): Promise<{ tokenData: any; type: 'access' | 'refresh' } | null> {
const [accessToken, refreshToken] = await Promise.all([
this.tokenManager.getAccessToken(token),
this.tokenManager.getRefreshToken(token)
]);
const tokenData = accessToken || refreshToken;
if (tokenData && tokenData.expires_at < new Date()) {
// 清理过期token
await Promise.all([
accessToken && this.tokenManager.deleteAccessToken(token),
refreshToken && this.tokenManager.deleteRefreshToken(token),
].filter(Boolean));
return null;
}
if (!tokenData) return null;
return { tokenData, type: accessToken ? 'access' : 'refresh' };
}
parseAuthRequest(query: Record<string, string | string[]>): AuthorizationRequest {
const normalized = authorizationQuerySchema.parse(query);
return authorizationRequestSchema.parse({
response_type: normalized.response_type || '',
client_id: normalized.client_id || '',
redirect_uri: normalized.redirect_uri || '',
scope: normalized.scope || 'openid',
state: normalized.state,
nonce: normalized.nonce,
code_challenge: normalized.code_challenge,
code_challenge_method: normalized.code_challenge_method as 'plain' | 'S256' | undefined,
prompt: normalized.prompt,
max_age: normalized.max_age ? parseInt(normalized.max_age, 10) : undefined,
id_token_hint: normalized.id_token_hint,
login_hint: normalized.login_hint,
acr_values: normalized.acr_values,
});
}
parseTokenRequest(body: FormData, authHeader?: string): TokenRequest {
const formData = tokenFormDataSchema.parse(body);
// 处理客户端认证
if (authHeader?.startsWith('Basic ')) {
const basicAuth = basicAuthSchema.parse(authHeader);
formData.client_id = basicAuth.username || formData.client_id || '';
formData.client_secret = basicAuth.password || formData.client_secret || '';
}
return tokenRequestSchema.parse({
grant_type: formData.grant_type || '',
client_id: formData.client_id || '',
code: formData.code,
redirect_uri: formData.redirect_uri,
client_secret: formData.client_secret,
refresh_token: formData.refresh_token,
code_verifier: formData.code_verifier,
scope: formData.scope,
});
}
async validateClient(clientId: string, clientSecret?: string): Promise<OIDCClient> {
const client = await this.findClient(clientId);
if (!client) throw new Error('客户端不存在');
// 机密客户端需要密钥
if (client.client_type === 'confidential' && clientSecret !== client.client_secret) {
throw new Error('客户端认证失败');
}
// 公开客户端不应该发送密钥
if (client.client_type === 'public' && clientSecret) {
throw new Error('公开客户端不应发送密钥');
}
return client;
}
validatePKCE(request: AuthorizationRequest, client: OIDCClient): void {
const isPublic = client.client_type === 'public';
const hasChallenge = !!request.code_challenge;
if (this.requirePKCE && isPublic && !hasChallenge) {
throw new Error('公开客户端必须使用PKCE');
}
if (hasChallenge && request.code_challenge) {
const codeChallenge = request.code_challenge;
const method = request.code_challenge_method || 'plain';
if (!PKCEUtils.isValidCodeChallenge(codeChallenge, method)) {
throw new Error('无效的PKCE参数');
}
}
}
getDiscoveryDocument(): DiscoveryDocument {
const baseUrl = this.config.issuer;
// 根据signingKey类型确定签名算法
const signingAlgorithm = typeof this.config.signingKey === 'string' ? 'HS256' : 'RS256';
return {
issuer: this.config.issuer,
authorization_endpoint: `${baseUrl}/auth`,
token_endpoint: `${baseUrl}/token`,
userinfo_endpoint: `${baseUrl}/userinfo`,
jwks_uri: `${baseUrl}/.well-known/jwks.json`,
revocation_endpoint: `${baseUrl}/revoke`,
introspection_endpoint: `${baseUrl}/introspect`,
end_session_endpoint: `${baseUrl}/logout`,
response_types_supported: this.supportedResponseTypes,
grant_types_supported: this.supportedGrantTypes,
scopes_supported: this.defaultScopes,
claims_supported: this.supportedClaims,
token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post', 'none'],
id_token_signing_alg_values_supported: [signingAlgorithm],
subject_types_supported: ['public'],
code_challenge_methods_supported: this.enablePKCE ? ['plain', 'S256'] : undefined,
response_modes_supported: ['query', 'fragment'],
};
}
async handleAuthorizationRequest(request: AuthorizationRequest, userId?: string) {
// 基本验证
if (!request.client_id || !request.response_type || !request.redirect_uri) {
throw new Error('缺少必需参数');
}
const client = await this.findClient(request.client_id);
if (!client || !client.redirect_uris.includes(request.redirect_uri)) {
throw new Error('无效的客户端或重定向URI');
}
// 验证请求
const validation = ValidationUtils.validateAuthorizationRequest(
request, client, this.defaultScopes, this.supportedResponseTypes
);
if (!validation.valid) {
throw new Error(validation.errors.join(', '));
}
if (!userId) {
throw new Error('需要用户认证');
}
this.validatePKCE(request, client);
// 生成授权码
if (request.response_type === 'code') {
const code = nanoid(32);
const authCode: AuthorizationCode = {
code,
client_id: request.client_id,
user_id: userId,
redirect_uri: request.redirect_uri,
scope: request.scope,
code_challenge: request.code_challenge,
code_challenge_method: request.code_challenge_method,
nonce: request.nonce,
state: request.state,
expires_at: new Date(Date.now() + this.tokenTTL.authorizationCode * 1000),
created_at: new Date(),
};
await this.tokenManager.storeAuthorizationCode(authCode);
return { code, state: request.state };
}
throw new Error(`不支持的响应类型: ${request.response_type}`);
}
async handleTokenRequest(request: TokenRequest): Promise<TokenResponse> {
if (!request.grant_type || !request.client_id) {
throw new Error('缺少必需参数');
}
const client = await this.validateClient(request.client_id, request.client_secret);
switch (request.grant_type) {
case 'authorization_code':
return this.handleAuthorizationCodeGrant(request, client);
case 'refresh_token':
return this.handleRefreshTokenGrant(request, client);
default:
throw new Error(`不支持的授权类型: ${request.grant_type}`);
}
}
private async handleAuthorizationCodeGrant(request: TokenRequest, client: OIDCClient): Promise<TokenResponse> {
if (!request.code || !request.redirect_uri) {
throw new Error('缺少授权码或重定向URI');
}
const authCode = await this.tokenManager.getAuthorizationCode(request.code);
if (!authCode || authCode.expires_at < new Date()) {
if (authCode) await this.tokenManager.deleteAuthorizationCode(request.code);
throw new Error('授权码无效或已过期');
}
// 验证授权码
if (authCode.client_id !== request.client_id || authCode.redirect_uri !== request.redirect_uri) {
throw new Error('授权码不匹配');
}
// 验证PKCE
if (authCode.code_challenge) {
if (!request.code_verifier || !PKCEUtils.verifyCodeChallenge(
request.code_verifier, authCode.code_challenge, authCode.code_challenge_method || 'plain'
)) {
throw new Error('PKCE验证失败');
}
}
const user = await this.findUser(authCode.user_id);
if (!user) throw new Error('用户不存在');
// 删除已使用的授权码
await this.tokenManager.deleteAuthorizationCode(request.code);
return this.generateTokens(user, client, authCode.scope, authCode.nonce);
}
private async handleRefreshTokenGrant(request: TokenRequest, client: OIDCClient): Promise<TokenResponse> {
if (!request.refresh_token) {
throw new Error('缺少刷新token');
}
const refreshToken = await this.tokenManager.getRefreshToken(request.refresh_token);
if (!refreshToken || refreshToken.expires_at < new Date()) {
if (refreshToken) await this.tokenManager.deleteRefreshToken(request.refresh_token);
throw new Error('刷新token无效或已过期');
}
if (refreshToken.client_id !== request.client_id) {
throw new Error('刷新token客户端不匹配');
}
const user = await this.findUser(refreshToken.user_id);
if (!user) throw new Error('用户不存在');
// 验证scope
let scope = refreshToken.scope;
if (request.scope) {
const requestedScopes = request.scope.split(' ');
const originalScopes = refreshToken.scope.split(' ');
if (!requestedScopes.every(s => originalScopes.includes(s))) {
throw new Error('请求的scope超出原始scope');
}
scope = request.scope;
}
const tokens = await this.generateTokens(user, client, scope);
// 旋转刷新token
if (this.rotateRefreshTokens) {
await this.tokenManager.deleteRefreshToken(request.refresh_token);
} else {
tokens.refresh_token = request.refresh_token;
}
return tokens;
}
private async generateTokens(
user: OIDCUser,
client: OIDCClient,
scope: string,
nonce?: string
): Promise<TokenResponse> {
const now = new Date();
const [accessToken, refreshTokenValue] = await Promise.all([
this.jwtUtils.generateAccessToken({
issuer: this.config.issuer,
subject: user.sub,
audience: client.client_id,
clientId: client.client_id,
scope,
expiresIn: this.tokenTTL.accessToken,
}),
nanoid(64)
]);
// 生成ID Token如果scope包含openid
let idToken: string | undefined;
if (scope.includes('openid')) {
idToken = await this.jwtUtils.generateIDToken({
issuer: this.config.issuer,
subject: user.sub,
audience: client.client_id,
user,
authTime: Math.floor(now.getTime() / 1000),
nonce,
expiresIn: this.tokenTTL.idToken,
});
}
// 存储tokens
const storeOperations = [
this.tokenManager.storeAccessToken({
token: accessToken,
client_id: client.client_id,
user_id: user.sub,
scope,
expires_at: new Date(now.getTime() + this.tokenTTL.accessToken * 1000),
created_at: now,
}),
this.tokenManager.storeRefreshToken({
token: refreshTokenValue,
client_id: client.client_id,
user_id: user.sub,
scope,
expires_at: new Date(now.getTime() + this.tokenTTL.refreshToken * 1000),
created_at: now,
})
];
if (idToken) {
storeOperations.push(this.tokenManager.storeIDToken({
token: idToken,
client_id: client.client_id,
user_id: user.sub,
nonce,
expires_at: new Date(now.getTime() + this.tokenTTL.idToken * 1000),
created_at: now,
}));
}
await Promise.all(storeOperations);
const response: TokenResponse = {
access_token: accessToken,
token_type: 'Bearer',
expires_in: this.tokenTTL.accessToken,
refresh_token: refreshTokenValue,
scope,
};
if (idToken) response.id_token = idToken;
return response;
}
async getUserInfo(accessToken: string): Promise<Partial<OIDCUser>> {
const tokenData = await this.tokenManager.getAccessToken(accessToken);
if (!tokenData || tokenData.expires_at < new Date()) {
if (tokenData) await this.tokenManager.deleteAccessToken(accessToken);
throw new Error('Token无效或已过期');
}
// 验证JWT
await this.jwtUtils.verifyToken(accessToken);
const user = await this.findUser(tokenData.user_id);
if (!user) throw new Error('用户不存在');
const claims = this.getClaimsForScope(tokenData.scope);
return this.filterUserClaims(user, claims);
}
private getClaimsForScope(scope: string): string[] {
const scopes = scope.split(' ');
const claimsMap = {
profile: ['name', 'given_name', 'family_name', 'picture'],
email: ['email', 'email_verified'],
phone: ['phone_number', 'phone_number_verified'],
address: ['address'],
};
const claims = new Set(['sub']);
scopes.forEach(s => {
if (s in claimsMap) {
claimsMap[s as keyof typeof claimsMap].forEach(c => claims.add(c));
}
});
return Array.from(claims);
}
private filterUserClaims(user: OIDCUser, claims: string[]): Partial<OIDCUser> {
return Object.fromEntries(
claims
.filter(claim => claim in user && user[claim as keyof OIDCUser] !== undefined)
.map(claim => [claim, user[claim as keyof OIDCUser]])
);
}
async revokeToken(token: string, clientId?: string): Promise<void> {
const result = await this.findToken(token);
if (!result) return; // Token不存在或已过期
const { tokenData, type } = result;
if (clientId && tokenData.client_id !== clientId) {
throw new Error('Token不属于此客户端');
}
if (type === 'access') {
await this.tokenManager.deleteAccessToken(token);
} else {
// 撤销刷新token时同时撤销相关的访问token
await Promise.all([
this.tokenManager.deleteRefreshToken(token),
this.tokenManager.deleteAccessTokensByUserAndClient(tokenData.user_id, tokenData.client_id)
]);
}
}
async introspectToken(token: string, clientId?: string) {
const result = await this.findToken(token);
if (!result) return { active: false };
const { tokenData, type } = result;
if (clientId && tokenData.client_id !== clientId) {
return { active: false };
}
const user = await this.findUser(tokenData.user_id);
return {
active: true,
scope: tokenData.scope,
client_id: tokenData.client_id,
username: user?.username,
token_type: type === 'access' ? 'Bearer' : 'refresh_token',
exp: Math.floor(tokenData.expires_at.getTime() / 1000),
iat: Math.floor(tokenData.created_at.getTime() / 1000),
sub: tokenData.user_id,
aud: tokenData.client_id,
iss: this.config.issuer,
};
}
async getJWKS() {
return this.jwtUtils.generateJWKS();
}
// HTTP处理器
private createErrorResponse(error: unknown): { error: string; error_description: string } {
const message = error instanceof Error ? error.message : '服务器内部错误';
console.error('OIDC Provider Error:', error);
return {
error: 'invalid_request',
error_description: message
};
}
async handleLogin(c: Context): Promise<Response> {
try {
const formData = await c.req.formData();
const authRequest = Object.fromEntries(
['response_type', 'client_id', 'redirect_uri', 'scope', 'state', 'nonce', 'code_challenge', 'code_challenge_method']
.map(field => [field, formData.get(field)?.toString() || ''])
);
const authResult = await this.passwordAuth.authenticate(c);
return authResult.success
? await this.passwordAuth.handleAuthenticationSuccess(c, authResult)
: await this.passwordAuth.handleAuthenticationFailure(c, authResult, authRequest as any);
} catch (error) {
return c.json(this.createErrorResponse(error), 500);
}
}
async handleLogout(c: Context): Promise<Response> {
return this.passwordAuth.logout(c);
}
async handleAuthorization(c: Context): Promise<Response> {
try {
const authRequest = this.parseAuthRequest(c.req.query());
const userId = await this.passwordAuth.getCurrentUser(c);
if (!userId) {
return this.passwordAuth.handleAuthenticationRequired(c, authRequest);
}
const result = await this.handleAuthorizationRequest(authRequest, userId);
const params = new URLSearchParams();
if (result.code) params.set('code', result.code);
if (result.state) params.set('state', result.state);
return c.redirect(`${authRequest.redirect_uri}?${params.toString()}`);
} catch (error) {
return c.json(this.createErrorResponse(error), 400);
}
}
async handleToken(c: Context): Promise<Response> {
try {
const body = await c.req.formData();
const tokenRequest = this.parseTokenRequest(body, c.req.header('Authorization'));
const response = await this.handleTokenRequest(tokenRequest);
c.header('Cache-Control', 'no-store');
c.header('Pragma', 'no-cache');
return c.json(response);
} catch (error) {
return c.json(this.createErrorResponse(error), 400);
}
}
async handleUserInfo(c: Context): Promise<Response> {
try {
const authHeader = c.req.header('Authorization');
if (!authHeader?.startsWith('Bearer ')) {
c.header('WWW-Authenticate', 'Bearer realm="userinfo"');
return c.json({ error: 'invalid_token', error_description: 'Bearer token required' }, 401);
}
const token = bearerTokenSchema.parse(authHeader);
const user = await this.getUserInfo(token);
return c.json(user);
} catch (error) {
c.header('WWW-Authenticate', 'Bearer realm="userinfo"');
return c.json({ error: 'invalid_token', error_description: 'Invalid token' }, 401);
}
}
async handleRevoke(c: Context): Promise<Response> {
try {
const body = await c.req.formData();
const { token, client_id } = revokeTokenRequestSchema.parse({
token: body.get('token')?.toString(),
client_id: body.get('client_id')?.toString(),
});
if (!token) {
return c.json({ error: 'invalid_request', error_description: 'Missing token' }, 400);
}
await this.revokeToken(token, client_id);
return c.body(null, 200);
} catch (error) {
return c.json(this.createErrorResponse(error), 400);
}
}
async handleIntrospect(c: Context): Promise<Response> {
try {
const body = await c.req.formData();
const { token, client_id } = introspectTokenRequestSchema.parse({
token: body.get('token')?.toString(),
client_id: body.get('client_id')?.toString(),
});
if (!token) {
return c.json({ error: 'invalid_request', error_description: 'Missing token' }, 400);
}
const result = await this.introspectToken(token, client_id);
return c.json(result);
} catch (error) {
return c.json(this.createErrorResponse(error), 400);
}
}
}