diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java index 25db19452..58e54c6ba 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2008-2013 the original author or authors. + * Copyright 2008-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,7 @@ import org.springframework.util.Assert; * * @param the type of the repository * @author Oliver Gierke + * @author Thomas Darimont */ public abstract class RepositoryFactoryBeanSupport, S, ID extends Serializable> implements InitializingBean, RepositoryFactoryInformation, FactoryBean, BeanClassLoaderAware { @@ -56,6 +57,8 @@ public abstract class RepositoryFactoryBeanSupport, private T repository; + private RepositoryMetadata repositoryMetadata; + /** * Setter to inject the repository interface to implement. * @@ -130,7 +133,6 @@ public abstract class RepositoryFactoryBeanSupport, @SuppressWarnings("unchecked") public EntityInformation getEntityInformation() { - RepositoryMetadata repositoryMetadata = factory.getRepositoryMetadata(repositoryInterface); return (EntityInformation) factory.getEntityInformation(repositoryMetadata.getDomainType()); } @@ -140,9 +142,8 @@ public abstract class RepositoryFactoryBeanSupport, */ public RepositoryInformation getRepositoryInformation() { - RepositoryMetadata metadata = factory.getRepositoryMetadata(repositoryInterface); - return this.factory.getRepositoryInformation(metadata, - customImplementation == null ? null : customImplementation.getClass()); + return this.factory.getRepositoryInformation(repositoryMetadata, customImplementation == null ? null + : customImplementation.getClass()); } /* @@ -155,8 +156,7 @@ public abstract class RepositoryFactoryBeanSupport, return null; } - RepositoryMetadata metadata = factory.getRepositoryMetadata(repositoryInterface); - return mappingContext.getPersistentEntity(metadata.getDomainType()); + return mappingContext.getPersistentEntity(repositoryMetadata.getDomainType()); } /* (non-Javadoc) @@ -197,11 +197,15 @@ public abstract class RepositoryFactoryBeanSupport, */ public void afterPropertiesSet() { + Assert.notNull(repositoryInterface, "Repository interface must not be null on initialization!"); + this.factory = createRepositoryFactory(); this.factory.setQueryLookupStrategyKey(queryLookupStrategyKey); this.factory.setNamedQueries(namedQueries); this.factory.setBeanClassLoader(classLoader); + this.repositoryMetadata = this.factory.getRepositoryMetadata(repositoryInterface); + if (!lazyInit) { initAndReturn(); } diff --git a/src/main/java/org/springframework/data/repository/support/Repositories.java b/src/main/java/org/springframework/data/repository/support/Repositories.java index 82529d796..1682e1696 100644 --- a/src/main/java/org/springframework/data/repository/support/Repositories.java +++ b/src/main/java/org/springframework/data/repository/support/Repositories.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2013 the original author or authors. + * Copyright 2012-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,8 @@ package org.springframework.data.repository.support; import java.io.Serializable; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -43,22 +41,25 @@ import org.springframework.util.ClassUtils; * Wrapper class to access repository instances obtained from a {@link ListableBeanFactory}. * * @author Oliver Gierke + * @author Thomas Darimont */ public class Repositories implements Iterable> { static final Repositories NONE = new Repositories(); - - private final Map, RepositoryFactoryInformation> domainClassToBeanName = new HashMap, RepositoryFactoryInformation>(); - private final Map, String> repositories = new HashMap, String>(); + private static final RepositoryFactoryInformation EMPTY_REPOSITORY_FACTORY_INFO = EmptyRepositoryFactoryInformation.INSTANCE; private final BeanFactory beanFactory; - private final Set repositoryFactoryBeanNames = new HashSet(); + private final Map, String> repositoryBeanNames; + private final Map, RepositoryFactoryInformation> repositoryFactoryInfos; /** * Constructor to create the {@link #NONE} instance. */ private Repositories() { + this.beanFactory = null; + this.repositoryBeanNames = Collections., String> emptyMap(); + this.repositoryFactoryInfos = Collections., RepositoryFactoryInformation> emptyMap(); } /** @@ -70,11 +71,31 @@ public class Repositories implements Iterable> { public Repositories(ListableBeanFactory factory) { Assert.notNull(factory); + this.beanFactory = factory; + this.repositoryFactoryInfos = new HashMap, RepositoryFactoryInformation>(); + this.repositoryBeanNames = new HashMap, String>(); + + populateRepositoryFactoryInformation(factory); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void populateRepositoryFactoryInformation(ListableBeanFactory factory) { - String[] beanNamesForType = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(factory, - RepositoryFactoryInformation.class, false, false); - this.repositoryFactoryBeanNames.addAll(Arrays.asList(beanNamesForType)); + Set> repositoryFactoryBeans = BeanFactoryUtils + .beansOfTypeIncludingAncestors(factory, RepositoryFactoryInformation.class).entrySet(); + + for (Map.Entry entry : repositoryFactoryBeans) { + + String beanName = entry.getKey(); + + RepositoryFactoryInformation repositoryFactoryInformation = entry.getValue(); + Class userDomainType = ClassUtils.getUserClass(repositoryFactoryInformation.getRepositoryInformation() + .getDomainType()); + + this.repositoryFactoryInfos.put(userDomainType, repositoryFactoryInformation); + this.repositoryBeanNames.put(userDomainType, BeanFactoryUtils.transformedBeanName(beanName)); + } } /** @@ -84,8 +105,10 @@ public class Repositories implements Iterable> { * @return */ public boolean hasRepositoryFor(Class domainClass) { - lookupRepositoryFactoryInformationFor(domainClass); - return domainClassToBeanName.containsKey(domainClass); + + Assert.notNull(domainClass, "Domain class must not be null!"); + + return repositoryFactoryInfos.containsKey(domainClass); } /** @@ -96,13 +119,27 @@ public class Repositories implements Iterable> { */ public Object getRepositoryFor(Class domainClass) { - RepositoryFactoryInformation information = getRepoInfoFor(domainClass); + Assert.notNull(domainClass, "Domain class must not be null!"); - if (information == null) { - return null; - } + String repositoryBeanName = repositoryBeanNames.get(domainClass); + return repositoryBeanName == null || beanFactory == null ? null : beanFactory.getBean(repositoryBeanName); + } + + /** + * Returns the {@link RepositoryFactoryInformation} for the given domain class. The given code is + * converted to the actual user class if necessary, @see ClassUtils#getUserClass. + * + * @param domainClass must not be {@literal null}. + * @return the {@link RepositoryFactoryInformation} for the given domain class or {@literal null} if no repository + * registered for this domain class. + */ + private RepositoryFactoryInformation getRepositoryFactoryInfoFor(Class domainClass) { - return beanFactory.getBean(repositories.get(information)); + Assert.notNull(domainClass, "Domain class must not be null!"); + + RepositoryFactoryInformation repositoryInfo = repositoryFactoryInfos.get(ClassUtils + .getUserClass(domainClass)); + return repositoryInfo == null ? EMPTY_REPOSITORY_FACTORY_INFO : repositoryInfo; } /** @@ -114,8 +151,9 @@ public class Repositories implements Iterable> { @SuppressWarnings("unchecked") public EntityInformation getEntityInformationFor(Class domainClass) { - RepositoryFactoryInformation information = getRepoInfoFor(domainClass); - return information == null ? null : (EntityInformation) information.getEntityInformation(); + Assert.notNull(domainClass, "Domain class must not be null!"); + + return (EntityInformation) getRepositoryFactoryInfoFor(domainClass).getEntityInformation(); } /** @@ -127,8 +165,10 @@ public class Repositories implements Iterable> { */ public RepositoryInformation getRepositoryInformationFor(Class domainClass) { - RepositoryFactoryInformation information = getRepoInfoFor(domainClass); - return information == null ? null : information.getRepositoryInformation(); + Assert.notNull(domainClass, "Domain class must not be null!"); + + RepositoryFactoryInformation information = getRepositoryFactoryInfoFor(domainClass); + return information == EMPTY_REPOSITORY_FACTORY_INFO ? null : information.getRepositoryInformation(); } /** @@ -141,8 +181,8 @@ public class Repositories implements Iterable> { */ public PersistentEntity getPersistentEntity(Class domainClass) { - RepositoryFactoryInformation information = getRepoInfoFor(domainClass); - return information == null ? null : information.getPersistentEntity(); + Assert.notNull(domainClass, "Domain class must not be null!"); + return getRepositoryFactoryInfoFor(domainClass).getPersistentEntity(); } /** @@ -153,14 +193,13 @@ public class Repositories implements Iterable> { */ public List getQueryMethodsFor(Class domainClass) { - RepositoryFactoryInformation information = getRepoInfoFor(domainClass); - return information == null ? Collections. emptyList() : information.getQueryMethods(); + Assert.notNull(domainClass, "Domain class must not be null!"); + return getRepositoryFactoryInfoFor(domainClass).getQueryMethods(); } @SuppressWarnings("unchecked") public CrudInvoker getCrudInvoker(Class domainClass) { - RepositoryInformation information = getRepositoryInformationFor(domainClass); Object repository = getRepositoryFor(domainClass); Assert.notNull(repository, String.format("No repository found for domain class: %s", domainClass)); @@ -168,67 +207,45 @@ public class Repositories implements Iterable> { if (repository instanceof CrudRepository) { return new CrudRepositoryInvoker((CrudRepository) repository); } else { - return new ReflectionRepositoryInvoker(repository, information.getCrudMethods()); + return new ReflectionRepositoryInvoker(repository, getRepositoryInformationFor(domainClass).getCrudMethods()); } } - private RepositoryFactoryInformation getRepoInfoFor(Class domainClass) { - - Assert.notNull(domainClass); - - // Create defensive copy of the keys to allow threads to potentially add values while iterating over them - Set> keys = Collections.unmodifiableSet(repositories.keySet()); - Class type = ClassUtils.getUserClass(domainClass); - - for (RepositoryFactoryInformation information : keys) { - if (type.equals(information.getEntityInformation().getJavaType())) { - return information; - } - } - - return lookupRepositoryFactoryInformationFor(type); - } - /* * (non-Javadoc) * @see java.lang.Iterable#iterator() */ public Iterator> iterator() { - lookupRepositoryFactoryInformationFor(null); - return domainClassToBeanName.keySet().iterator(); + return repositoryFactoryInfos.keySet().iterator(); } /** - * Looks up the {@link RepositoryFactoryInformation} for a given domain type. Will inspect the {@link BeanFactory} for - * beans implementing {@link RepositoryFactoryInformation} and cache the domain class to repository bean name mappings - * for further lookups. If a {@link RepositoryFactoryInformation} for the given domain type is found we interrupt the - * lookup proces to prevent beans from being looked up early. + * Null-object to avoid nasty {@literal null} checks in cache lookups. * - * @param domainType - * @return + * @author Thomas Darimont */ - @SuppressWarnings("unchecked") - private RepositoryFactoryInformation lookupRepositoryFactoryInformationFor(Class domainType) { - - if (domainClassToBeanName.containsKey(domainType)) { - return domainClassToBeanName.get(domainType); - } + private static enum EmptyRepositoryFactoryInformation implements RepositoryFactoryInformation { - for (String repositoryFactoryName : repositoryFactoryBeanNames) { + INSTANCE; - RepositoryFactoryInformation information = beanFactory.getBean(repositoryFactoryName, - RepositoryFactoryInformation.class); - - RepositoryInformation info = information.getRepositoryInformation(); + @Override + public EntityInformation getEntityInformation() { + return null; + } - repositories.put(information, BeanFactoryUtils.transformedBeanName(repositoryFactoryName)); - domainClassToBeanName.put(info.getDomainType(), information); + @Override + public RepositoryInformation getRepositoryInformation() { + return null; + } - if (info.getDomainType().equals(domainType)) { - return information; - } + @Override + public PersistentEntity getPersistentEntity() { + return null; } - return null; + @Override + public List getQueryMethods() { + return Collections. emptyList(); + } } } diff --git a/src/test/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupportUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupportUnitTests.java index be0316fb4..e27d03afd 100644 --- a/src/test/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupportUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupportUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2013 the original author or authors. + * Copyright 2013-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,14 @@ import static org.mockito.Mockito.*; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.springframework.data.repository.CrudRepository; import org.springframework.test.util.ReflectionTestUtils; /** * Unit tests for {@link RepositoryFactoryBeanSupport}. * * @author Oliver Gierke + * @author Thomas Darimont */ public class RepositoryFactoryBeanSupportUnitTests { @@ -37,7 +39,7 @@ public class RepositoryFactoryBeanSupportUnitTests { * @see DATACMNS-341 */ @Test - @SuppressWarnings("rawtypes") + @SuppressWarnings({ "rawtypes", "unchecked" }) public void setsConfiguredClassLoaderOnRepositoryFactory() { ClassLoader classLoader = mock(ClassLoader.class); @@ -45,12 +47,16 @@ public class RepositoryFactoryBeanSupportUnitTests { RepositoryFactoryBeanSupport factoryBean = new DummyRepositoryFactoryBean(); factoryBean.setBeanClassLoader(classLoader); factoryBean.setLazyInit(true); + factoryBean.setRepositoryInterface(CrudRepository.class); factoryBean.afterPropertiesSet(); Object factory = ReflectionTestUtils.getField(factoryBean, "factory"); assertThat(ReflectionTestUtils.getField(factory, "classLoader"), is((Object) classLoader)); } + /** + * @see DATACMNS-432 + */ @Test @SuppressWarnings("rawtypes") public void initializationFailsWithMissingRepositoryInterface() {