diff --git a/src/main/java/net/siegeln/cameleer/saas/config/SecurityConfig.java b/src/main/java/net/siegeln/cameleer/saas/config/SecurityConfig.java index b9047ad..5b19438 100644 --- a/src/main/java/net/siegeln/cameleer/saas/config/SecurityConfig.java +++ b/src/main/java/net/siegeln/cameleer/saas/config/SecurityConfig.java @@ -10,6 +10,7 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.server.resource.web.authentication.BearerTokenAuthenticationFilter; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; @@ -19,9 +20,11 @@ import org.springframework.security.web.authentication.UsernamePasswordAuthentic public class SecurityConfig { private final JwtAuthenticationFilter machineTokenFilter; + private final TenantResolutionFilter tenantResolutionFilter; - public SecurityConfig(JwtAuthenticationFilter machineTokenFilter) { + public SecurityConfig(JwtAuthenticationFilter machineTokenFilter, TenantResolutionFilter tenantResolutionFilter) { this.machineTokenFilter = machineTokenFilter; + this.tenantResolutionFilter = tenantResolutionFilter; } @Bean @@ -50,7 +53,8 @@ public class SecurityConfig { .anyRequest().authenticated() ) .oauth2ResourceServer(oauth2 -> oauth2.jwt(jwt -> {})) - .addFilterBefore(machineTokenFilter, UsernamePasswordAuthenticationFilter.class); + .addFilterBefore(machineTokenFilter, UsernamePasswordAuthenticationFilter.class) + .addFilterAfter(tenantResolutionFilter, BearerTokenAuthenticationFilter.class); return http.build(); } diff --git a/src/main/java/net/siegeln/cameleer/saas/config/TenantContext.java b/src/main/java/net/siegeln/cameleer/saas/config/TenantContext.java new file mode 100644 index 0000000..0c94585 --- /dev/null +++ b/src/main/java/net/siegeln/cameleer/saas/config/TenantContext.java @@ -0,0 +1,22 @@ +package net.siegeln.cameleer.saas.config; + +import java.util.UUID; + +public final class TenantContext { + + private static final ThreadLocal CURRENT_TENANT = new ThreadLocal<>(); + + private TenantContext() {} + + public static UUID getTenantId() { + return CURRENT_TENANT.get(); + } + + public static void setTenantId(UUID tenantId) { + CURRENT_TENANT.set(tenantId); + } + + public static void clear() { + CURRENT_TENANT.remove(); + } +} diff --git a/src/main/java/net/siegeln/cameleer/saas/config/TenantResolutionFilter.java b/src/main/java/net/siegeln/cameleer/saas/config/TenantResolutionFilter.java new file mode 100644 index 0000000..90f7fb0 --- /dev/null +++ b/src/main/java/net/siegeln/cameleer/saas/config/TenantResolutionFilter.java @@ -0,0 +1,47 @@ +package net.siegeln.cameleer.saas.config; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import net.siegeln.cameleer.saas.tenant.TenantService; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; + +@Component +public class TenantResolutionFilter extends OncePerRequestFilter { + + private final TenantService tenantService; + + public TenantResolutionFilter(TenantService tenantService) { + this.tenantService = tenantService; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + try { + var authentication = SecurityContextHolder.getContext().getAuthentication(); + + if (authentication instanceof JwtAuthenticationToken jwtAuth) { + Jwt jwt = jwtAuth.getToken(); + String orgId = jwt.getClaimAsString("organization_id"); + + if (orgId != null) { + tenantService.getByLogtoOrgId(orgId) + .ifPresent(tenant -> TenantContext.setTenantId(tenant.getId())); + } + } + + filterChain.doFilter(request, response); + } finally { + TenantContext.clear(); + } + } +}