Implemented Cache for decrypted private key and handled refresh token

This commit is contained in:
2025-07-25 13:36:15 +05:30
parent 2622667de4
commit 063bfa794a
10 changed files with 277 additions and 68 deletions
+11
View File
@@ -121,6 +121,17 @@
<artifactId>spring-boot-starter-test</artifactId> <artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<!-- Caching -->
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
<version>3.0.5</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-cache</artifactId>
</dependency>
</dependencies> </dependencies>
<build> <build>
@@ -0,0 +1,24 @@
package com.skycrate.backend.skycrateBackend.config;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import com.github.benmanes.caffeine.cache.Caffeine;
import org.springframework.cache.caffeine.CaffeineCacheManager;
import java.util.concurrent.TimeUnit;
@Configuration
@EnableCaching
public class CacheConfig {
@Bean
public CaffeineCacheManager cacheManager() {
CaffeineCacheManager cacheManager = new CaffeineCacheManager();
cacheManager.setCaffeine(Caffeine.newBuilder()
.expireAfterWrite(30, TimeUnit.MINUTES) // Cache expiry time
.maximumSize(100)); // Maximum cache size
return cacheManager;
}
}
@@ -29,7 +29,7 @@ public class SecurityConfig {
.sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) .sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
.authenticationProvider(authenticationProvider) .authenticationProvider(authenticationProvider)
.authorizeHttpRequests(auth -> auth .authorizeHttpRequests(auth -> auth
.requestMatchers("/api/auth/login", "/api/auth/register", "/actuator/**").permitAll() .requestMatchers("/api/auth/logout","/api/auth/login", "/api/auth/register", "/actuator/**").permitAll()
.requestMatchers(HttpMethod.GET, "/public/**").permitAll() .requestMatchers(HttpMethod.GET, "/public/**").permitAll()
.anyRequest().authenticated() .anyRequest().authenticated()
) )
@@ -14,6 +14,8 @@ import com.skycrate.backend.skycrateBackend.services.JwtService;
import com.skycrate.backend.skycrateBackend.services.RateLimiterService; import com.skycrate.backend.skycrateBackend.services.RateLimiterService;
import com.skycrate.backend.skycrateBackend.services.RefreshTokenService; import com.skycrate.backend.skycrateBackend.services.RefreshTokenService;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
@@ -84,6 +86,23 @@ public class AuthController {
return ResponseEntity.ok(new LoginResponse(accessToken, refreshToken.getToken())); return ResponseEntity.ok(new LoginResponse(accessToken, refreshToken.getToken()));
} }
// @PostMapping("/logout")
// public ResponseEntity<?> logout(HttpServletRequest request) {
// String authHeader = request.getHeader("Authorization");
// if (authHeader == null || !authHeader.startsWith("Bearer ")) {
// return ResponseEntity.badRequest().body("Missing or invalid Authorization header");
// }
//
// String token = authHeader.substring(7);
//
// tokenBlacklistService.blacklistToken(token);
//
// String email = jwtService.extractUsername(token);
// userRepository.findByEmail(email).ifPresent(refreshTokenService::deleteByUser);
//
// return ResponseEntity.ok("Logged out successfully");
// }
@PostMapping("/logout") @PostMapping("/logout")
public ResponseEntity<?> logout(HttpServletRequest request) { public ResponseEntity<?> logout(HttpServletRequest request) {
String authHeader = request.getHeader("Authorization"); String authHeader = request.getHeader("Authorization");
@@ -92,15 +111,38 @@ public class AuthController {
} }
String token = authHeader.substring(7); String token = authHeader.substring(7);
String username = jwtService.extractUsername(token);
userRepository.findByUsername(username).ifPresent(user -> {
// Clear the cached decrypted private key for the user
authenticationService.clearDecryptedPrivateKeyCache(user.getId().toString());
// Delete the refresh token associated with the user
refreshTokenService.logout(user); // This should delete the token
});
tokenBlacklistService.blacklistToken(token); tokenBlacklistService.blacklistToken(token);
String email = jwtService.extractUsername(token);
userRepository.findByEmail(email).ifPresent(refreshTokenService::deleteByUser);
return ResponseEntity.ok("Logged out successfully"); return ResponseEntity.ok("Logged out successfully");
} }
// @PostMapping("/refresh")
// public ResponseEntity<?> refresh(@RequestBody TokenRefreshRequest request) {
// String requestToken = request.getRefreshToken();
//
// return refreshTokenService.findByToken(requestToken)
// .map(token -> {
// if (refreshTokenService.isExpired(token)) {
// return ResponseEntity.status(403).body("Refresh token expired");
// }
//
// User user = token.getUser();
// String newAccessToken = jwtService.generateToken(user);
// return ResponseEntity.ok(new TokenRefreshResponse(newAccessToken, requestToken));
// })
// .orElseGet(() -> ResponseEntity.status(403).body("Invalid refresh token"));
// }
@PostMapping("/refresh") @PostMapping("/refresh")
public ResponseEntity<?> refresh(@RequestBody TokenRefreshRequest request) { public ResponseEntity<?> refresh(@RequestBody TokenRefreshRequest request) {
String requestToken = request.getRefreshToken(); String requestToken = request.getRefreshToken();
@@ -108,6 +150,8 @@ public class AuthController {
return refreshTokenService.findByToken(requestToken) return refreshTokenService.findByToken(requestToken)
.map(token -> { .map(token -> {
if (refreshTokenService.isExpired(token)) { if (refreshTokenService.isExpired(token)) {
// Clear the cached key on token expiry
authenticationService.clearDecryptedPrivateKeyCache(token.getUser().getId().toString());
return ResponseEntity.status(403).body("Refresh token expired"); return ResponseEntity.status(403).body("Refresh token expired");
} }
@@ -15,4 +15,5 @@ public interface RefreshTokenRepository extends JpaRepository<RefreshToken, Long
@Modifying @Modifying
@Query("DELETE FROM RefreshToken t WHERE t.user = :user") @Query("DELETE FROM RefreshToken t WHERE t.user = :user")
void deleteByUser(User user); void deleteByUser(User user);
}
}
@@ -9,6 +9,8 @@ import com.skycrate.backend.skycrateBackend.utils.EncryptionUtil;
import com.skycrate.backend.skycrateBackend.utils.RSAKeyUtil; import com.skycrate.backend.skycrateBackend.utils.RSAKeyUtil;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder;
@@ -17,6 +19,8 @@ import org.springframework.stereotype.Service;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Service @Service
public class AuthenticationService { public class AuthenticationService {
@@ -24,13 +28,18 @@ public class AuthenticationService {
private final UserRepository userRepository; private final UserRepository userRepository;
private final PasswordEncoder passwordEncoder; private final PasswordEncoder passwordEncoder;
private final AuthenticationManager authenticationManager; private final AuthenticationManager authenticationManager;
private final KeyCacheService keyCacheService;
private static final Logger log = LoggerFactory.getLogger(AuthenticationService.class);
public AuthenticationService(UserRepository userRepository, public AuthenticationService(UserRepository userRepository,
AuthenticationManager authenticationManager, AuthenticationManager authenticationManager,
PasswordEncoder passwordEncoder) { PasswordEncoder passwordEncoder,
KeyCacheService keyCacheService) {
this.userRepository = userRepository; this.userRepository = userRepository;
this.passwordEncoder = passwordEncoder; this.passwordEncoder = passwordEncoder;
this.authenticationManager = authenticationManager; this.authenticationManager = authenticationManager;
this.keyCacheService = keyCacheService;
} }
public User signUp(RegisterUserDto inputUser) { public User signUp(RegisterUserDto inputUser) {
@@ -90,4 +99,23 @@ public class AuthenticationService {
return userRepository.findByEmail(inputUser.getEmail()) return userRepository.findByEmail(inputUser.getEmail())
.orElseThrow(() -> new RuntimeException("User not found")); .orElseThrow(() -> new RuntimeException("User not found"));
} }
@Cacheable(value = "decryptedPrivateKeys", key = "#userId")
public byte[] getDecryptedPrivateKey(String userId, String password) throws Exception {
User user = userRepository.findById(Integer.valueOf(userId))
.orElseThrow(() -> new RuntimeException("User not found: " + userId));
log.info("Caching decrypted private key for userId: {}", userId);
SecretKey derivedKey = EncryptionUtil.deriveKey(password.toCharArray(), user.getPrivateKeySalt());
byte[] decryptedPrivateKeyBytes = EncryptionUtil.decrypt(user.getPrivateKey(), derivedKey, user.getPrivateKeyIv());
return decryptedPrivateKeyBytes;
}
@CacheEvict(value = "decryptedPrivateKeys", key = "#userId")
public void clearDecryptedPrivateKeyCache(String userId) {
// This method will clear the cached decrypted private key for the given userId
log.info("Clearing Caching decrypted private key for userId: {}", userId);
keyCacheService.clearKey(Long.valueOf(userId));
}
} }
@@ -17,63 +17,63 @@ public class EncryptionUtil {
private static final int IV_LENGTH = 16; // for AES CBC private static final int IV_LENGTH = 16; // for AES CBC
private static final int ITERATIONS = 65536; private static final int ITERATIONS = 65536;
private static final int KEY_LENGTH = 256; // bits private static final int KEY_LENGTH = 256; // bits
//
// // --- AES key derivation using PBKDF2 --- // --- AES key derivation using PBKDF2 ---
// public static SecretKey deriveAESKey(char[] password, byte[] salt) public static SecretKey deriveAESKey(char[] password, byte[] salt)
// throws NoSuchAlgorithmException, InvalidKeySpecException { throws NoSuchAlgorithmException, InvalidKeySpecException {
//
// SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256"); SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
//
// KeySpec spec = new PBEKeySpec(password, salt, ITERATIONS, KEY_LENGTH); KeySpec spec = new PBEKeySpec(password, salt, ITERATIONS, KEY_LENGTH);
// byte[] keyBytes = factory.generateSecret(spec).getEncoded(); byte[] keyBytes = factory.generateSecret(spec).getEncoded();
//
// return new SecretKeySpec(keyBytes, "AES"); return new SecretKeySpec(keyBytes, "AES");
// } }
//
// // --- Encrypt data using AES-CBC --- // --- Encrypt data using AES-CBC ---
// public static byte[] encrypt(byte[] data, SecretKey key, byte[] iv) public static byte[] encrypt(byte[] data, SecretKey key, byte[] iv)
// throws GeneralSecurityException { throws GeneralSecurityException {
//
// Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
//
// IvParameterSpec ivSpec = new IvParameterSpec(iv); IvParameterSpec ivSpec = new IvParameterSpec(iv);
// cipher.init(Cipher.ENCRYPT_MODE, key, ivSpec); cipher.init(Cipher.ENCRYPT_MODE, key, ivSpec);
//
// return cipher.doFinal(data); return cipher.doFinal(data);
// } }
// --- Decrypt data using AES-CBC --- // --- Decrypt data using AES-CBC ---
// public static byte[] decrypt(byte[] encryptedData, SecretKey key, byte[] iv) public static byte[] decrypt(byte[] encryptedData, SecretKey key, byte[] iv)
// throws GeneralSecurityException { throws GeneralSecurityException {
//
// Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
//
// IvParameterSpec ivSpec = new IvParameterSpec(iv);
// cipher.init(Cipher.DECRYPT_MODE, key, ivSpec);
//
// return cipher.doFinal(encryptedData);
// }
//
// // --- Generate random salt ---
// public static byte[] generateSalt() {
// byte[] salt = new byte[SALT_LENGTH];
// new SecureRandom().nextBytes(salt);
// return salt;
// }
// // --- Generate random IV --- Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
// public static byte[] generateIV() {
// byte[] iv = new byte[IV_LENGTH]; IvParameterSpec ivSpec = new IvParameterSpec(iv);
// new SecureRandom().nextBytes(iv); cipher.init(Cipher.DECRYPT_MODE, key, ivSpec);
// return iv;
// } return cipher.doFinal(encryptedData);
// }
// // --- Optional: Utility to base64 encode data ---
// public static String encodeBase64(byte[] data) { // --- Generate random salt ---
// return Base64.getEncoder().encodeToString(data); public static byte[] generateSalt() {
// } byte[] salt = new byte[SALT_LENGTH];
// new SecureRandom().nextBytes(salt);
// public static byte[] decodeBase64(String base64) { return salt;
// return Base64.getDecoder().decode(base64); }
// }
// --- Generate random IV ---
public static byte[] generateIV() {
byte[] iv = new byte[IV_LENGTH];
new SecureRandom().nextBytes(iv);
return iv;
}
// --- Optional: Utility to base64 encode data ---
public static String encodeBase64(byte[] data) {
return Base64.getEncoder().encodeToString(data);
}
public static byte[] decodeBase64(String base64) {
return Base64.getDecoder().decode(base64);
}
} }
@@ -13,7 +13,9 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
@@ -24,15 +26,23 @@ import java.security.PublicKey;
public class FileService { public class FileService {
private static final Logger log = LoggerFactory.getLogger(FileService.class); private static final Logger log = LoggerFactory.getLogger(FileService.class);
private final AuthenticationService authenticationService;
private final FileMetadataRepository fileMetadataRepository; private final FileMetadataRepository fileMetadataRepository;
private final UserRepository userRepository; private final UserRepository userRepository;
public FileService(FileMetadataRepository fileMetadataRepository, UserRepository userRepository) { // public FileService(FileMetadataRepository fileMetadataRepository, UserRepository userRepository) {
// this.fileMetadataRepository = fileMetadataRepository;
// this.userRepository = userRepository;
// }
@Autowired
public FileService(FileMetadataRepository fileMetadataRepository, UserRepository userRepository, AuthenticationService authenticationService) {
this.fileMetadataRepository = fileMetadataRepository; this.fileMetadataRepository = fileMetadataRepository;
this.userRepository = userRepository; this.userRepository = userRepository;
this.authenticationService = authenticationService;
} }
@Transactional
public void uploadEncryptedFile(String username, byte[] fileContent, String filename) throws Exception { public void uploadEncryptedFile(String username, byte[] fileContent, String filename) throws Exception {
log.info("Starting upload for user={}, file={}", username, filename); log.info("Starting upload for user={}, file={}", username, filename);
try { try {
@@ -81,6 +91,37 @@ public class FileService {
} }
} }
// public byte[] downloadDecryptedFile(String username, String password, String filename) throws Exception {
// log.info("Download request: user={}, file={}", username, filename);
// try {
// User user = userRepository.findByUsername(username)
// .orElseThrow(() -> new RuntimeException("User not found: " + username));
//
// Path filePath = new Path("/" + username + "/" + filename);
// FileMetadata metadata = fileMetadataRepository.findByUsernameAndFilePath(username, filePath.toString())
// .orElseThrow(() -> new RuntimeException("File metadata not found for: " + filePath));
//
// SecretKey derivedKey = EncryptionUtil.deriveKey(password.toCharArray(), user.getPrivateKeySalt());
// byte[] decryptedPrivateKeyBytes = EncryptionUtil.decrypt(user.getPrivateKey(), derivedKey, user.getPrivateKeyIv());
// PrivateKey privateKey = RSAKeyUtil.decodePrivateKey(decryptedPrivateKeyBytes);
//
// byte[] aesKeyBytes = EncryptionUtil.decryptRSA(metadata.getEncryptedKey(), privateKey);
// SecretKey aesKey = EncryptionUtil.rebuildAESKey(aesKeyBytes);
//
// FileSystem fs = HDFSConfig.getHDFS();
// byte[] encryptedData;
// try (FSDataInputStream in = fs.open(filePath)) {
// encryptedData = in.readAllBytes();
// }
//
// return EncryptionUtil.decrypt(encryptedData, aesKey, metadata.getIv());
//
// } catch (Exception e) {
// log.error("Download failed for user={}, file={}: {}", username, filename, e.getMessage(), e);
// throw e;
// }
// }
public byte[] downloadDecryptedFile(String username, String password, String filename) throws Exception { public byte[] downloadDecryptedFile(String username, String password, String filename) throws Exception {
log.info("Download request: user={}, file={}", username, filename); log.info("Download request: user={}, file={}", username, filename);
try { try {
@@ -91,8 +132,8 @@ public class FileService {
FileMetadata metadata = fileMetadataRepository.findByUsernameAndFilePath(username, filePath.toString()) FileMetadata metadata = fileMetadataRepository.findByUsernameAndFilePath(username, filePath.toString())
.orElseThrow(() -> new RuntimeException("File metadata not found for: " + filePath)); .orElseThrow(() -> new RuntimeException("File metadata not found for: " + filePath));
SecretKey derivedKey = EncryptionUtil.deriveKey(password.toCharArray(), user.getPrivateKeySalt()); // Use the cached decrypted private key
byte[] decryptedPrivateKeyBytes = EncryptionUtil.decrypt(user.getPrivateKey(), derivedKey, user.getPrivateKeyIv()); byte[] decryptedPrivateKeyBytes = authenticationService.getDecryptedPrivateKey(String.valueOf(user.getId()), password);
PrivateKey privateKey = RSAKeyUtil.decodePrivateKey(decryptedPrivateKeyBytes); PrivateKey privateKey = RSAKeyUtil.decodePrivateKey(decryptedPrivateKeyBytes);
byte[] aesKeyBytes = EncryptionUtil.decryptRSA(metadata.getEncryptedKey(), privateKey); byte[] aesKeyBytes = EncryptionUtil.decryptRSA(metadata.getEncryptedKey(), privateKey);
@@ -0,0 +1,28 @@
package com.skycrate.backend.skycrateBackend.services;
import org.springframework.stereotype.Service;
import java.util.concurrent.ConcurrentHashMap;
@Service
public class KeyCacheService {
private final ConcurrentHashMap<Long, String> keyCache = new ConcurrentHashMap<>();
public void cacheKey(Long userId, String decryptedKey) {
keyCache.put(userId, decryptedKey);
}
public String getKey(Long userId) {
return keyCache.get(userId);
}
public void clearKey(Long userId) {
keyCache.remove(userId);
}
public void clearAllKeys() {
keyCache.clear();
}
}
@@ -16,13 +16,25 @@ public class RefreshTokenService {
private final RefreshTokenRepository refreshTokenRepo; private final RefreshTokenRepository refreshTokenRepo;
@Value("${security.jwt.refresh-expiry-ms:604800000}") // 7 days default @Value("${security.jwt.refresh-expiry-ms:86400000}") //1 day in milliseconds
private Long refreshTokenDurationMs; private Long refreshTokenDurationMs;
public RefreshTokenService(RefreshTokenRepository refreshTokenRepo) { public RefreshTokenService(RefreshTokenRepository refreshTokenRepo) {
this.refreshTokenRepo = refreshTokenRepo; this.refreshTokenRepo = refreshTokenRepo;
} }
// @Transactional
// public RefreshToken createRefreshToken(User user) {
// refreshTokenRepo.deleteByUser(user);
// refreshTokenRepo.flush();
//
// RefreshToken token = new RefreshToken();
// token.setUser(user);
// token.setExpiryDate(Instant.now().plusMillis(refreshTokenDurationMs));
// token.setToken(UUID.randomUUID().toString());
// return refreshTokenRepo.save(token);
// }
@Transactional @Transactional
public RefreshToken createRefreshToken(User user) { public RefreshToken createRefreshToken(User user) {
refreshTokenRepo.deleteByUser(user); refreshTokenRepo.deleteByUser(user);
@@ -35,6 +47,7 @@ public class RefreshTokenService {
return refreshTokenRepo.save(token); return refreshTokenRepo.save(token);
} }
public Optional<RefreshToken> findByToken(String token) { public Optional<RefreshToken> findByToken(String token) {
return refreshTokenRepo.findByToken(token); return refreshTokenRepo.findByToken(token);
} }
@@ -42,9 +55,28 @@ public class RefreshTokenService {
public boolean isExpired(RefreshToken token) { public boolean isExpired(RefreshToken token) {
return token.getExpiryDate().isBefore(Instant.now()); return token.getExpiryDate().isBefore(Instant.now());
} }
//
// @Transactional
// public void deleteByUser(User user) {
// refreshTokenRepo.deleteByUser(user);
// }
@Transactional @Transactional
public void deleteByUser(User user) { public void deleteByUser(User user) {
refreshTokenRepo.deleteByUser(user); try {
refreshTokenRepo.deleteByUser(user);
System.out.println("Successfully deleted refresh tokens for user: " + user.getId());
} catch (Exception e) {
System.err.println("Error deleting refresh tokens for user: " + user.getId() + " - " + e.getMessage());
}
}
@Transactional
public void logout(User user) {
deleteByUser(user); // This should call the repository method to delete the token
}
public Optional<RefreshToken> refreshAccessToken(String refreshToken) {
return findByToken(refreshToken).filter(token -> !isExpired(token));
} }
} }