@@ -27,6 +27,22 @@ import com.cloudera.livy.util.LineBufferedProcess
27
27
28
28
object LivySparkUtils {
29
29
30
+ // For each Spark version we supported, we need to add this mapping relation in case Scala
31
+ // version cannot be detected from "spark-submit --version".
32
+ private val defaultSparkScalaVersion = Map (
33
+ // Spark 2.0 + Scala 2.11
34
+ (2 , 0 ) -> " 2.11" ,
35
+ // Spark 1.6 + Scala 2.10
36
+ (1 , 6 ) -> " 2.10"
37
+ )
38
+
39
+ // Supported Spark version
40
+ private val MIN_VERSION = (1 , 6 )
41
+ private val MAX_VERSION = (2 , 1 )
42
+
43
+ private val sparkVersionRegex = """ version (.*)""" .r.unanchored
44
+ private val scalaVersionRegex = """ Scala version (.*), Java""" .r.unanchored
45
+
30
46
/**
31
47
* Test that Spark home is configured and configured Spark home is a directory.
32
48
*/
@@ -45,7 +61,7 @@ object LivySparkUtils {
45
61
*/
46
62
def testSparkSubmit (livyConf : LivyConf ): Unit = {
47
63
try {
48
- testSparkVersion(sparkSubmitVersion(livyConf))
64
+ testSparkVersion(sparkSubmitVersion(livyConf)._1 )
49
65
} catch {
50
66
case e : IOException =>
51
67
throw new IOException (" Failed to run spark-submit executable" , e)
@@ -57,25 +73,21 @@ object LivySparkUtils {
57
73
* @param version Spark version
58
74
*/
59
75
def testSparkVersion (version : String ): Unit = {
60
- // This is exclusive. Version which equals to this will be rejected.
61
- val maxVersion = (2 , 1 )
62
- val minVersion = (1 , 6 )
63
-
64
76
val supportedVersion = formatSparkVersion(version) match {
65
77
case v : (Int , Int ) =>
66
- v >= minVersion && v < maxVersion
78
+ v >= MIN_VERSION && v < MAX_VERSION
67
79
case _ => false
68
80
}
69
81
require(supportedVersion, s " Unsupported Spark version $version. " )
70
82
}
71
83
72
84
/**
73
- * Return the version of the configured `spark-submit` version.
85
+ * Return the Spark and Scala version of the configured `spark-submit` version.
74
86
*
75
87
* @param livyConf
76
- * @return the version
88
+ * @return Tuple with Spark and Scala version
77
89
*/
78
- def sparkSubmitVersion (livyConf : LivyConf ): String = {
90
+ def sparkSubmitVersion (livyConf : LivyConf ): ( String , String ) = {
79
91
val sparkSubmit = livyConf.sparkSubmit()
80
92
val pb = new ProcessBuilder (sparkSubmit, " --version" )
81
93
pb.redirectErrorStream(true )
@@ -89,20 +101,28 @@ object LivySparkUtils {
89
101
val exitCode = process.waitFor()
90
102
val output = process.inputIterator.mkString(" \n " )
91
103
92
- val regex = """ version (.*)""" .r.unanchored
93
-
104
+ var sparkVersion = " "
94
105
output match {
95
- case regex (version) => version
106
+ case sparkVersionRegex (version) => sparkVersion = version
96
107
case _ =>
97
108
throw new IOException (f " Unable to determine spark-submit version [ $exitCode]: \n $output" )
98
109
}
110
+
111
+ var scalaVersion = " "
112
+ output match {
113
+ case scalaVersionRegex(version) => scalaVersion = version
114
+ case _ =>
115
+ }
116
+
117
+ (sparkVersion, scalaVersion)
99
118
}
100
119
101
120
/**
102
- * Return formatted Spark version.
103
- * @param version Spark version
104
- * @return Two element tuple, one is major version and the other is minor version
105
- */
121
+ * Return formatted Spark version.
122
+ *
123
+ * @param version Spark version
124
+ * @return Two element tuple, one is major version and the other is minor version
125
+ */
106
126
def formatSparkVersion (version : String ): (Int , Int ) = {
107
127
val versionPattern = """ (\d)+\.(\d)+(?:[\.-]\d*)*""" .r
108
128
version match {
@@ -112,4 +132,23 @@ object LivySparkUtils {
112
132
throw new IllegalArgumentException (s " Fail to parse Spark version from $version" )
113
133
}
114
134
}
135
+
136
+ /**
137
+ * Return Scala binary version, if it cannot be parsed from input version string, it will
138
+ * pick default Scala version related to Spark version.
139
+ *
140
+ * @param scalaVersion Scala binary version String
141
+ * @param sparkVersion formatted Spark version.
142
+ * @return Scala binary version String based on Spark version and livy conf.
143
+ */
144
+ def formatScalaVersion (scalaVersion : String , sparkVersion : (Int , Int )): String = {
145
+ val versionPattern = """ (\d)+\.(\d+)+.*""" .r
146
+ scalaVersion match {
147
+ case versionPattern(major, minor) =>
148
+ major + " ." + minor
149
+ case _ =>
150
+ defaultSparkScalaVersion.getOrElse(sparkVersion,
151
+ throw new IllegalArgumentException (s " Fail to get Scala version from Spark $sparkVersion" ))
152
+ }
153
+ }
115
154
}
0 commit comments