Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add brackets if needed for hint "Avoid lambda" #1634

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 104 additions & 92 deletions src/GHC/Util/HsExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ niceDotApp a b = dotApp a b

-- Generate a lambda expression but prettier if possible.
niceLambda :: [String] -> LHsExpr GhcPs -> LHsExpr GhcPs
niceLambda ss e = fst (niceLambdaR ss e)-- We don't support refactorings yet.
niceLambda ss e = fst (niceLambdaR Nothing ss e)-- We don't support refactorings yet.

allowRightSection :: String -> Bool
allowRightSection x = x `notElem` ["-","#"]
Expand All @@ -150,99 +150,111 @@ allowLeftSection x = x /= "#"

-- Implementation. Try to produce special forms (e.g. sections,
-- compositions) where we can.
niceLambdaR :: [String]
-> LHsExpr GhcPs
niceLambdaR :: Maybe (LHsExpr GhcPs) -- parent expression
-> [String]
-> LHsExpr GhcPs -- the expression being processed
-> (LHsExpr GhcPs, R.SrcSpan -> [Refactoring R.SrcSpan])
-- Rewrite @\ -> e@ as @e@
-- These are encountered as recursive calls.
niceLambdaR xs (SimpleLambda [] x) = niceLambdaR xs x

-- Rewrite @\xs -> (e)@ as @\xs -> e@.
niceLambdaR xs (L _ (HsPar _ x)) = niceLambdaR xs x

