Extension
- JUnit 5에서 단위 테스트 코드를 확장하고 사용자 정의 기능을 추가할 수 있는 메커니즘
- Test class, method를 확장할 수 있다.
- 테스트 이벤트, Life Cycle에 관여하게 된다.
- 우리가 사용하는
@SpringBootTest,@WebMvcTest,@DataJpaTest에도@ExtendWith(SpringExtension.clas)가 포함되어 있다.
- JUnit 4의 Runner는 단일 상속 구조라 여러 기능 조합이 불가능 했음
- Store를 통한 상태 공유
SpringExtension
public class SpringExtension implements BeforeAllCallback,
AfterAllCallback,TestInstancePostProcessor, BeforeEachCallback,
AfterEachCallback,BeforeTestExecutionCallback,
AfterTestExecutionCallback, ParameterResolver {
Extension 실행 순서
- [BeforeAllCallback]
@BeforeAll- 테스트 인스턴스 생성
- [TestInstancePostProcessor] -> SpringExtesion이 DI 수행
- [BeforeEachCallback]
@BeforeEach- [BeforeTestExecutionCallback]
@Test실행- [AfterTestExecutionCallback]
@AfterEach- [AfterEachCallback]
- [TestInstancePreDestroyCallback]
@AfterAll- [AfterAllCallback]
ExtensionContext
// org.junit.jupiter.api.extension.ExtensionContext
public interface ExtensionContext {
// 테스트 메타데이터
Optional<Class<?>> getTestClass();
Optional<Object> getTestInstance();
Optional<Method> getTestMethod();
// 계층 구조
Optional<ExtensionContext> getParent();
ExtensionContext getRoot();
// 저장소 (Extension 간 데이터 공유)
Store getStore(Namespace namespace);
// 설정 파라미터
Optional<String> getConfigurationParameter(String key);
}
Store 역할
// Extension 간 상태 공유
public class MyExtension implements BeforeEachCallback {
@Override
public void beforeEach(ExtensionContext context) {
Store store = context.getStore(
Namespace.create(MyExtension.class)
);
// 데이터 저장 (테스트 생명주기 동안 유지)
store.put("key", new MyResource());
}
}
SpringExtension
SpringExtension 기본 구조
public class SpringExtension implements
BeforeAllCallback, //클래스 레벨 초기화
AfterAllCallback, //클래스 레벨 정리
TestInstancePostProcessor, //DI 수행
BeforeEachCallback, //메소드 전 준비
AfterEachCallback, //메소드 후 정리
BeforeTestExecutionCallback, //트랜잭션 시작 등
AfterTestExecutionCallback, //트랜잭션 커밋/롤백
ParameterResolver { //메소드 파라미터 주입
private static final ExtensionContext.Namespace TEST_CONTEXT_MANAGER_NAMESPACE = Namespace.create(new Object[]{SpringExtension.class});
private static final ExtensionContext.Namespace AUTOWIRED_VALIDATION_NAMESPACE = Namespace.create(new Object[]{SpringExtension.class.getName() + "#autowired.validation"});
private static final String NO_VIOLATIONS_DETECTED = "";
private static final ExtensionContext.Namespace RECORD_APPLICATION_EVENTS_VALIDATION_NAMESPACE = Namespace.create(new Object[]{SpringExtension.class.getName() + "#recordApplicationEvents.validation"});
private static final List<Class<? extends Annotation>> JUPITER_ANNOTATION_TYPES = List.of(BeforeAll.class, AfterAll.class, BeforeEach.class, AfterEach.class, Testable.class);
private static final ReflectionUtils.MethodFilter autowiredTestOrLifecycleMethodFilter;
책임
- JUnit-Spring 브릿지: JUnit 생명주기 → Spring TestContext Framework 연결
- TestContextManager 관리 : Spring 테스트 생명 주기 관리 → 클래스당 1개 생성 및 캐싱 → ExtensionContext.Store 활용
- ApplicationContext 로딩 조율 → 캐시 확인 -> 재사용 혹은 새로 생성 → ContextLoader 전략 위임
- DI 위임 → TestInstancePostProcessor에서 수행 → DependencyInjectionTestExecutionListner 활용
- 트랜잭션 관리 → BeforeTestExtension에서 시작 → AfterTestExecution에서 롤백
TestInstancePostProcessor
public void postProcessTestInstance(Object testInstance, ExtensionContext context) throws Exception {
this.validateAutowiredConfig(context);
this.validateRecordApplicationEventsConfig(context);
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.prepareTestInstance(testInstance); // DI를 처리
}
private void validateAutowiredConfig(ExtensionContext context) {
ExtensionContext.Store store = context.getStore(AUTOWIRED_VALIDATION_NAMESPACE);
String errorMessage = (String)store.getOrComputeIfAbsent(context.getRequiredTestClass(), (testClass) -> {
Method[] methodsWithErrors = ReflectionUtils.getUniqueDeclaredMethods(testClass, autowiredTestOrLifecycleMethodFilter);
return methodsWithErrors.length == 0 ? "" : String.format("Test methods and test lifecycle methods must not be annotated with @Autowired. You should instead annotate individual method parameters with @Autowired, @Qualifier, or @Value. Offending methods in test class %s: %s", testClass.getName(), Arrays.toString(methodsWithErrors));
}, String.class);
if (!errorMessage.isEmpty()) {
throw new IllegalStateException(errorMessage);
}
}
private void validateRecordApplicationEventsConfig(ExtensionContext context) {
ExtensionContext.Store store = context.getStore(RECORD_APPLICATION_EVENTS_VALIDATION_NAMESPACE);
String errorMessage = (String)store.getOrComputeIfAbsent(context.getRequiredTestClass(), (testClass) -> {
boolean recording = TestContextAnnotationUtils.hasAnnotation(testClass, RecordApplicationEvents.class);
if (!recording) {
return "";
} else if (context.getTestInstanceLifecycle().orElse(Lifecycle.PER_METHOD) == Lifecycle.PER_METHOD) {
return "";
} else {
return context.getExecutionMode() == ExecutionMode.SAME_THREAD ? "" : "Test classes or @Nested test classes that @RecordApplicationEvents must not be run in parallel with the @TestInstance(PER_CLASS) lifecycle mode. Configure either @Execution(SAME_THREAD) or @TestInstance(PER_METHOD) semantics, or disable parallel execution altogether. Note that when recording events in parallel, one might see events published by other tests since the application context may be shared.";
}
}, String.class);
if (!errorMessage.isEmpty()) {
throw new IllegalStateException(errorMessage);
}
}
prepareTestInstance
public void prepareTestInstance(Object testInstance) throws Exception {
try {
this.getTestContext().updateState(testInstance, (Method)null, (Throwable)null);
Iterator var2 = this.getTestExecutionListeners().iterator();
while(var2.hasNext()) {
TestExecutionListener testExecutionListener = (TestExecutionListener)var2.next();
try {
testExecutionListener.prepareTestInstance(this.getTestContext());
//각 Listener 실행
} catch (Throwable var8) {
//생략
}
}
}
}
/**
*1. ServletTestExecutionListener // Mock ServletContext 준비
*2. DirtiesContextBeforeModesTestExecutionListener
*3. ApplicationEventsTestExecutionListener
*4. DependencyInjectionTestExecutionListener // ★ DI 수행
*5. DirtiesContextTestExecutionListener
*6. TransactionalTestExecutionListener // 트랜잭션 설정
*7. SqlScriptsTestExecutionListener // @Sql 실행
*8. EventPublishingTestExecutionListener
*/
DependencyInjectionTestExecutionListener
//org.springframework.spring-test
package org.springframework.test.context.support;
public class DependencyInjectionTestExecutionListener extends AbstractTestExecutionListener {
public static final int ORDER = 2000;
public static final String REINJECT_DEPENDENCIES_ATTRIBUTE = Conventions.getQualifiedAttributeName(DependencyInjectionTestExecutionListener.class, "reinjectDependencies");
private static final Log logger = LogFactory.getLog(DependencyInjectionTestExecutionListener.class);
private final AotTestContextInitializers aotTestContextInitializers = new AotTestContextInitializers();
public DependencyInjectionTestExecutionListener() {
}
public final int getOrder() {
return 2000;
}
public void prepareTestInstance(TestContext testContext) throws Exception {
//생략
//AOT 모드 확인 후 적절한 DI 전략 선택
//Ahead-Of-Time: Spring Application의 시작 속도를 높이고 메모리 사용량을 줄이기 위한 기술, 프로그램 실행 전 바이트 코드를 미리 네이티브 이미지로 컴파일하는 방식
if (this.runningInAotMode(testContext.getTestClass())) {
this.injectDependenciesInAotMode(testContext);
} else {
this.injectDependencies(testContext);
}
}
public void beforeTestMethod(TestContext testContext) throws Exception {
if (Boolean.TRUE.equals(testContext.getAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE))) {
if (logger.isTraceEnabled()) {
logger.trace("Reinjecting dependencies for test context " + String.valueOf(testContext));
} else if (logger.isDebugEnabled()) {
logger.debug("Reinjecting dependencies for test class " + testContext.getTestClass().getName());
}
if (this.runningInAotMode(testContext.getTestClass())) {
this.injectDependenciesInAotMode(testContext);
} else {
this.injectDependencies(testContext);
}
}
}
//일반 DI 모드
protected void injectDependencies(TestContext testContext) throws Exception {
//1. 테스트 인스턴스와 클래스 정보 추출
Object bean = testContext.getTestInstance();
Class<?> clazz = testContext.getTestClass();
//2. AutowireCapableBeanFactory 획득
AutowireCapableBeanFactory beanFactory = testContext.getApplicationContext().getAutowireCapableBeanFactory();
//3. 의존성 주입
//3-1. 필드/ Setter 주입
beanFactory.autowireBeanProperties(bean, 0, false);
//3-2. initializeBean: Bean 후처리
beanFactory.initializeBean(bean, clazz.getName() + ".ORIGINAL");
//3-3. 재주입 플래그 제거
testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE);
}
//4. AOT 모드 DI
private void injectDependenciesInAotMode(TestContext testContext) throws Exception {
ApplicationContext applicationContext = testContext.getApplicationContext();
// GeneralApplicationContext만 지원
if (applicationContext instanceof GenericApplicationContext gac) {
Object bean = testContext.getTestInstance();
String beanName = testContext.getTestClass().getName() + ".ORIGINAL";
ConfigurableListableBeanFactory beanFactory = gac.getBeanFactory();
//@Autowire처리
AutowiredAnnotationBeanPostProcessor autowiredAnnotationBpp = new AutowiredAnnotationBeanPostProcessor();
autowiredAnnotationBpp.setBeanFactory(beanFactory);
autowiredAnnotationBpp.processInjection(bean);
//@Resource, @PostConstructor
CommonAnnotationBeanPostProcessor commonAnnotationBpp = new CommonAnnotationBeanPostProcessor();
commonAnnotationBpp.setBeanFactory(beanFactory);
commonAnnotationBpp.processInjection(bean);
//Bean 초기화
beanFactory.initializeBean(bean, beanName);
//재주입
testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE);
} else {
}
}
private boolean runningInAotMode(Class<?> testClass) {
return this.aotTestContextInitializers.isSupportedTestClass(testClass);
}
}
BeforeAllCallback
@Override
public void beforeAll(ExtensionContext context) throws Exception {
// 1. TestContextManager 생성 (클래스당 1회)
getTestContextManager(context)
.beforeTestClass(); // TestExecutionListener들에게 알림
}
private static TestContextManager getTestContextManager(
ExtensionContext context) {
Class<?> testClass = context.getRequiredTestClass();
Store store = context.getStore(NAMESPACE);
// Store에 캐싱 (클래스 생명주기 동안 유지)
return store.getOrComputeIfAbsent(
testClass,
key -> new TestContextManager(testClass),
TestContextManager.class
);
}