diff --git a/src/com/wotbrew/cinq/eager_loop.clj b/src/com/wotbrew/cinq/eager_loop.clj index 1fba054..828790b 100644 --- a/src/com/wotbrew/cinq/eager_loop.clj +++ b/src/com/wotbrew/cinq/eager_loop.clj @@ -646,10 +646,11 @@ (for [v agg-bindings] (update v 1 (partial mapv #(into [(vswap! acc-index inc)] %)))))) -(defn emit-group-project-all [ra agg-bindings new-projection body] +(defn emit-group-let-all [ra agg-bindings new-projection body] (let [arr (gensym "arr") agg-bindings (assign-agg-binding-indexes agg-bindings) - acc-count (reduce + 0 (map (fn [[_ agg]] (count agg)) agg-bindings)) + acc-count (inc (reduce + 0 (map (fn [[_ agg]] (count agg)) agg-bindings))) + cnt-index (dec acc-count) acc-bindings (for [[_ agg] agg-bindings [i sym _] agg form [sym `(aget ~arr ~i)]] @@ -658,12 +659,15 @@ ~@(for [[_ agg] agg-bindings [i _ init _] agg] `(aset ~arr ~i (RT/box ~init))) + (aset ~arr ~cnt-index (RT/box 0)) ~(emit-loop ra `(let [~@acc-bindings] ~@(for [[_ agg] agg-bindings [i _ _ expr] agg] `(aset ~arr ~i (RT/box ~(rewrite-expr [ra] expr)))) + (aset ~arr ~cnt-index (unchecked-inc (aget ~arr ~cnt-index))) nil)) (let [~@acc-bindings + ~plan/%count-sym (aget ~arr ~cnt-index) ~@(for [[sym _ completion] agg-bindings form [sym (rewrite-expr [] completion)]] form) @@ -672,15 +676,16 @@ form)] ~body)))) -(defn emit-group-project [ra bindings agg-bindings new-projection body] +(defn emit-group-let [ra bindings agg-bindings new-projection body] (if (empty? bindings) - (emit-group-project-all ra agg-bindings new-projection body) + (emit-group-let-all ra agg-bindings new-projection body) ;; group bindings (let [k (t/key-local (mapv second bindings)) arr (with-meta (gensym "arr") {:tag 'objects}) ht (gensym "ht") agg-bindings (assign-agg-binding-indexes agg-bindings) - acc-count (reduce + 0 (map (fn [[_ agg]] (count agg)) agg-bindings)) + acc-count (inc (reduce + 0 (map (fn [[_ agg]] (count agg)) agg-bindings))) + cnt-index (dec acc-count) acc-bindings (for [[_ agg] agg-bindings [i sym _] agg form [sym `(aget ~arr ~i)]] @@ -694,11 +699,13 @@ (let [~arr (or arr# (doto (object-array ~acc-count) ~@(for [[_ agg] agg-bindings [i _ init] agg] - `(aset ~i (RT/box ~(rewrite-expr [] init)))))) + `(aset ~i (RT/box ~(rewrite-expr [] init)))) + (aset ~cnt-index (RT/box 0)))) ~@acc-bindings] ~@(for [[_ agg] agg-bindings [i _ _ expr] agg] `(aset ~arr ~i (RT/box ~(rewrite-expr [ra] expr)))) + (aset ~arr ~cnt-index (unchecked-inc (aget ~arr ~cnt-index))) ~arr)))] (.compute ~ht ~k f#) nil)) @@ -706,6 +713,7 @@ (fn [_# [~k ~arr]] (let [~@(t/emit-key-bindings k (map first bindings)) ~@acc-bindings + ~plan/%count-sym (aget ~arr ~cnt-index) ~@(for [[sym _ completion] agg-bindings form [sym (rewrite-expr [] completion)]] form) @@ -846,8 +854,8 @@ (emit-group-all ?ra body) (emit-group-by ?ra ?bindings body)) - [::plan/group-project ?ra ?bindings ?aggs ?new-projection] - (emit-group-project ?ra ?bindings ?aggs ?new-projection body) + [::plan/group-let ?ra ?bindings ?aggs ?new-projection] + (emit-group-let ?ra ?bindings ?aggs ?new-projection body) [::plan/order-by ?ra ?order-clauses] (emit-order-by ?ra ?order-clauses body) diff --git a/src/com/wotbrew/cinq/plan.clj b/src/com/wotbrew/cinq/plan.clj index c9d74ce..fa7fba5 100644 --- a/src/com/wotbrew/cinq/plan.clj +++ b/src/com/wotbrew/cinq/plan.clj @@ -8,7 +8,7 @@ (:import (clojure.lang IRecord) (com.wotbrew.cinq CinqUtil) (com.wotbrew.cinq.column Column DoubleColumn LongColumn) - (java.lang.reflect Field))) + (java.lang.reflect Field)(java.util HashSet))) (declare arity column-map columns) @@ -181,8 +181,8 @@ [::project ?ra ?bindings] (mapv first ?bindings) - [::group-project ?ra ?bindings ?aggregates ?projection] - (mapv first ?projection) + [::group-let ?ra ?bindings ?aggregates ?projection] + (conj (into (mapv first ?bindings) (mapv first ?projection)) %count-sym) [::apply :left-join ?left ?right] (into (columns ?left) (mapv optional-tag (columns ?right))) @@ -717,7 +717,7 @@ ;; can permit faster aggregation where grouped columns do not need to be materialized and multiple aggregates ;; can be computed in one loop ;; where columns do not leak out of the projection -(def ^:dynamic *group-project-fusion* false) +(def ^:dynamic *group-let-fusion* true) (defn infer-type [cols expr] (let [col-types (zipmap cols (map (comp :tag meta) cols)) @@ -772,7 +772,7 @@ (m/match expr [::count] [`0] [::count ?expr] [(zero ?expr)] - [::count-distinct ?expr] [(zero ?expr)] + [::count-distinct ?expr] [(zero ?expr) `(HashSet.)] [::sum ?expr] [(zero ?expr)] [::avg ?expr] [(zero ?expr) 0] [::min ?expr] [nil] @@ -785,12 +785,18 @@ (m/match expr [::count] [`(unchecked-inc ~acc-sym)] [::count ?expr] [`(if ~?expr (unchecked-inc ~acc-sym) ~acc-sym)] - [::count-distinct ?expr] (throw (ex-info "Compile error: unexpected count-distinct in aggregate reduction" {})) + [::count-distinct ?expr] + (let [[count-sym hashset-sym] acc-syms + hashset-sym (with-meta hashset-sym {:tag `HashSet})] + [`(let [e# ~?expr] + (cond (nil? e#) ~count-sym + (.add ~hashset-sym e#) (unchecked-inc ~count-sym) + :else ~count-sym)) + hashset-sym]) [::sum ?expr] [`(CinqUtil/sumStep ~acc-sym ~?expr)] [::avg ?expr] [`(CinqUtil/sumStep ~acc-sym ~?expr) `(unchecked-inc ~(second acc-syms))] [::min ?expr] [`(CinqUtil/minStep ~acc-sym ~?expr)] [::max ?expr] [`(CinqUtil/maxStep ~acc-sym ~?expr)] - ;; todo min/max _ (throw (ex-info "Unknown aggregate" {:expr expr}))))) (defn aggregate-completion [acc-syms expr] @@ -801,7 +807,8 @@ (defn aggregate? [expr] (and (vector? expr) - (contains? aggregate-keywords (nth expr 0 nil)))) + (contains? aggregate-keywords (nth expr 0 nil)) + (not= [::count] expr))) (defn hoist-aggregates [group-columns projection-bindings] (let [smap (atom {}) @@ -815,6 +822,18 @@ (when no-leakage [aggregates new-projections]))) +(def project-let + (r/match + [::project ?ra ?projection] + (let [new-sym (memoize *gensym*)] + [::project + [::let ?ra (mapv (fn [[sym expr]] [(new-sym sym) expr]) ?projection)] + (mapv + (fn [[sym]] [sym (new-sym sym)]) + ?projection)]))) + +(def rewrite-project-let (r/bottom-up (r/attempt #'project-let))) + (def fuse (r/match [::where ?ra true] @@ -823,11 +842,17 @@ [::where [::where ?ra ?pred-a] ?pred-b] [::where ?ra (conjoin-predicates ?pred-a ?pred-b)] + [::let [::let ?ra ?a] ?b] + [::let ?ra (into ?a ?b)] + + [::let [::order-by ?ra ?clauses] ?bindings] + [::order-by [::let ?ra ?bindings] ?clauses] + ;; it might be better to have something like ::group-let and an ana pass ;; this only works in a tiny subset of occasions ;; to determine whether group columns leak - (m/and [::project [::group-by ?ra ?bindings] ?projection] - (m/guard *group-project-fusion*)) + (m/and [::let [::group-by ?ra ?bindings] ?projection] + (m/guard *group-let-fusion*)) (let [;; filter out shadowed group columns group-columns (filterv (complement (set (map first ?bindings))) (columns ?ra)) [agg-bindings new-projection :as no-leakage] (hoist-aggregates group-columns ?projection) @@ -837,8 +862,8 @@ exprs (aggregate-reduction acc-syms agg)]] [sym (mapv vector acc-syms inits exprs) (aggregate-completion acc-syms agg)])] (if no-leakage - [::group-project ?ra ?bindings (vec agg-bindings) new-projection] - [::project [::group-by ?ra ?bindings] ?projection])))) + [::group-let ?ra ?bindings (vec agg-bindings) new-projection] + [::let [::group-by ?ra ?bindings] ?projection])))) (def rewrite-logical (-> #'rewrites @@ -1044,6 +1069,7 @@ rewrite-logical push-lookups-sub-queries push-lookups + rewrite-project-let rewrite-fuse rewrite-join-collect rewrite-join-order @@ -1182,7 +1208,7 @@ ?bindings))] ;; todo group-project - #_#_[::group-project ?ra ?bindings ?aggregates ?projection] + #_#_[::group-let ?ra ?bindings ?aggregates ?projection] nil [::apply ?mode ?left ?right] diff --git a/test/com/wotbrew/cinq/unnesting_test.clj b/test/com/wotbrew/cinq/unnesting_test.clj index e32f017..3dc2aef 100644 --- a/test/com/wotbrew/cinq/unnesting_test.clj +++ b/test/com/wotbrew/cinq/unnesting_test.clj @@ -27,8 +27,13 @@ l2__2:orderkey)] [:without #{l1__3:orderkey}] + [:let + [[col__1__5 + l1__3]]] + [:without + #{l1__3}] [:project - {col__1 l1__3}]] + {col__1 col__1__5}]] (c/plan (c/q [l1 lineitem :when (c/exists? [l2 lineitem :when (= l1:orderkey l2:orderkey)])] l1)))))