-- @\vs v -> ($) e v@ ==> @\vs -> e@
-- @\vs v -> e $ v@ ==> @\vs -> e@
niceLambdaR (unsnoc -> Just (vs, v)) (view -> App2 f e (view -> Var_ v'))
| isDol f
, v == v'
, vars e `disjoint` [v]
= niceLambdaR vs e

-- @\v -> thing + v@ ==> @\v -> (thing +)@ (heuristic: @v@ must be a single
-- lexeme, or it all gets too complex)
niceLambdaR [v] (L _ (OpApp _ e f (view -> Var_ v')))
| isLexeme e
, v == v'
, vars e `disjoint` [v]
, L _ (HsVar _ (L _ fname)) <- f
, isSymOcc $ rdrNameOcc fname
= let res = nlHsPar $ noLocA $ SectionL noExtField e f
in (res, \s -> [Replace Expr s [] (unsafePrettyPrint res)])

-- @\vs v -> f x v@ ==> @\vs -> f x@
niceLambdaR (unsnoc -> Just (vs, v)) (L _ (HsApp _ f (view -> Var_ v')))
| v == v'
, vars f `disjoint` [v]
= niceLambdaR vs f

-- @\vs v -> (v `f`)@ ==> @\vs -> f@
niceLambdaR (unsnoc -> Just (vs, v)) (L _ (SectionL _ (view -> Var_ v') f))
| v == v' = niceLambdaR vs f

-- Strip one variable pattern from the end of a lambdas match, and place it in our list of factoring variables.
niceLambdaR xs (SimpleLambda ((view -> PVar_ v):vs) x)
| v `notElem` xs = niceLambdaR (xs++[v]) $ lambda vs x

-- Rewrite @\x -> x + a@ as @(+ a)@ (heuristic: @a@ must be a single
-- lexeme, or it all gets too complex).
niceLambdaR [x] (view -> App2 op@(L _ (HsVar _ (L _ tag))) l r)
| isLexeme r, view l == Var_ x, x `notElem` vars r, allowRightSection (occNameStr tag) =
let e = rebracket1 $ addParen (noLocA $ SectionR noExtField op r)
in (e, \s -> [Replace Expr s [] (unsafePrettyPrint e)])
-- Rewrite (1) @\x -> f (b x)@ as @f . b@, (2) @\x -> f $ b x@ as @f . b@.
niceLambdaR [x] y
| Just (z, subts) <- factor y, x `notElem` vars z = (z, \s -> [mkRefact subts s])
niceLambdaR parent = go
where
-- Factor the expression with respect to x.
factor :: LHsExpr GhcPs -> Maybe (LHsExpr GhcPs, [LHsExpr GhcPs])
factor (L _ (HsApp _ ini lst)) | view lst == Var_ x = Just (ini, [ini])
factor (L _ (HsApp _ ini lst)) | Just (z, ss) <- factor lst
= let r = niceDotApp ini z
in if astEq r z then Just (r, ss) else Just (r, ini : ss)
factor (L _ (OpApp _ y op (factor -> Just (z, ss))))| isDol op
= let r = niceDotApp y z
in if astEq r z then Just (r, ss) else Just (r, y : ss)
factor (L _ (HsPar _ y@(L _ HsApp{}))) = factor y
factor _ = Nothing
mkRefact :: [LHsExpr GhcPs] -> R.SrcSpan -> Refactoring R.SrcSpan
mkRefact subts s =
let tempSubts = zipWith (\a b -> (a, toSSA b)) substVars subts
template = dotApps (map (strToVar . fst) tempSubts)
in Replace Expr s tempSubts (unsafePrettyPrint template)
-- Rewrite @\x y -> x + y@ as @(+)@.
niceLambdaR [x,y] (L _ (OpApp _ (view -> Var_ x1) op@(L _ HsVar {}) (view -> Var_ y1)))
| x == x1, y == y1, vars op `disjoint` [x, y] = (op, \s -> [Replace Expr s [] (unsafePrettyPrint op)])
-- Rewrite @\x y -> f y x@ as @flip f@.
niceLambdaR [x, y] (view -> App2 op (view -> Var_ y1) (view -> Var_ x1))
| x == x1, y == y1, vars op `disjoint` [x, y] =
( gen op
, \s -> [Replace Expr s [("x", toSSA op)] (unsafePrettyPrint $ gen (strToVar "x"))]
)
where
gen :: LHsExpr GhcPs -> LHsExpr GhcPs
gen = noLocA . HsApp noExtField (strToVar "flip")
. if isAtom op then id else addParen

-- We're done factoring, but have no variables left, so we shouldn't make a lambda.
-- @\ -> e@ ==> @e@
niceLambdaR [] e = (e, \s -> [Replace Expr s [("a", toSSA e)] "a"])
-- Base case. Just a good old fashioned lambda.
niceLambdaR ss e =
let grhs = noLocA $ GRHS noAnn [] e :: LGRHS GhcPs (LHsExpr GhcPs)
grhss = GRHSs {grhssExt = emptyComments, grhssGRHSs=[grhs], grhssLocalBinds=EmptyLocalBinds noExtField}
match = noLocA $ Match {m_ext=noExtField, m_ctxt=LamAlt LamSingle, m_pats=noLocA $ map strToPat ss, m_grhss=grhss} :: LMatch GhcPs (LHsExpr GhcPs)
matchGroup = MG {mg_ext=Generated OtherExpansion SkipPmc, mg_alts=noLocA [match]}
in (noLocA $ HsLam noAnn LamSingle matchGroup, const [])
-- Rewrite @\ -> e@ as @e@
-- These are encountered as recursive calls.
go xs (SimpleLambda [] x) = go xs x

-- Rewrite @\xs -> (e)@ as @\xs -> e@.
go xs (L _ (HsPar _ x)) = go xs x

-- @\vs v -> ($) e v@ ==> @\vs -> e@
-- @\vs v -> e $ v@ ==> @\vs -> e@
go (unsnoc -> Just (vs, v)) (view -> App2 f e (view -> Var_ v'))
| isDol f
, v == v'
, vars e `disjoint` [v]
= go vs e

-- @\v -> thing + v@ ==> @\v -> (thing +)@ (heuristic: @v@ must be a single
-- lexeme, or it all gets too complex)
go [v] (L _ (OpApp _ e f (view -> Var_ v')))
| isLexeme e
, v == v'
, vars e `disjoint` [v]
, L _ (HsVar _ (L _ fname)) <- f
, isSymOcc $ rdrNameOcc fname
= let res = nlHsPar $ noLocA $ SectionL noExtField e f
in (res, \s -> [Replace Expr s [] (unsafePrettyPrint res)])

-- @\vs v -> f x v@ ==> @\vs -> f x@
go (unsnoc -> Just (vs, v)) (L _ (HsApp _ f (view -> Var_ v')))
| v == v'
, vars f `disjoint` [v]
= go vs f

-- @\vs v -> (v `f`)@ ==> @\vs -> f@
go (unsnoc -> Just (vs, v)) (L _ (SectionL _ (view -> Var_ v') f))
| v == v' = go vs f

-- Strip one variable pattern from the end of a lambdas match, and place it in our list of factoring variables.
go xs (SimpleLambda ((view -> PVar_ v):vs) x)
| v `notElem` xs = go (xs++[v]) $ lambda vs x

-- Rewrite @\x -> x + a@ as @(+ a)@ (heuristic: @a@ must be a single
-- lexeme, or it all gets too complex).
go [x] (view -> App2 op@(L _ (HsVar _ (L _ tag))) l r)
| isLexeme r, view l == Var_ x, x `notElem` vars r, allowRightSection (occNameStr tag) =
let e = rebracket1 $ addParen (noLocA $ SectionR noExtField op r)
in (e, \s -> [Replace Expr s [] (unsafePrettyPrint e)])
-- Rewrite (1) @\x -> f (b x)@ as @f . b@, (2) @\x -> f $ b x@ as @f . b@.
go [x] y
| Just (z, subts) <- factor y, x `notElem` vars z = (z, \s -> [mkRefact subts s])
where
-- Factor the expression with respect to x.
factor :: LHsExpr GhcPs -> Maybe (LHsExpr GhcPs, [LHsExpr GhcPs])
factor (L _ (HsApp _ ini lst)) | view lst == Var_ x = Just (ini, [ini])
factor (L _ (HsApp _ ini lst)) | Just (z, ss) <- factor lst
= let r = niceDotApp ini z
in if astEq r z then Just (r, ss) else Just (r, ini : ss)
factor (L _ (OpApp _ y op (factor -> Just (z, ss))))| isDol op
= let r = niceDotApp y z
in if astEq r z then Just (r, ss) else Just (r, y : ss)
factor (L _ (HsPar _ y@(L _ HsApp{}))) = factor y
factor _ = Nothing
mkRefact :: [LHsExpr GhcPs] -> R.SrcSpan -> Refactoring R.SrcSpan
mkRefact subts s =
let tempSubts = zipWith (\a b -> (a, toSSA b)) substVars subts
template = dotApps (map (strToVar . fst) tempSubts)
in Replace Expr s tempSubts (unsafePrettyPrint template)
-- Rewrite @\x y -> x + y@ as @(+)@.
go [x,y] (L _ (OpApp _ (view -> Var_ x1) op@(L _ HsVar {}) (view -> Var_ y1)))
| x == x1, y == y1, vars op `disjoint` [x, y] = (op, \s -> [Replace Expr s [] (unsafePrettyPrint op)])
-- Rewrite @\x y -> f y x@ as @flip f@.
go [x, y] (view -> App2 op (view -> Var_ y1) (view -> Var_ x1))
| x == x1, y == y1, vars op `disjoint` [x, y] =
( gen op
, \s -> [Replace Expr s [("x", toSSA op)] (unsafePrettyPrint $ gen (strToVar "x"))]
)
where
gen :: LHsExpr GhcPs -> LHsExpr GhcPs
gen = noLocA . HsApp noExtField (strToVar "flip")
. if isAtom op then id else addParen

-- We're done factoring, but have no variables left, so we shouldn't make a lambda.
-- @\ -> e@ ==> @e@
go [] e =
let -- Add brackets if needed, primarily for handling BlockArguments.
-- e.g., parent = `f \x -> g 3 x`; e = `g 3`.
-- Brackets should be placed around `e` to produce `f (g 3)` instead of `f g 3`.
addBrackets = case parent of
Just p -> isApp p && not (isVar e)
Nothing -> False
e' = if addBrackets then mkHsPar e else e
tpl = if addBrackets then "(a)" else "a"
in (e', \s -> [Replace Expr s [("a", toSSA e)] tpl])
-- Base case. Just a good old fashioned lambda.
go ss e =
let grhs = noLocA $ GRHS noAnn [] e :: LGRHS GhcPs (LHsExpr GhcPs)
grhss = GRHSs {grhssExt = emptyComments, grhssGRHSs=[grhs], grhssLocalBinds=EmptyLocalBinds noExtField}
match = noLocA $ Match {m_ext=noExtField, m_ctxt=LamAlt LamSingle, m_pats=noLocA $ map strToPat ss, m_grhss=grhss} :: LMatch GhcPs (LHsExpr GhcPs)
matchGroup = MG {mg_ext=Generated OtherExpansion SkipPmc, mg_alts=noLocA [match]}
in (noLocA $ HsLam noAnn LamSingle matchGroup, const [])


-- 'case' and 'if' expressions have branches, nothing else does (this
Expand Down
6 changes: 5 additions & 1 deletion src/Hint/Lambda.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ f = foo (\y -> g x . h $ y) -- g x . h
f = foo (\y -> g x . h $ y) -- @Message Avoid lambda
f = foo ((*) x) -- (x *)
f = (*) x
f = g \x -> h 3 x -- (h 3)
f = g (\x -> h 3 x) -- h 3
f = g \x -> (`h` 3) x -- (`h` 3)
f = g \x -> h x -- h
f = foo (flip op x) -- (`op` x)
f = foo (flip op x) -- @Message Use section
f = foo (flip x y) -- (`x` y)
Expand Down Expand Up @@ -217,7 +221,7 @@ lambdaExp _ o@(L _ (HsPar _ (view -> App2 (view -> Var_ "flip") origf@(view -> R

lambdaExp p o@(L _ (HsLam _ LamSingle _))
| not $ any isOpApp p
, (res, refact) <- niceLambdaR [] o
, (res, refact) <- niceLambdaR p [] o
, not $ isLambda res
, not $ any isQuasiQuoteExpr $ universe res
, not $ "runST" `Set.member` Set.map occNameString (freeVars o)
Expand Down
Loading