Skip to content

Commit c4fb3ac

Browse files
committed
Suport for bigdec in arrow.
1 parent 577f103 commit c4fb3ac

File tree

4 files changed

+168
-45
lines changed

4 files changed

+168
-45
lines changed

scripts/arrow_decimal.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pyarrow as pa
2+
import uuid as uuid
3+
4+
schema = pa.schema([pa.field('id', pa.decimal128(5, 2))])
5+
data = [1, 0, 2]
6+
table = pa.Table.from_arrays([data], schema=schema)
7+
8+
9+
with pa.OSFile('test/data/bigdec.arrow', 'wb') as sink:
10+
with pa.ipc.new_file(sink, schema=schema) as writer:
11+
batch = pa.record_batch([data], schema=schema)
12+
writer.write(batch)
13+
14+
with pa.memory_map('test/data/bigdec.arrow', 'r') as source:
15+
loaded_arrays = pa.ipc.open_file(source).read_all()
16+
17+
print(loaded_arrays[0])

src/tech/v3/libs/arrow.clj

Lines changed: 147 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
[tech.v3.datatype.array-buffer :as array-buffer]
8989
[tech.v3.datatype.ffi :as dt-ffi]
9090
[tech.v3.dataset.impl.column :as col-impl]
91+
[tech.v3.dataset.column :as ds-col]
9192
[tech.v3.dataset.protocols :as ds-proto]
9293
[tech.v3.dataset.impl.dataset :as ds-impl]
9394
[tech.v3.dataset.dynamic-int-list :as dyn-int-list]
@@ -119,7 +120,8 @@
119120
[org.apache.arrow.vector.types.pojo Field Schema ArrowType$Int
120121
ArrowType$Utf8 ArrowType$Timestamp ArrowType$Time DictionaryEncoding FieldType
121122
ArrowType$FloatingPoint ArrowType$Bool ArrowType$Date ArrowType$Duration
122-
ArrowType$LargeUtf8 ArrowType$Null ArrowType$List ArrowType$Binary ArrowType$FixedSizeBinary]
123+
ArrowType$LargeUtf8 ArrowType$Null ArrowType$List ArrowType$Binary ArrowType$FixedSizeBinary
124+
ArrowType$Decimal]
123125
[org.apache.arrow.flatbuf CompressionType]
124126
[org.apache.arrow.vector.types MetadataVersion]
125127
[org.apache.arrow.vector.ipc WriteChannel]
@@ -132,7 +134,7 @@
132134
[java.io OutputStream InputStream ByteArrayOutputStream ByteArrayInputStream]
133135
[java.nio ByteBuffer ByteOrder ShortBuffer IntBuffer LongBuffer DoubleBuffer
134136
FloatBuffer]
135-
[java.util List ArrayList Map HashMap Map$Entry Iterator Set UUID]
137+
[java.util List ArrayList Map HashMap Map$Entry Iterator Set UUID Arrays]
136138
[java.util.concurrent ForkJoinTask]
137139
[java.time ZoneId]
138140
[java.nio.channels WritableByteChannel]
@@ -438,6 +440,11 @@ Dependent block frames are not supported!!")
438440
ArrowType$FixedSizeBinary
439441
(datafy [this] {:datatype :fixed-size-binary
440442
:byte-width (.getByteWidth this)})
443+
ArrowType$Decimal
444+
(datafy [this] {:datatype :decimal
445+
:scale (.getScale this)
446+
:precision (.getPrecision this)
447+
:bit-width (.getBitWidth this)})
441448
ArrowType$List
442449
(datafy [this]
443450
{:datatype :list}))
@@ -664,18 +671,24 @@ Dependent block frames are not supported!!")
664671
(map (fn [[k v]] [(json/write-json-str k) (json/write-json-str v)]))
665672
(into {})))
666673

667-
668-
(defonce ^:private uuid-warn-counter (atom 0))
669-
674+
(def ^{:private true
675+
:tag String} ARROW_EXTENSION_NAME "ARROW:extension:name")
676+
(def ^{:private true
677+
:tag String} ARROW_UUID_NAME "arrow.uuid")
670678

