Skip to content

Commit 41f6da1

Browse files
UDAFs for Datasets (#1)
* sbt infra * update .gitignore * TDigestUDT.scala * A working UDAF for TDigest * ignore null entries * bug repro * workaround SPARK-21277 with flattened unsafe storage * remove silex dep * complete draft of DataFrame UDAF suite * rc1
1 parent 34689ea commit 41f6da1

File tree

7 files changed

+469
-0
lines changed

7 files changed

+469
-0
lines changed

.gitignore

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,17 @@
11
*.class
22
*.log
3+
4+
# sbt specific
5+
.cache
6+
.history
7+
.lib/
8+
dist/*
9+
target/
10+
lib_managed/
11+
src_managed/
12+
project/boot/
13+
project/plugins/project/
14+
15+
# Scala-IDE specific
16+
.scala_dependencies
17+
.worksheet

build.sbt

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name := "isarn-sketches-spark"
2+
3+
organization := "org.isarnproject"
4+
5+
bintrayOrganization := Some("isarn")
6+
7+
version := "0.1.0.rc1"
8+
9+
scalaVersion := "2.11.8"
10+
11+
crossScalaVersions := Seq("2.10.6", "2.11.8")
12+
13+
def commonSettings = Seq(
14+
libraryDependencies ++= Seq(
15+
"org.isarnproject" %% "isarn-sketches" % "0.1.0",
16+
"org.apache.spark" %% "spark-core" % "2.1.0" % Provided,
17+
"org.apache.spark" %% "spark-sql" % "2.1.0" % Provided,
18+
"org.apache.spark" %% "spark-mllib" % "2.1.0" % Provided,
19+
"org.isarnproject" %% "isarn-scalatest" % "0.0.1" % Test,
20+
"org.scalatest" %% "scalatest" % "2.2.4" % Test,
21+
"org.apache.commons" % "commons-math3" % "3.6.1" % Test),
22+
initialCommands in console := """
23+
|import org.apache.spark.SparkConf
24+
|import org.apache.spark.SparkContext
25+
|import org.apache.spark.sql.SparkSession
26+
|import org.apache.spark.SparkContext._
27+
|import org.apache.spark.rdd.RDD
28+
|import org.apache.spark.ml.linalg.Vectors
29+
|import org.isarnproject.sketches.TDigest
30+
|import org.isarnproject.sketches.udaf._
31+
|import org.apache.spark.isarnproject.sketches.udt._
32+
|val initialConf = new SparkConf().setAppName("repl").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer", "16mb")
33+
|val spark = SparkSession.builder.config(initialConf).master("local[2]").getOrCreate()
34+
|import spark._
35+
|val sc = spark.sparkContext
36+
|import org.apache.log4j.{Logger, ConsoleAppender, Level}
37+
|Logger.getRootLogger().getAppender("console").asInstanceOf[ConsoleAppender].setThreshold(Level.WARN)
38+
""".stripMargin,
39+
cleanupCommands in console := "spark.stop"
40+
)
41+
42+
seq(commonSettings:_*)
43+
44+
licenses += ("Apache-2.0", url("http://opensource.org/licenses/Apache-2.0"))
45+
46+
scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature")
47+
48+
scalacOptions in (Compile, doc) ++= Seq("-doc-root-content", baseDirectory.value+"/root-doc.txt")
49+
50+
site.settings
51+
52+
site.includeScaladoc()
53+
54+
// Re-enable if/when we want to support gh-pages w/ jekyll
55+
// site.jekyllSupport()
56+
57+
ghpages.settings
58+
59+
git.remoteRepo := "git@github.com:isarn/isarn-sketches-spark.git"

project/build.properties

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sbt.version=0.13.11

project/plugins.sbt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
resolvers += Resolver.url(
2+
"bintray-sbt-plugin-releases",
3+
url("http://dl.bintray.com/content/sbt/sbt-plugin-releases"))(
4+
Resolver.ivyStylePatterns)
5+
6+
resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
7+
8+
resolvers += "jgit-repo" at "http://download.eclipse.org/jgit/maven"
9+
10+
addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0")
11+
12+
addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.5.4")
13+
14+
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0")
15+
16+
// scoverage and coveralls deps are at old versions to avoid a bug in the current versions
17+
// update these when this fix is released: https://github.com/scoverage/sbt-coveralls/issues/73
18+
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.0.4")
19+
20+
addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.0.0")
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
Copyright 2017 Erik Erlandson
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package org.apache.spark.isarnproject.sketches.udt
15+
16+
import org.apache.spark.sql.catalyst.util._
17+
import org.apache.spark.sql.types._
18+
import org.apache.spark.sql.catalyst.InternalRow
19+
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
20+
import org.isarnproject.sketches.TDigest
21+
import org.isarnproject.sketches.tdmap.TDigestMap
22+
23+
@SQLUserDefinedType(udt = classOf[TDigestUDT])
24+
case class TDigestSQL(tdigest: TDigest)
25+
26+
class TDigestUDT extends UserDefinedType[TDigestSQL] {
27+
def userClass: Class[TDigestSQL] = classOf[TDigestSQL]
28+
29+
def sqlType: DataType = StructType(
30+
StructField("delta", DoubleType, false) ::
31+
StructField("maxDiscrete", IntegerType, false) ::
32+
StructField("nclusters", IntegerType, false) ::
33+
StructField("clustX", ArrayType(DoubleType, false), false) ::
34+
StructField("clustM", ArrayType(DoubleType, false), false) ::
35+
Nil)
36+
37+
def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest)
38+
39+
def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum))
40+
41+
private[sketches] def serializeTD(td: TDigest): InternalRow = {
42+
val TDigest(delta, maxDiscrete, nclusters, clusters) = td
43+
val row = new GenericInternalRow(5)
44+
row.setDouble(0, delta)
45+
row.setInt(1, maxDiscrete)
46+
row.setInt(2, nclusters)
47+
val clustX = clusters.keys.toArray
48+
val clustM = clusters.values.toArray
49+
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX))
50+
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM))
51+
row
52+
}
53+
54+
private[sketches] def deserializeTD(datum: Any): TDigest = datum match {
55+
case row: InternalRow =>
56+
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}")
57+
val delta = row.getDouble(0)
58+
val maxDiscrete = row.getInt(1)
59+
val nclusters = row.getInt(2)
60+
val clustX = row.getArray(3).toDoubleArray()
61+
val clustM = row.getArray(4).toDoubleArray()
62+
val clusters = clustX.zip(clustM)
63+
.foldLeft(TDigestMap.empty) { case (td, e) => td + e }
64+
TDigest(delta, maxDiscrete, nclusters, clusters)
65+
case u => throw new Exception(s"failed to deserialize: $u")
66+
}
67+
}
68+
69+
case object TDigestUDT extends TDigestUDT
70+
71+
@SQLUserDefinedType(udt = classOf[TDigestArrayUDT])
72+
case class TDigestArraySQL(tdigests: Array[TDigest])
73+
74+
class TDigestArrayUDT extends UserDefinedType[TDigestArraySQL] {
75+
def userClass: Class[TDigestArraySQL] = classOf[TDigestArraySQL]
76+
77+
// Spark seems to have trouble with ArrayType data that isn't
78+
// serialized using UnsafeArrayData (SPARK-21277), so my workaround
79+
// is to store all the cluster information flattened into single Unsafe arrays.
80+
// To deserialize, I unpack the slices.
81+
def sqlType: DataType = StructType(
82+
StructField("delta", DoubleType, false) ::
83+
StructField("maxDiscrete", IntegerType, false) ::
84+
StructField("clusterS", ArrayType(IntegerType, false), false) ::
85+
StructField("clusterX", ArrayType(DoubleType, false), false) ::
86+
StructField("ClusterM", ArrayType(DoubleType, false), false) ::
87+
Nil)
88+
89+
def serialize(tdasql: TDigestArraySQL): Any = {
90+
val row = new GenericInternalRow(5)
91+
val tda: Array[TDigest] = tdasql.tdigests
92+
val delta = if (tda.isEmpty) 0.0 else tda.head.delta
93+
val maxDiscrete = if (tda.isEmpty) 0 else tda.head.maxDiscrete
94+
val clustS = tda.map(_.nclusters)
95+
val clustX = tda.flatMap(_.clusters.keys)
96+
val clustM = tda.flatMap(_.clusters.values)
97+
row.setDouble(0, delta)
98+
row.setInt(1, maxDiscrete)
99+
row.update(2, UnsafeArrayData.fromPrimitiveArray(clustS))
100+
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX))
101+
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM))
102+
row
103+
}
104+
105+
def deserialize(datum: Any): TDigestArraySQL = datum match {
106+
case row: InternalRow =>
107+
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}")
108+
val delta = row.getDouble(0)
109+
val maxDiscrete = row.getInt(1)
110+
val clustS = row.getArray(2).toIntArray()
111+
val clustX = row.getArray(3).toDoubleArray()
112+
val clustM = row.getArray(4).toDoubleArray()
113+
var beg = 0
114+
val tda = clustS.map { nclusters =>
115+
val x = clustX.slice(beg, beg + nclusters)
116+
val m = clustM.slice(beg, beg + nclusters)
117+
val clusters = x.zip(m).foldLeft(TDigestMap.empty) { case (td, e) => td + e }
118+
val td = TDigest(delta, maxDiscrete, nclusters, clusters)
119+
beg += nclusters
120+
td
121+
}
122+
TDigestArraySQL(tda)
123+
case u => throw new Exception(s"failed to deserialize: $u")
124+
}
125+
}
126+
127+
case object TDigestArrayUDT extends TDigestArrayUDT
128+
129+
// VectorUDT is private[spark], but I can expose what I need this way:
130+
object TDigestUDTInfra {
131+
private val udtML = new org.apache.spark.ml.linalg.VectorUDT
132+
def udtVectorML: DataType = udtML
133+
134+
private val udtMLLib = new org.apache.spark.mllib.linalg.VectorUDT
135+
def udtVectorMLLib: DataType = udtMLLib
136+
}

0 commit comments

Comments
 (0)