Skip to content

Commit

Permalink
spring support for scala default parameter values
Browse files Browse the repository at this point in the history
  • Loading branch information
ghik committed Dec 8, 2017
1 parent 05bdc3d commit b95332b
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ object SharedExtensions extends SharedExtensions {
def evalFuture: Future[A] = FutureCompanionOps.eval(a())

def evalTry: Try[A] = Try(a())

def recoverFrom[T <: Throwable : ClassTag](fallbackValue: => A): A =
try a() catch {
case _: T => fallbackValue
}

def recoverToOpt[T <: Throwable : ClassTag]: Opt[A] =
try Opt(a()) catch {
case _: T => Opt.Empty
}
}

class NullableOps[A >: Null](private val a: A) extends AnyVal {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class HoconBeanDefinitionReader(registry: BeanDefinitionRegistry)
getProps(obj).foreach {
case (key, value) =>
if (construct) {
addConstructorArg(readConstructorArg(value, Some(key)))
addConstructorArg(readConstructorArg(value, forcedName = key))
} else {
propertyValues.addPropertyValue(readPropertyValue(key, value))
}
Expand Down Expand Up @@ -275,31 +275,35 @@ class HoconBeanDefinitionReader(registry: BeanDefinitionRegistry)
private def readConstructorArgs(value: ConfigValue) = {
value.as[Option[Either[ConfigList, ConfigObject]]] match {
case Some(Left(list)) =>
list.iterator.asScala.map(configValue => readConstructorArg(configValue))
list.iterator.asScala.zipWithIndex.map { case (configValue, idx) =>
readConstructorArg(configValue, forcedIndex = idx)
}
case Some(Right(obj)) =>
validateObj(props = true)(obj)
getProps(obj).iterator.map { case (name, configValue) =>
val (idxOpt, holder) = readConstructorArg(configValue)
holder.setName(name)
(idxOpt, holder)
readConstructorArg(configValue, forcedName = name)
}
case None =>
Iterator.empty
}
}

private def readConstructorArg(value: ConfigValue, forcedName: Option[String] = None) = value match {
private def readConstructorArg(
value: ConfigValue,
forcedIndex: OptArg[Int] = OptArg.Empty,
forcedName: OptArg[String] = OptArg.Empty
) = value match {
case ValueDefinition(obj) =>
validateObj(required = Set(ValueAttr), allowed = Set(IndexAttr, TypeAttr, NameAttr))(obj)
val vh = new ValueHolder(read(obj.get(ValueAttr)))
obj.get(TypeAttr).as[Option[String]].foreach(vh.setType)
(forcedName orElse obj.get(NameAttr).as[Option[String]]).foreach(vh.setName)
val indexOpt = obj.get(IndexAttr).as[Option[Int]]
(forcedName.toOption orElse obj.get(NameAttr).as[Option[String]]).foreach(vh.setName)
val indexOpt = forcedIndex.toOption orElse obj.get(IndexAttr).as[Option[Int]]
(indexOpt, vh)
case _ =>
val vh = new ValueHolder(read(value))
forcedName.foreach(vh.setName)
(None, vh)
(forcedIndex.toOption, vh)
}

private def readPropertyValue(name: String, value: ConfigValue) = value match {
Expand Down Expand Up @@ -344,6 +348,6 @@ class HoconBeanDefinitionReader(registry: BeanDefinitionRegistry)
result
}

def loadBeanDefinitions(resource: Resource) =
def loadBeanDefinitions(resource: Resource): Int =
loadBeanDefinitions(ConfigFactory.parseURL(resource.getURL).resolve)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.avsystem.commons
package spring

import java.lang.reflect.{Constructor, Method, Modifier}

import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder
import org.springframework.beans.factory.config.{BeanDefinition, BeanDefinitionHolder, ConfigurableListableBeanFactory}
import org.springframework.beans.factory.support.{BeanDefinitionRegistry, BeanDefinitionRegistryPostProcessor, ManagedList, ManagedMap, ManagedSet}
import org.springframework.core.ParameterNameDiscoverer

import scala.beans.BeanProperty
import scala.reflect.{ScalaLongSignature, ScalaSignature}

class ScalaDefaultValuesInjector extends BeanDefinitionRegistryPostProcessor {
@BeanProperty var paramNameDiscoverer: ParameterNameDiscoverer =
new ScalaParameterNameDiscoverer

def classLoader: ClassLoader =
Thread.currentThread.getContextClassLoader.opt getOrElse getClass.getClassLoader

def loadClass(name: String): Class[_] = Class.forName(name, false, classLoader)

def postProcessBeanDefinitionRegistry(registry: BeanDefinitionRegistry): Unit = {
def traverse(value: Any): Unit = value match {
case bd: BeanDefinition =>
bd.getConstructorArgumentValues.getGenericArgumentValues.asScala.foreach(traverse)
bd.getConstructorArgumentValues.getIndexedArgumentValues.values.asScala.foreach(traverse)
bd.getPropertyValues.getPropertyValueList.asScala.foreach(pv => traverse(pv.getValue))
injectDefaultValues(bd)
case bdw: BeanDefinitionHolder =>
traverse(bdw.getBeanDefinition)
case vh: ValueHolder =>
traverse(vh.getValue)
case ml: ManagedList[_] =>
ml.asScala.foreach(traverse)
case ms: ManagedSet[_] =>
ms.asScala.foreach(traverse)
case mm: ManagedMap[_, _] =>
mm.asScala.foreach {
case (k, v) =>
traverse(k)
traverse(v)
}
case _ =>
}

registry.getBeanDefinitionNames
.foreach(n => traverse(registry.getBeanDefinition(n)))
}

private def isScalaClass(cls: Class[_]): Boolean = cls.getEnclosingClass match {
case null => cls.getAnnotation(classOf[ScalaSignature]) != null ||
cls.getAnnotation(classOf[ScalaLongSignature]) != null
case encls => isScalaClass(encls)
}

private def injectDefaultValues(bd: BeanDefinition): Unit = {
val className = bd.getFactoryBeanName.opt getOrElse bd.getBeanClassName
loadClass(className).recoverToOpt[ClassNotFoundException].filter(isScalaClass).foreach { clazz =>
val usingConstructor = bd.getFactoryMethodName == null
val factoryExecs =
if (usingConstructor) clazz.getConstructors.toVector
else clazz.getMethods.iterator.filter(_.getName == bd.getFactoryMethodName).toVector
val factorySymbolName =
if (usingConstructor) "$lessinit$greater" else bd.getFactoryMethodName

if (factoryExecs.size == 1) {
val constrVals = bd.getConstructorArgumentValues
val factoryExec = factoryExecs.head
val paramNames = factoryExec match {
case c: Constructor[_] => paramNameDiscoverer.getParameterNames(c)
case m: Method => paramNameDiscoverer.getParameterNames(m)
}
(0 until factoryExec.getParameterCount).foreach { i =>
def defaultValueMethod = clazz.getMethod(s"$factorySymbolName$$default$$${i + 1}")
.recoverToOpt[NoSuchMethodException].filter(m => Modifier.isStatic(m.getModifiers))
def specifiedNamed = paramNames != null &&
constrVals.getGenericArgumentValues.asScala.exists(_.getName == paramNames(i))
def specifiedIndexed =
constrVals.getIndexedArgumentValues.get(i) != null
if (!specifiedNamed && !specifiedIndexed) {
defaultValueMethod.foreach { dvm =>
constrVals.addIndexedArgumentValue(i, dvm.invoke(null))
}
}
}
}
}
}

def postProcessBeanFactory(beanFactory: ConfigurableListableBeanFactory): Unit = ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.avsystem.commons
package spring

import java.lang.reflect.{Constructor, Executable, Method, Modifier}

import org.springframework.core.{JdkVersion, ParameterNameDiscoverer}

import scala.annotation.tailrec
import scala.ref.WeakReference
import scala.reflect.api.JavaUniverse
import scala.reflect.{ScalaLongSignature, ScalaSignature}

object ScalaParameterNameDiscoverer {
final val ScalaSignatureClasses =
List(classOf[ScalaSignature], classOf[ScalaLongSignature])

final val JdkAtLeast8 =
JdkVersion.getMajorJavaVersion >= JdkVersion.JAVA_18

// we don't want to keep the universe in memory forever, so we don't use scala.reflect.runtime.universe
private var universeRef: WeakReference[JavaUniverse] = _

private def universe: JavaUniverse = {
universeRef.option.flatMap(_.get) match {
case Some(result) => result
case None =>
val result = new scala.reflect.runtime.JavaUniverse
universeRef = new WeakReference[JavaUniverse](result)
result
}
}
}

class ScalaParameterNameDiscoverer extends ParameterNameDiscoverer {

import ScalaParameterNameDiscoverer._

@tailrec private def isScala(cls: Class[_]): Boolean = cls.getEnclosingClass match {
case null => ScalaSignatureClasses.exists(ac => cls.getAnnotation(ac) != null)
case encls => isScala(encls)
}

private def discoverNames(u: JavaUniverse)(executable: Executable, symbolPredicate: u.Symbol => Boolean): Array[String] = {
import u._

val declaringClass = executable.getDeclaringClass
val mirror = runtimeMirror(declaringClass.getClassLoader)
val ownerSymbol =
if (Modifier.isStatic(executable.getModifiers)) mirror.moduleSymbol(declaringClass).moduleClass.asType
else mirror.classSymbol(declaringClass)

def argErasuresMatch(ms: MethodSymbol) =
ms.paramLists.flatten.map(s => mirror.runtimeClass(s.typeSignature)) == executable.getParameterTypes.toList

def paramNames(ms: MethodSymbol) =
ms.paramLists.flatten.map(_.name.toString).toArray

ownerSymbol.toType.members
.find(s => symbolPredicate(s) && argErasuresMatch(s.asMethod))
.map(s => paramNames(s.asMethod))
.orNull
}

def getParameterNames(ctor: Constructor[_]): Array[String] =
if (JdkAtLeast8 && ctor.getParameters.forall(_.isNamePresent))
ctor.getParameters.map(_.getName)
else if (isScala(ctor.getDeclaringClass))
discoverNames(universe)(ctor, s => s.isConstructor)
else null

def getParameterNames(method: Method): Array[String] = {
val declaringCls = method.getDeclaringClass
if (JdkAtLeast8 && method.getParameters.forall(_.isNamePresent))
method.getParameters.map(_.getName)
else if (isScala(declaringCls)) {
// https://github.com/scala/bug/issues/10650
val forStaticForwarder =
if (Modifier.isStatic(method.getModifiers))
Class.forName(declaringCls.getName + "$", false, declaringCls.getClassLoader)
.recoverToOpt[ClassNotFoundException]
.flatMap(_.getMethod(method.getName, method.getParameterTypes: _*).recoverToOpt[NoSuchMethodException])
.map(getParameterNames)
else
Opt.Empty
forStaticForwarder.getOrElse(
discoverNames(universe)(method, s => s.isMethod && s.name.toString == method.getName))
}
else null
}
}
56 changes: 41 additions & 15 deletions commons-spring/src/test/resources/testBean.conf
Original file line number Diff line number Diff line change
@@ -1,32 +1,58 @@
beanClass = com.avsystem.commons.spring.TestBean
abstract {
testBean {
%class = com.avsystem.commons.spring.TestBean
}
constrTestBean = ${abstract.testBean} {
%construct = true
}
fmTestBean = ${abstract.constrTestBean} {
%factory-method = create
}
}

beans {
testBean {
%class = ${beanClass}
%constructor-args = [42, "lolzsy"]
testBean = ${abstract.testBean} {
%constructor-args = [42, lolzsy]
int = 5
string = "lol"
string = lol
strIntMap {
"fuu" = 42
fuu = 42
}
strList = ["a", "b"]
strSet = ["A", "B"]
nestedBean {
%class = ${beanClass}
strList = [a, b]
strSet = [A, B]
nestedBean = ${abstract.testBean} {
%constructor-args {
constrString = "wut"
constrString = wut
constrInt = 1
}
int = 6
nestedBean {
%class = ${beanClass}
%construct = true
constrString = "yes"
nestedBean = ${abstract.constrTestBean} {
constrString = yes
constrInt = 2
}
}
config.%config {
srsly = dafuq
}
}

testBeanDefInt = ${abstract.constrTestBean} {
constrString = constrNonDefault
}

testBeanDefString = ${abstract.constrTestBean} {
constrInt = 2
}

testBeanDefAll = ${abstract.constrTestBean}

testBeanFMDefInt = ${abstract.fmTestBean} {
theString = factoryNonDefault
}

testBeanFMDefString = ${abstract.fmTestBean} {
theInt = -2
}

testBeanFMDefAll = ${abstract.fmTestBean}
}

This file was deleted.

Loading

0 comments on commit b95332b

Please sign in to comment.