671679
(defn- datatype->field-type
672680
(^FieldType [datatype & [nullable? metadata extra-data]]
673681
(let [nullable? (or nullable? (= :object (casting/flatten-datatype datatype)))
674682
metadata (->str-str-meta (dissoc metadata
675683
:name :datatype :categorical?
676-
::previous-string-table) )
684+
::previous-string-table
685+
::complex-datatype))
677686
ft-fn (fn [arrow-type & [dict-encoding]]
678687
(field-type nullable? arrow-type dict-encoding metadata))
688+
complex-datatype datatype
689+
datatype (if (map? complex-datatype)
690+
(get complex-datatype :datatype)
691+
datatype)
679692
datatype (packing/unpack-datatype datatype)]
680693
(case (if #{:epoch-microseconds :epoch-milliseconds
681694
:epoch-days}
@@ -697,7 +710,7 @@ Dependent block frames are not supported!!")
697710
:instant (ft-fn (ArrowType$Timestamp. TimeUnit/MICROSECOND
698711
(str (:timezone extra-data))))
699712
:epoch-microseconds (ft-fn (ArrowType$Timestamp. TimeUnit/MICROSECOND
700-
(str (:timezone extra-data))))
713+
(str (:timezone extra-data))))
701714
:epoch-nanoseconds (ft-fn (ArrowType$Timestamp. TimeUnit/NANOSECOND
702715
(str (:timezone extra-data))))
703716
:epoch-days (ft-fn (ArrowType$Date. DateUnit/DAY))
@@ -713,6 +726,9 @@ Dependent block frames are not supported!!")
713726
(ft-fn (ArrowType$Utf8.) encoding)
714727
;;If no encoding is provided then just save the string as text
715728
(ft-fn (ArrowType$Utf8.)))
729+
:decimal (ft-fn (ArrowType$Decimal. (unchecked-int (get complex-datatype :precision))
730+
(unchecked-int (get complex-datatype :scale))
731+
(unchecked-int (get complex-datatype :bit-width))))
716732
:uuid (ft-fn (ArrowType$FixedSizeBinary. 16))
717733
:text (ft-fn (ArrowType$Utf8.))
718734
:encoded-text (ft-fn (ArrowType$Utf8.))))))
@@ -728,7 +744,7 @@ Dependent block frames are not supported!!")
728744
nullable? (boolean
729745
(or (:nullable? colmeta)
730746
(not (empty? (ds-proto/missing col)))))
731-
col-dtype (:datatype colmeta)
747+
col-dtype (or (::complex-datatype colmeta) (:datatype colmeta))
732748
colname (:name colmeta)
733749
extra-data (merge (select-keys (meta col) [:timezone])
734750
(when (and (not strings-as-text?)
@@ -1240,6 +1256,24 @@ Dependent block frames are not supported!!")
12401256
(.putLong wbuf 0) (.putLong wbuf 0))))
12411257
nil cbuf)
12421258
[(java.nio.ByteBuffer/wrap data)])
1259+
:decimal (let [colmeta (meta col)
1260+
{:keys [scale precision bit-width]} (get colmeta ::complex-datatype)
1261+
byte-width (quot (+ (long bit-width) 7) 8)
1262+
ne (.lsize cbuf)
1263+
byte-data (byte-array (* ne byte-width))
1264+
le? (identical? :little-endian (tech.v3.datatype.protocols/platform-endianness))]
1265+
(dotimes [idx ne]
1266+
(when-let [^BigDecimal d (.readObject cbuf idx)]
1267+
(let [^BigInteger bb (.unscaledValue d)
1268+
bb-bytes (.toByteArray bb)
1269+
offset (* idx byte-width)]
1270+
(if le?
1271+
(let [bb-len (alength bb-bytes)]
1272+
(dotimes [bidx bb-len]
1273+
(let [write-pos (+ offset (- bb-len bidx 1))]
1274+
(ArrayHelpers/aset byte-data write-pos (aget bb-bytes bidx)))))
1275+
(System/arraycopy bb-bytes 0 byte-data 0 (alength bb-bytes))))))
1276+
[(java.nio.ByteBuffer/wrap byte-data)])
12431277
:string (let [str-t (ds-base/ensure-column-string-table col)
12441278
indices (dtype-proto/->array-buffer (str-table/indices str-t))]
12451279
[(nio-buffer/as-nio-buffer indices)])
@@ -1648,11 +1682,6 @@ Dependent block frames are not supported!!")
16481682
(field-metadata field)
16491683
(node-buf->missing node validity-buf))))))
16501684

1651-
(def ^{:private true
1652-
:tag String} ARROW_EXTENSION_NAME "ARROW:extension:name")
1653-
(def ^{:private true
1654-
:tag String} ARROW_UUID_NAME "arrow.uuid")
1655-
16561685
(defmethod ^:private preparse-field :fixed-size-binary
16571686
[field ^Iterator node-iter ^Iterator buf-iter dict-map options]
16581687
(let [node (.next node-iter)
@@ -1671,19 +1700,58 @@ Dependent block frames are not supported!!")
16711700
(if (= ARROW_UUID_NAME (get fm ARROW_EXTENSION_NAME))
16721701
(let [longsdata (-> (java.nio.ByteBuffer/wrap data-ary)
16731702
(.order (java.nio.ByteOrder/BIG_ENDIAN)))]
1674-
(println "is uuid")
16751703
(dtype/make-reader :uuid n-elems
16761704
(let [lidx (* idx 16)]
16771705
(java.util.UUID. (.getLong longsdata lidx)
16781706
(.getLong longsdata (+ lidx 8))))))
16791707
(let [ll (ArrayLists/toList data-ary)]
1680-
(println "is obj")
16811708
(dtype/make-reader :object n-elems
16821709
(let [lidx (* idx field-width)]
16831710
(.subList ll lidx (+ lidx field-width))))))
16841711
fm
16851712
(node-buf->missing node validity-buf))))))
16861713

