diff --git a/src/main/java/net/siegeln/cameleer/saas/account/AccountService.java b/src/main/java/net/siegeln/cameleer/saas/account/AccountService.java new file mode 100644 index 0000000..b6d6ab7 --- /dev/null +++ b/src/main/java/net/siegeln/cameleer/saas/account/AccountService.java @@ -0,0 +1,240 @@ +package net.siegeln.cameleer.saas.account; + +import net.siegeln.cameleer.saas.identity.LogtoManagementClient; +import net.siegeln.cameleer.saas.notification.PasswordResetNotificationService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.web.server.ResponseStatusException; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +@Service +public class AccountService { + + private static final Logger log = LoggerFactory.getLogger(AccountService.class); + private static final String BASE32_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + + private final LogtoManagementClient logtoClient; + private final PasswordResetNotificationService passwordNotificationService; + + public AccountService(LogtoManagementClient logtoClient, + PasswordResetNotificationService passwordNotificationService) { + this.logtoClient = logtoClient; + this.passwordNotificationService = passwordNotificationService; + } + + // --- Records --- + + public record ProfileData(String userId, String name, String email) {} + public record MfaStatusData(boolean enrolled, boolean hasBackupCodes, boolean passkeyEnrolled, int passkeyCount) {} + public record MfaSetupData(String secret, String secretQrCode) {} + public record BackupCodesData(List codes) {} + public record PasskeyCredential(String id, String name, String agent, String createdAt) {} + + // --- Profile --- + + public ProfileData getProfile(String userId) { + var user = logtoClient.getUser(userId); + if (user == null) { + throw new ResponseStatusException(HttpStatus.NOT_FOUND, "User not found"); + } + return new ProfileData( + userId, + String.valueOf(user.getOrDefault("name", "")), + String.valueOf(user.getOrDefault("primaryEmail", "")) + ); + } + + public void updateDisplayName(String userId, String name) { + if (name == null || name.isBlank()) { + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Display name must not be blank"); + } + logtoClient.updateUserProfile(userId, Map.of("name", name.trim())); + } + + // --- Password --- + + public void validatePassword(String password) { + if (password == null || password.length() < 8) { + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Password must be at least 8 characters"); + } + } + + public void changePassword(String userId, String currentPassword, String newPassword) { + validatePassword(newPassword); + if (!logtoClient.verifyUserPassword(userId, currentPassword)) { + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Current password is incorrect"); + } + logtoClient.updateUserPassword(userId, newPassword); + + // Send confirmation email asynchronously + try { + var user = logtoClient.getUser(userId); + if (user != null) { + String email = String.valueOf(user.getOrDefault("primaryEmail", "")); + if (!email.isBlank()) { + passwordNotificationService.sendNotification(email); + } + } + } catch (Exception e) { + log.warn("Failed to send password change notification for user {}: {}", userId, e.getMessage()); + } + } + + // --- MFA --- + + public MfaStatusData getMfaStatus(String userId) { + var verifications = logtoClient.getUserMfaVerifications(userId); + boolean enrolled = verifications.stream() + .anyMatch(v -> "Totp".equals(String.valueOf(v.get("type")))); + boolean hasBackupCodes = verifications.stream() + .anyMatch(v -> "BackupCode".equals(String.valueOf(v.get("type")))); + long passkeyCount = verifications.stream() + .filter(v -> "WebAuthn".equals(String.valueOf(v.get("type")))) + .count(); + return new MfaStatusData(enrolled, hasBackupCodes, passkeyCount > 0, (int) passkeyCount); + } + + public MfaSetupData setupTotp(String userId) { + byte[] secretBytes = new byte[20]; + new SecureRandom().nextBytes(secretBytes); + String secret = base32Encode(secretBytes); + + var result = logtoClient.createTotpVerification(userId, secret); + String qrCode = result.containsKey("secretQrCode") + ? String.valueOf(result.get("secretQrCode")) + : String.valueOf(result.getOrDefault("qrCode", "")); + return new MfaSetupData(secret, qrCode); + } + + public boolean verifyTotpCode(String secret, String code) { + if (code == null || code.length() != 6) return false; + long currentStep = Instant.now().getEpochSecond() / 30; + for (int drift = -1; drift <= 1; drift++) { + String computed = computeTotp(secret, currentStep + drift); + if (code.equals(computed)) return true; + } + return false; + } + + public BackupCodesData generateBackupCodes(String userId) { + var result = logtoClient.createBackupCodes(userId); + @SuppressWarnings("unchecked") + List codes = (List) result.get("codes"); + return new BackupCodesData(codes != null ? codes : List.of()); + } + + public void removeMfa(String userId) { + var verifications = logtoClient.getUserMfaVerifications(userId); + for (var v : verifications) { + logtoClient.deleteMfaVerification(userId, String.valueOf(v.get("id"))); + } + } + + // --- Passkeys --- + + public List listPasskeys(String userId) { + var credentials = logtoClient.getWebAuthnCredentials(userId); + return credentials.stream() + .map(c -> new PasskeyCredential( + String.valueOf(c.get("id")), + c.get("name") != null ? String.valueOf(c.get("name")) : null, + c.get("agent") != null ? String.valueOf(c.get("agent")) : null, + c.get("createdAt") != null ? String.valueOf(c.get("createdAt")) : null + )) + .toList(); + } + + public void renamePasskey(String userId, String credentialId, String name) { + var credentials = logtoClient.getWebAuthnCredentials(userId); + boolean owns = credentials.stream() + .anyMatch(c -> credentialId.equals(String.valueOf(c.get("id")))); + if (!owns) { + throw new ResponseStatusException(HttpStatus.NOT_FOUND, "Passkey not found"); + } + logtoClient.renameMfaVerification(userId, credentialId, name); + } + + public void deletePasskey(String userId, String credentialId) { + var credentials = logtoClient.getWebAuthnCredentials(userId); + boolean owns = credentials.stream() + .anyMatch(c -> credentialId.equals(String.valueOf(c.get("id")))); + if (!owns) { + throw new ResponseStatusException(HttpStatus.NOT_FOUND, "Passkey not found"); + } + logtoClient.deleteMfaVerification(userId, credentialId); + } + + // --- MFA Preference --- + + public void setMfaMethodPreference(String userId, String preference) { + if (!"totp".equals(preference) && !"webauthn".equals(preference)) { + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid MFA preference: must be 'totp' or 'webauthn'"); + } + logtoClient.updateUserCustomData(userId, Map.of("mfa_method_preference", preference)); + } + + // --- TOTP helpers (moved from TenantPortalService) --- + + private String computeTotp(String base32Secret, long timeStep) { + try { + byte[] key = base32Decode(base32Secret); + byte[] data = ByteBuffer.allocate(8).putLong(timeStep).array(); + Mac mac = Mac.getInstance("HmacSHA1"); + mac.init(new SecretKeySpec(key, "HmacSHA1")); + byte[] hash = mac.doFinal(data); + int offset = hash[hash.length - 1] & 0x0F; + int code = ((hash[offset] & 0x7F) << 24) + | ((hash[offset + 1] & 0xFF) << 16) + | ((hash[offset + 2] & 0xFF) << 8) + | (hash[offset + 3] & 0xFF); + return String.format("%06d", code % 1_000_000); + } catch (Exception e) { + log.error("TOTP computation failed", e); + return ""; + } + } + + String base32Encode(byte[] data) { + StringBuilder sb = new StringBuilder(); + int buffer = 0, bitsLeft = 0; + for (byte b : data) { + buffer = (buffer << 8) | (b & 0xFF); + bitsLeft += 8; + while (bitsLeft >= 5) { + sb.append(BASE32_ALPHABET.charAt((buffer >> (bitsLeft - 5)) & 0x1F)); + bitsLeft -= 5; + } + } + if (bitsLeft > 0) { + sb.append(BASE32_ALPHABET.charAt((buffer << (5 - bitsLeft)) & 0x1F)); + } + return sb.toString(); + } + + byte[] base32Decode(String encoded) { + String clean = encoded.replaceAll("[=\\s]", "").toUpperCase(); + int byteCount = clean.length() * 5 / 8; + byte[] result = new byte[byteCount]; + int buffer = 0, bitsLeft = 0, index = 0; + for (char c : clean.toCharArray()) { + int val = BASE32_ALPHABET.indexOf(c); + if (val < 0) continue; + buffer = (buffer << 5) | val; + bitsLeft += 5; + if (bitsLeft >= 8) { + result[index++] = (byte) (buffer >> (bitsLeft - 8)); + bitsLeft -= 8; + } + } + return result; + } +}