|
| 1 | +/* |
| 2 | +Copyright 2017 Erik Erlandson |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +Unless required by applicable law or agreed to in writing, software |
| 8 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +See the License for the specific language governing permissions and |
| 11 | +limitations under the License. |
| 12 | +*/ |
| 13 | + |
| 14 | +package org.apache.spark.isarnproject.sketches.udt |
| 15 | + |
| 16 | +import org.apache.spark.sql.catalyst.util._ |
| 17 | +import org.apache.spark.sql.types._ |
| 18 | +import org.apache.spark.sql.catalyst.InternalRow |
| 19 | +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} |
| 20 | +import org.isarnproject.sketches.TDigest |
| 21 | +import org.isarnproject.sketches.tdmap.TDigestMap |
| 22 | + |
| 23 | +@SQLUserDefinedType(udt = classOf[TDigestUDT]) |
| 24 | +case class TDigestSQL(tdigest: TDigest) |
| 25 | + |
| 26 | +class TDigestUDT extends UserDefinedType[TDigestSQL] { |
| 27 | + def userClass: Class[TDigestSQL] = classOf[TDigestSQL] |
| 28 | + |
| 29 | + def sqlType: DataType = StructType( |
| 30 | + StructField("delta", DoubleType, false) :: |
| 31 | + StructField("maxDiscrete", IntegerType, false) :: |
| 32 | + StructField("nclusters", IntegerType, false) :: |
| 33 | + StructField("clustX", ArrayType(DoubleType, false), false) :: |
| 34 | + StructField("clustM", ArrayType(DoubleType, false), false) :: |
| 35 | + Nil) |
| 36 | + |
| 37 | + def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest) |
| 38 | + |
| 39 | + def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum)) |
| 40 | + |
| 41 | + private[sketches] def serializeTD(td: TDigest): InternalRow = { |
| 42 | + val TDigest(delta, maxDiscrete, nclusters, clusters) = td |
| 43 | + val row = new GenericInternalRow(5) |
| 44 | + row.setDouble(0, delta) |
| 45 | + row.setInt(1, maxDiscrete) |
| 46 | + row.setInt(2, nclusters) |
| 47 | + val clustX = clusters.keys.toArray |
| 48 | + val clustM = clusters.values.toArray |
| 49 | + row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX)) |
| 50 | + row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM)) |
| 51 | + row |
| 52 | + } |
| 53 | + |
| 54 | + private[sketches] def deserializeTD(datum: Any): TDigest = datum match { |
| 55 | + case row: InternalRow => |
| 56 | + require(row.numFields == 5, s"expected row length 5, got ${row.numFields}") |
| 57 | + val delta = row.getDouble(0) |
| 58 | + val maxDiscrete = row.getInt(1) |
| 59 | + val nclusters = row.getInt(2) |
| 60 | + val clustX = row.getArray(3).toDoubleArray() |
| 61 | + val clustM = row.getArray(4).toDoubleArray() |
| 62 | + val clusters = clustX.zip(clustM) |
| 63 | + .foldLeft(TDigestMap.empty) { case (td, e) => td + e } |
| 64 | + TDigest(delta, maxDiscrete, nclusters, clusters) |
| 65 | + case u => throw new Exception(s"failed to deserialize: $u") |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | +case object TDigestUDT extends TDigestUDT |
| 70 | + |
| 71 | +@SQLUserDefinedType(udt = classOf[TDigestArrayUDT]) |
| 72 | +case class TDigestArraySQL(tdigests: Array[TDigest]) |
| 73 | + |
| 74 | +class TDigestArrayUDT extends UserDefinedType[TDigestArraySQL] { |
| 75 | + def userClass: Class[TDigestArraySQL] = classOf[TDigestArraySQL] |
| 76 | + |
| 77 | + // Spark seems to have trouble with ArrayType data that isn't |
| 78 | + // serialized using UnsafeArrayData (SPARK-21277), so my workaround |
| 79 | + // is to store all the cluster information flattened into single Unsafe arrays. |
| 80 | + // To deserialize, I unpack the slices. |
| 81 | + def sqlType: DataType = StructType( |
| 82 | + StructField("delta", DoubleType, false) :: |
| 83 | + StructField("maxDiscrete", IntegerType, false) :: |
| 84 | + StructField("clusterS", ArrayType(IntegerType, false), false) :: |
| 85 | + StructField("clusterX", ArrayType(DoubleType, false), false) :: |
| 86 | + StructField("ClusterM", ArrayType(DoubleType, false), false) :: |
| 87 | + Nil) |
| 88 | + |
| 89 | + def serialize(tdasql: TDigestArraySQL): Any = { |
| 90 | + val row = new GenericInternalRow(5) |
| 91 | + val tda: Array[TDigest] = tdasql.tdigests |
| 92 | + val delta = if (tda.isEmpty) 0.0 else tda.head.delta |
| 93 | + val maxDiscrete = if (tda.isEmpty) 0 else tda.head.maxDiscrete |
| 94 | + val clustS = tda.map(_.nclusters) |
| 95 | + val clustX = tda.flatMap(_.clusters.keys) |
| 96 | + val clustM = tda.flatMap(_.clusters.values) |
| 97 | + row.setDouble(0, delta) |
| 98 | + row.setInt(1, maxDiscrete) |
| 99 | + row.update(2, UnsafeArrayData.fromPrimitiveArray(clustS)) |
| 100 | + row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX)) |
| 101 | + row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM)) |
| 102 | + row |
| 103 | + } |
| 104 | + |
| 105 | + def deserialize(datum: Any): TDigestArraySQL = datum match { |
| 106 | + case row: InternalRow => |
| 107 | + require(row.numFields == 5, s"expected row length 5, got ${row.numFields}") |
| 108 | + val delta = row.getDouble(0) |
| 109 | + val maxDiscrete = row.getInt(1) |
| 110 | + val clustS = row.getArray(2).toIntArray() |
| 111 | + val clustX = row.getArray(3).toDoubleArray() |
| 112 | + val clustM = row.getArray(4).toDoubleArray() |
| 113 | + var beg = 0 |
| 114 | + val tda = clustS.map { nclusters => |
| 115 | + val x = clustX.slice(beg, beg + nclusters) |
| 116 | + val m = clustM.slice(beg, beg + nclusters) |
| 117 | + val clusters = x.zip(m).foldLeft(TDigestMap.empty) { case (td, e) => td + e } |
| 118 | + val td = TDigest(delta, maxDiscrete, nclusters, clusters) |
| 119 | + beg += nclusters |
| 120 | + td |
| 121 | + } |
| 122 | + TDigestArraySQL(tda) |
| 123 | + case u => throw new Exception(s"failed to deserialize: $u") |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +case object TDigestArrayUDT extends TDigestArrayUDT |
| 128 | + |
| 129 | +// VectorUDT is private[spark], but I can expose what I need this way: |
| 130 | +object TDigestUDTInfra { |
| 131 | + private val udtML = new org.apache.spark.ml.linalg.VectorUDT |
| 132 | + def udtVectorML: DataType = udtML |
| 133 | + |
| 134 | + private val udtMLLib = new org.apache.spark.mllib.linalg.VectorUDT |
| 135 | + def udtVectorMLLib: DataType = udtMLLib |
| 136 | +} |
0 commit comments