1714+
(defn- copy-bytes
1715+
^bytes [^bytes data ^long sidx ^long eidx]
1716+
(Arrays/copyOfRange data sidx eidx))
1717+
1718+
(defn- copy-reverse-bytes
1719+
^bytes [^bytes data ^long sidx ^long eidx]
1720+
(let [ne (- eidx sidx)
1721+
rv (byte-array ne)]
1722+
(loop [idx 0]
1723+
(when (< idx ne)
1724+
(ArrayHelpers/aset rv idx (aget data (- eidx idx 1)))
1725+
(recur (inc idx))))
1726+
rv))
1727+
1728+
(defmethod ^:private preparse-field :decimal
1729+
[field ^Iterator node-iter ^Iterator buf-iter dict-map options]
1730+
(let [node (.next node-iter)
1731+
buffers [(.next buf-iter) (.next buf-iter)]
1732+
n-elems (long (:n-elems node))
1733+
{:keys [^long precision ^long scale ^long bit-width]} (get field :field-type)
1734+
byte-width (quot (+ bit-width 7) 8)]
1735+
(fn parse-decimal
1736+
[decompressor]
1737+
(let [[validity-buf data-buf] (decompressor buffers)
1738+
^bytes data-ary (if (instance? NativeBuffer data-buf)
1739+
(native-buffer/->jvm-array data-buf 0 (dtype/ecount data-buf))
1740+
(dtype/->array data-buf))
1741+
;;biginteger data is always stored big endian I guess...
1742+
;;https://github.com/apache/arrow-java/blob/main/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java#L53
1743+
array-copy (if (identical? :little-endian (tech.v3.datatype.protocols/platform-endianness))
1744+
copy-reverse-bytes
1745+
copy-bytes)]
1746+
(col-impl/new-column
1747+
(:name field)
1748+
(dtype/make-reader :decimal n-elems
1749+
(let [idx (* idx byte-width)]
1750+
(-> (BigInteger. ^bytes (array-copy data-ary idx (+ idx byte-width)))
1751+
(BigDecimal. scale))))
1752+
(field-metadata field)
1753+
(node-buf->missing node validity-buf))))))
1754+
16871755

16881756
(defmethod ^:private preparse-field :default
16891757
[field ^Iterator node-iter ^Iterator buf-iter dict-map options]
@@ -2112,6 +2180,36 @@ Please use stream->dataset-seq.")))
21122180
:text
21132181
datatype)))
21142182

