-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
spring support for scala default parameter values
- Loading branch information
Showing
8 changed files
with
281 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
commons-spring/src/main/scala/com/avsystem/commons/spring/ScalaDefaultValuesInjector.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package com.avsystem.commons | ||
package spring | ||
|
||
import java.lang.reflect.{Constructor, Method, Modifier} | ||
|
||
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder | ||
import org.springframework.beans.factory.config.{BeanDefinition, BeanDefinitionHolder, ConfigurableListableBeanFactory} | ||
import org.springframework.beans.factory.support.{BeanDefinitionRegistry, BeanDefinitionRegistryPostProcessor, ManagedList, ManagedMap, ManagedSet} | ||
import org.springframework.core.ParameterNameDiscoverer | ||
|
||
import scala.beans.BeanProperty | ||
import scala.reflect.{ScalaLongSignature, ScalaSignature} | ||
|
||
class ScalaDefaultValuesInjector extends BeanDefinitionRegistryPostProcessor { | ||
@BeanProperty var paramNameDiscoverer: ParameterNameDiscoverer = | ||
new ScalaParameterNameDiscoverer | ||
|
||
def classLoader: ClassLoader = | ||
Thread.currentThread.getContextClassLoader.opt getOrElse getClass.getClassLoader | ||
|
||
def loadClass(name: String): Class[_] = Class.forName(name, false, classLoader) | ||
|
||
def postProcessBeanDefinitionRegistry(registry: BeanDefinitionRegistry): Unit = { | ||
def traverse(value: Any): Unit = value match { | ||
case bd: BeanDefinition => | ||
bd.getConstructorArgumentValues.getGenericArgumentValues.asScala.foreach(traverse) | ||
bd.getConstructorArgumentValues.getIndexedArgumentValues.values.asScala.foreach(traverse) | ||
bd.getPropertyValues.getPropertyValueList.asScala.foreach(pv => traverse(pv.getValue)) | ||
injectDefaultValues(bd) | ||
case bdw: BeanDefinitionHolder => | ||
traverse(bdw.getBeanDefinition) | ||
case vh: ValueHolder => | ||
traverse(vh.getValue) | ||
case ml: ManagedList[_] => | ||
ml.asScala.foreach(traverse) | ||
case ms: ManagedSet[_] => | ||
ms.asScala.foreach(traverse) | ||
case mm: ManagedMap[_, _] => | ||
mm.asScala.foreach { | ||
case (k, v) => | ||
traverse(k) | ||
traverse(v) | ||
} | ||
case _ => | ||
} | ||
|
||
registry.getBeanDefinitionNames | ||
.foreach(n => traverse(registry.getBeanDefinition(n))) | ||
} | ||
|
||
private def isScalaClass(cls: Class[_]): Boolean = cls.getEnclosingClass match { | ||
case null => cls.getAnnotation(classOf[ScalaSignature]) != null || | ||
cls.getAnnotation(classOf[ScalaLongSignature]) != null | ||
case encls => isScalaClass(encls) | ||
} | ||
|
||
private def injectDefaultValues(bd: BeanDefinition): Unit = { | ||
val className = bd.getFactoryBeanName.opt getOrElse bd.getBeanClassName | ||
loadClass(className).recoverToOpt[ClassNotFoundException].filter(isScalaClass).foreach { clazz => | ||
val usingConstructor = bd.getFactoryMethodName == null | ||
val factoryExecs = | ||
if (usingConstructor) clazz.getConstructors.toVector | ||
else clazz.getMethods.iterator.filter(_.getName == bd.getFactoryMethodName).toVector | ||
val factorySymbolName = | ||
if (usingConstructor) "$lessinit$greater" else bd.getFactoryMethodName | ||
|
||
if (factoryExecs.size == 1) { | ||
val constrVals = bd.getConstructorArgumentValues | ||
val factoryExec = factoryExecs.head | ||
val paramNames = factoryExec match { | ||
case c: Constructor[_] => paramNameDiscoverer.getParameterNames(c) | ||
case m: Method => paramNameDiscoverer.getParameterNames(m) | ||
} | ||
(0 until factoryExec.getParameterCount).foreach { i => | ||
def defaultValueMethod = clazz.getMethod(s"$factorySymbolName$$default$$${i + 1}") | ||
.recoverToOpt[NoSuchMethodException].filter(m => Modifier.isStatic(m.getModifiers)) | ||
def specifiedNamed = paramNames != null && | ||
constrVals.getGenericArgumentValues.asScala.exists(_.getName == paramNames(i)) | ||
def specifiedIndexed = | ||
constrVals.getIndexedArgumentValues.get(i) != null | ||
if (!specifiedNamed && !specifiedIndexed) { | ||
defaultValueMethod.foreach { dvm => | ||
constrVals.addIndexedArgumentValue(i, dvm.invoke(null)) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
def postProcessBeanFactory(beanFactory: ConfigurableListableBeanFactory): Unit = () | ||
} |
90 changes: 90 additions & 0 deletions
90
commons-spring/src/main/scala/com/avsystem/commons/spring/ScalaParameterNameDiscoverer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package com.avsystem.commons | ||
package spring | ||
|
||
import java.lang.reflect.{Constructor, Executable, Method, Modifier} | ||
|
||
import org.springframework.core.{JdkVersion, ParameterNameDiscoverer} | ||
|
||
import scala.annotation.tailrec | ||
import scala.ref.WeakReference | ||
import scala.reflect.api.JavaUniverse | ||
import scala.reflect.{ScalaLongSignature, ScalaSignature} | ||
|
||
object ScalaParameterNameDiscoverer { | ||
final val ScalaSignatureClasses = | ||
List(classOf[ScalaSignature], classOf[ScalaLongSignature]) | ||
|
||
final val JdkAtLeast8 = | ||
JdkVersion.getMajorJavaVersion >= JdkVersion.JAVA_18 | ||
|
||
// we don't want to keep the universe in memory forever, so we don't use scala.reflect.runtime.universe | ||
private var universeRef: WeakReference[JavaUniverse] = _ | ||
|
||
private def universe: JavaUniverse = { | ||
universeRef.option.flatMap(_.get) match { | ||
case Some(result) => result | ||
case None => | ||
val result = new scala.reflect.runtime.JavaUniverse | ||
universeRef = new WeakReference[JavaUniverse](result) | ||
result | ||
} | ||
} | ||
} | ||
|
||
class ScalaParameterNameDiscoverer extends ParameterNameDiscoverer { | ||
|
||
import ScalaParameterNameDiscoverer._ | ||
|
||
@tailrec private def isScala(cls: Class[_]): Boolean = cls.getEnclosingClass match { | ||
case null => ScalaSignatureClasses.exists(ac => cls.getAnnotation(ac) != null) | ||
case encls => isScala(encls) | ||
} | ||
|
||
private def discoverNames(u: JavaUniverse)(executable: Executable, symbolPredicate: u.Symbol => Boolean): Array[String] = { | ||
import u._ | ||
|
||
val declaringClass = executable.getDeclaringClass | ||
val mirror = runtimeMirror(declaringClass.getClassLoader) | ||
val ownerSymbol = | ||
if (Modifier.isStatic(executable.getModifiers)) mirror.moduleSymbol(declaringClass).moduleClass.asType | ||
else mirror.classSymbol(declaringClass) | ||
|
||
def argErasuresMatch(ms: MethodSymbol) = | ||
ms.paramLists.flatten.map(s => mirror.runtimeClass(s.typeSignature)) == executable.getParameterTypes.toList | ||
|
||
def paramNames(ms: MethodSymbol) = | ||
ms.paramLists.flatten.map(_.name.toString).toArray | ||
|
||
ownerSymbol.toType.members | ||
.find(s => symbolPredicate(s) && argErasuresMatch(s.asMethod)) | ||
.map(s => paramNames(s.asMethod)) | ||
.orNull | ||
} | ||
|
||
def getParameterNames(ctor: Constructor[_]): Array[String] = | ||
if (JdkAtLeast8 && ctor.getParameters.forall(_.isNamePresent)) | ||
ctor.getParameters.map(_.getName) | ||
else if (isScala(ctor.getDeclaringClass)) | ||
discoverNames(universe)(ctor, s => s.isConstructor) | ||
else null | ||
|
||
def getParameterNames(method: Method): Array[String] = { | ||
val declaringCls = method.getDeclaringClass | ||
if (JdkAtLeast8 && method.getParameters.forall(_.isNamePresent)) | ||
method.getParameters.map(_.getName) | ||
else if (isScala(declaringCls)) { | ||
// https://github.com/scala/bug/issues/10650 | ||
val forStaticForwarder = | ||
if (Modifier.isStatic(method.getModifiers)) | ||
Class.forName(declaringCls.getName + "$", false, declaringCls.getClassLoader) | ||
.recoverToOpt[ClassNotFoundException] | ||
.flatMap(_.getMethod(method.getName, method.getParameterTypes: _*).recoverToOpt[NoSuchMethodException]) | ||
.map(getParameterNames) | ||
else | ||
Opt.Empty | ||
forStaticForwarder.getOrElse( | ||
discoverNames(universe)(method, s => s.isMethod && s.name.toString == method.getName)) | ||
} | ||
else null | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,58 @@ | ||
beanClass = com.avsystem.commons.spring.TestBean | ||
abstract { | ||
testBean { | ||
%class = com.avsystem.commons.spring.TestBean | ||
} | ||
constrTestBean = ${abstract.testBean} { | ||
%construct = true | ||
} | ||
fmTestBean = ${abstract.constrTestBean} { | ||
%factory-method = create | ||
} | ||
} | ||
|
||
beans { | ||
testBean { | ||
%class = ${beanClass} | ||
%constructor-args = [42, "lolzsy"] | ||
testBean = ${abstract.testBean} { | ||
%constructor-args = [42, lolzsy] | ||
int = 5 | ||
string = "lol" | ||
string = lol | ||
strIntMap { | ||
"fuu" = 42 | ||
fuu = 42 | ||
} | ||
strList = ["a", "b"] | ||
strSet = ["A", "B"] | ||
nestedBean { | ||
%class = ${beanClass} | ||
strList = [a, b] | ||
strSet = [A, B] | ||
nestedBean = ${abstract.testBean} { | ||
%constructor-args { | ||
constrString = "wut" | ||
constrString = wut | ||
constrInt = 1 | ||
} | ||
int = 6 | ||
nestedBean { | ||
%class = ${beanClass} | ||
%construct = true | ||
constrString = "yes" | ||
nestedBean = ${abstract.constrTestBean} { | ||
constrString = yes | ||
constrInt = 2 | ||
} | ||
} | ||
config.%config { | ||
srsly = dafuq | ||
} | ||
} | ||
|
||
testBeanDefInt = ${abstract.constrTestBean} { | ||
constrString = constrNonDefault | ||
} | ||
|
||
testBeanDefString = ${abstract.constrTestBean} { | ||
constrInt = 2 | ||
} | ||
|
||
testBeanDefAll = ${abstract.constrTestBean} | ||
|
||
testBeanFMDefInt = ${abstract.fmTestBean} { | ||
theString = factoryNonDefault | ||
} | ||
|
||
testBeanFMDefString = ${abstract.fmTestBean} { | ||
theInt = -2 | ||
} | ||
|
||
testBeanFMDefAll = ${abstract.fmTestBean} | ||
} |
14 changes: 0 additions & 14 deletions
14
...spring/src/test/scala/com/avsystem/commons/spring/AnnotationParameterNameDiscoverer.scala
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.