2183+
(defn decimal-column-metadata
2184+
[col]
2185+
(let [[scale precision bit-width]
2186+
(reduce (fn [[scale precision bit-width] ^BigDecimal dec]
2187+
(let [ss (.scale dec)
2188+
pp (.precision dec)
2189+
bw (inc (.bitLength (.unscaledValue dec)))]
2190+
(when-not (nil? scale)
2191+
(when-not (== (long scale) ss)
2192+
(throw (RuntimeException. (str "column \"" (:name (meta col)) "\" has different scale than previous bigdecs
2193+
\texpected " scale " and got " ss)))))
2194+
;;smallest arrow java supports is 128 bit width
2195+
[ss
2196+
(max pp (long (or precision 1)))
2197+
(max (long (or bit-width 128)) bw)]))
2198+
[2 1 128]
2199+
col)
2200+
bit-width (long bit-width)
2201+
bit-width (if (> bit-width 128)
2202+
(do
2203+
(when (> bit-width 256)
2204+
(log/warn (str "Column \"" (:name (meta col)) "\" uses more bit-width than arrow supports:
2205+
\tMax supported - 256 - found - " bit-width)))
2206+
256)
2207+
bit-width)]
2208+
{:scale scale
2209+
:bit-width bit-width
2210+
:precision precision
2211+
:datatype :decimal}))
2212+
21152213

21162214
(defn ^:no-doc prepare-dataset-for-write
21172215
"Normalize schemas and convert datatypes to datatypes appropriate for arrow
@@ -2146,35 +2244,42 @@ Please use stream->dataset-seq.")))
21462244
;;datatypes
21472245
(reduce
21482246
(fn [ds col]
2149-
(if (and (= :string (dtype/elemwise-datatype col))
2150-
(not (:strings-as-text? options)))
2151-
(if (and (nil? prev-ds)
2152-
(instance? StringTable (.data ^Column col)))
2153-
ds
2154-
(let [missing (ds-proto/missing col)
2155-
metadata (meta col)]
2156-
(if (nil? prev-ds)
2157-
(assoc ds (metadata :name)
2158-
#:tech.v3.dataset{:data (tech.v3.dataset.base/ensure-column-string-table col)
2159-
:missing missing
2160-
:metadata metadata
2161-
:name (metadata :name)})
2162-
(let [prev-col (ds-base/column prev-ds (:name metadata))
2163-
prev-str-t (ds-base/ensure-column-string-table prev-col)
2164-
int->str (ArrayList. ^List (.int->str prev-str-t))
2165-
str->int (HashMap. ^Map (.str->int prev-str-t))
2166-
n-rows (dtype/ecount col)
2167-
data (StringTable. int->str str->int
2168-
(dyn-int-list/dynamic-int-list 0))]
2169-
(dotimes [idx n-rows]
2170-
(.add data (or (col idx) "")))
2247+
(let [col-dt (dtype/elemwise-datatype col)]
2248+
(cond
2249+
(and (identical? :string col-dt)
2250+
(not (:strings-as-text? options)))
2251+
(if (and (nil? prev-ds)
2252+
(instance? StringTable (.data ^Column col)))
2253+
ds
2254+
(let [missing (ds-proto/missing col)
2255+
metadata (meta col)]
2256+
(if (nil? prev-ds)
21712257
(assoc ds (metadata :name)
2172-
#:tech.v3.dataset{:data data
2258+
#:tech.v3.dataset{:data (tech.v3.dataset.base/ensure-column-string-table col)
21732259
:missing missing
2174-
:metadata (assoc metadata
2175-
::previous-string-table prev-str-t)
2176-
:name (metadata :name)})))))
2177-
ds))
2260+
:metadata metadata
2261+
:name (metadata :name)})
2262+
(let [prev-col (ds-base/column prev-ds (:name metadata))
2263+
prev-str-t (ds-base/ensure-column-string-table prev-col)
2264+
int->str (ArrayList. ^List (.int->str prev-str-t))
2265+
str->int (HashMap. ^Map (.str->int prev-str-t))
2266+
n-rows (dtype/ecount col)
2267+
data (StringTable. int->str str->int
2268+
(dyn-int-list/dynamic-int-list 0))]
2269+
(dotimes [idx n-rows]
2270+
(.add data (or (col idx) "")))
2271+
(assoc ds (metadata :name)
2272+
#:tech.v3.dataset{:data data
2273+
:missing missing
2274+
:metadata (assoc metadata
2275+
::previous-string-table prev-str-t)
2276+
:name (metadata :name)})))))
2277+
;;detect precision, scale and whether we need 128 or 256 bytes of accuracy
2278+
(identical? :decimal col-dt)
2279+
(assoc ds (ds-col/column-name col)
2280+
(vary-meta col assoc ::complex-datatype (decimal-column-metadata col)))
2281+
:else
2282+
ds)))
21782283
ds
21792284
(ds-base/columns ds))))
21802285

test/data/bigdec.arrow

506 Bytes
Binary file not shown.

test/tech/v3/libs/arrow_test.clj

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,13 @@
3131
:strings (map str (range n))
3232
:text (map (comp #(Text. %) str) (range n))
3333
:instants (repeatedly n dtype-dt/instant)
34+
:bigdec (repeatedly n #(BigDecimal/valueOf (+ 100 (rand-int 1700)) 2))
35+
;; :bigint (let [rng (java.util.Random.)]
36+
;; (repeatedly n #(BigInteger. 256 rng )))
3437
;;external formats often don't support dash-case
3538
:local_dates (repeatedly n dtype-dt/local-date)
3639
:local_times (repeatedly n dtype-dt/local-time)
37-
})
40+
:uuids (repeatedly n #(java.util.UUID/randomUUID))})
3841
(vary-meta assoc :name :testtable)))
3942
([]
4043
(supported-datatype-ds 10)))
@@ -304,14 +307,12 @@
304307
(is (= (vec (range (ds/row-count ds)))
305308
(vec (ds/missing (ds "nullcol")))))))
306309

307-
308310
(deftest list-datatypes-read-only
309311
(let [ds (ds/->dataset "test/data/arrow_list.arrow")]
310312
(is (= [["dog" "car"]
311313
["dog" "flower"]
312314
["car" "flower"]]
313315
(mapv vec (ds "class-name"))))))
314316

315-
316317
(deftest empty-array-dataset
317318
(is (nil? (arrow/stream->dataset "test/data/empty.arrow"))))

0 commit comments

Comments
 (0)