{-# LANGUAGE GADTs #-}

-- | 'Builder' for a case expression.
-- You can find an example usage [here](https://github.com/Arthi-chaud/template-haskell-natural/tree/main/examples/packed).
module Language.Haskell.TH.Natural.Syntax.Case (
    -- * Builder
    case_,
    CaseExprBuilder,

    -- * Match
    matchConst,
    matchWild,
    matchList,
    matchCon,

    -- * Pattern match on constructors

    -- ** Type
    ConMatchBuilder,
    ConPatternBuilder (..),
    field,
    body,

    -- ** Pattern Builder

    -- *** Types
    PatternBuilder,
    Pattern,

    -- *** Functions
    var,
    constant,
    constructor,

    -- * Reexports
    module Language.Haskell.TH.Natural.Syntax.Builder.Monad,
) where

import Control.Lens hiding (Empty)
import Control.Monad (replicateM)
import Data.Constructor.Extract
import Language.Haskell.TH (Exp)
import qualified Language.Haskell.TH as TH
import Language.Haskell.TH.Gen
import Language.Haskell.TH.Natural.Internal.Utils
import Language.Haskell.TH.Natural.Syntax.Builder
import qualified Language.Haskell.TH.Natural.Syntax.Builder as B
import Language.Haskell.TH.Natural.Syntax.Builder.Monad
import Language.Haskell.TH.Syntax.ExtractedCons hiding (body)

-- | A builder for the matches and branches in a case expression
type CaseExprBuilder = ConstBuilder CaseE

-- | takes an expression to pattern match on and a 'CaseExprBuilder' to produce a case expression
case_ :: (GenExpr b) => b -> CaseExprBuilder () -> TH.Q CaseE
case_ :: forall b. GenExpr b => b -> CaseExprBuilder () -> Q CaseE
case_ b
q CaseExprBuilder ()
builder = do
    e <- b -> Q Exp
forall a. GenExpr a => a -> Q Exp
genExpr b
q
    runBaseBuilder builder $ MkCaseE e []

-- | Match on a constant expression (e.g. a literal). The second argument is the body of the match.
matchConst :: (GenPat b1, GenExpr b2) => b1 -> b2 -> CaseExprBuilder ()
matchConst :: forall b1 b2.
(GenPat b1, GenExpr b2) =>
b1 -> b2 -> CaseExprBuilder ()
matchConst b1
b1 b2
b2 = do
    patt <- Q Pat -> BaseBuilder Q CaseE 'Ready 'Ready Pat
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q Pat -> BaseBuilder Q CaseE 'Ready 'Ready Pat)
-> Q Pat -> BaseBuilder Q CaseE 'Ready 'Ready Pat
forall a b. (a -> b) -> a -> b
$ b1 -> Q Pat
forall a. GenPat a => a -> Q Pat
genPat b1
b1
    e <- liftB $ genExpr b2
    matches |>= TH.Match patt (TH.NormalB e) []

-- | Match using a wildcard pattern. The argument is the body of the match.
matchWild :: (GenExpr b) => b -> CaseExprBuilder ()
matchWild :: forall b. GenExpr b => b -> CaseExprBuilder ()
matchWild b
b = do
    e <- Q Exp -> BaseBuilder Q CaseE 'Ready 'Ready Exp
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q Exp -> BaseBuilder Q CaseE 'Ready 'Ready Exp)
-> Q Exp -> BaseBuilder Q CaseE 'Ready 'Ready Exp
forall a b. (a -> b) -> a -> b
$ b -> Q Exp
forall a. GenExpr a => a -> Q Exp
genExpr b
b
    matches |>= TH.Match TH.WildP (TH.NormalB e) []

-- | Match using a constructor. The 'ConMatchBuilder' allow deconstructing and accessing the fields of the constructor
matchCon :: TH.Name -> ConMatchBuilder Empty Ready () -> CaseExprBuilder ()
matchCon :: Name -> ConMatchBuilder 'Empty 'Ready () -> CaseExprBuilder ()
matchCon Name
conName ConMatchBuilder 'Empty 'Ready ()
cmb = do
    fCount <- Q Int -> BaseBuilder Q CaseE 'Ready 'Ready Int
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q Int -> BaseBuilder Q CaseE 'Ready 'Ready Int)
-> Q Int -> BaseBuilder Q CaseE 'Ready 'Ready Int
forall a b. (a -> b) -> a -> b
$ Name -> Q Int
conFieldCount Name
conName
    (MkMBS conP mexp) <- liftB $ runBaseBuilder cmb (MkMBS (MkConP conName [] (TH.WildP <$ [0 .. fCount - 1])) Nothing)
    case mexp of
        Maybe Exp
Nothing -> String -> CaseExprBuilder ()
forall {k} (m :: * -> *) s (prev :: k) (curr :: k) a.
MonadFail m =>
String -> BaseBuilder m s prev curr a
B.fail String
"Match's Expression is missing"
        Just Exp
e -> ([Match] -> Identity [Match]) -> CaseE -> Identity CaseE
forall a b. HasMatches a b => Lens' a b
Lens' CaseE [Match]
matches (([Match] -> Identity [Match]) -> CaseE -> Identity CaseE)
-> Match -> CaseExprBuilder ()
forall s (m :: * -> *) b a.
(MonadState s m, Snoc b b a a) =>
ASetter s s b b -> a -> m ()
|>= Pat -> Body -> [Dec] -> Match
TH.Match (ConP -> Pat
forall con ty. ExtractedConstructor con ty => con -> ty
fromEC ConP
conP) (Exp -> Body
TH.NormalB Exp
e) []

-- | Match on a list of the given size.
-- The second argument is the body of the match, and its input is a list of 'VarE' bound to each item in the list
matchList :: (GenExpr b) => Int -> ([Exp] -> b) -> CaseExprBuilder ()
matchList :: forall b. GenExpr b => Int -> ([Exp] -> b) -> CaseExprBuilder ()
matchList Int
listSize [Exp] -> b
b = do
    fieldNames <- Q [Name] -> BaseBuilder Q CaseE 'Ready 'Ready [Name]
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q [Name] -> BaseBuilder Q CaseE 'Ready 'Ready [Name])
-> Q [Name] -> BaseBuilder Q CaseE 'Ready 'Ready [Name]
forall a b. (a -> b) -> a -> b
$ Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
listSize (Q Name -> Q [Name]) -> Q Name -> Q [Name]
forall a b. (a -> b) -> a -> b
$ String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName String
"_i"
    let fieldExpr = Name -> Exp
TH.VarE (Name -> Exp) -> [Name] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
fieldNames
        fieldPats = Name -> Pat
TH.VarP (Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
fieldNames
    e <- liftB $ genExpr $ b fieldExpr
    matches |>= TH.Match (TH.ListP fieldPats) (TH.NormalB e) []

type PatternBuilder = ConstBuilder ConP

data Pattern a where
    Var :: Pattern TH.Exp
    Constant :: (TH.Q TH.Pat) -> Pattern ()
    NestedMatch :: TH.Name -> (Int -> PatternBuilder a) -> Pattern a

-- | Allow binding a constructor's field to a name
var :: Pattern TH.Exp
var :: Pattern Exp
var = Pattern Exp
Var

-- | Pattern-match a constructor's field On a nested constructor
--
-- The second argument is invoked for each field in the constructor
constructor :: TH.Name -> (Int -> PatternBuilder a) -> Pattern a
constructor :: forall a. Name -> (Int -> PatternBuilder a) -> Pattern a
constructor = Name -> (Int -> PatternBuilder a) -> Pattern a
forall a. Name -> (Int -> PatternBuilder a) -> Pattern a
NestedMatch

-- | Pattern-match a constructor's field to a constant (e.g. a literal)
constant :: (GenPat b) => b -> Pattern ()
constant :: forall b. GenPat b => b -> Pattern ()
constant = Q Pat -> Pattern ()
Constant (Q Pat -> Pattern ()) -> (b -> Q Pat) -> b -> Pattern ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Q Pat
forall a. GenPat a => a -> Q Pat
genPat

class ConPatternBuilder m where
    setFieldPattern :: Int -> TH.Pat -> m ()

-- | In a pattern that deconstruct the value, this binds the field at the given index using the 'Pattern'
field :: (ConPatternBuilder (Builder s step step)) => Int -> Pattern a -> Builder s step step a
field :: forall {k} s (step :: k) a.
ConPatternBuilder (Builder s step step) =>
Int -> Pattern a -> Builder s step step a
field Int
fidx = \case
    Pattern a
Var -> do
        fieldVarName <- Q Name -> BaseBuilder Q s step step Name
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q Name -> BaseBuilder Q s step step Name)
-> Q Name -> BaseBuilder Q s step step Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName String
"f"
        setFieldPattern fidx $ TH.VarP fieldVarName
        return $ TH.VarE fieldVarName
    Constant Q Pat
qpat -> Q Pat -> BaseBuilder Q s step step Pat
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB Q Pat
qpat BaseBuilder Q s step step Pat
-> (Pat -> Builder s step step a) -> Builder s step step a
forall a b.
BaseBuilder Q s step step a
-> (a -> BaseBuilder Q s step step b)
-> BaseBuilder Q s step step b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
Prelude.>>= Int -> Pat -> BaseBuilder Q s step step ()
forall (m :: * -> *). ConPatternBuilder m => Int -> Pat -> m ()
setFieldPattern Int
fidx
    NestedMatch Name
conN Int -> PatternBuilder a
patBuilder -> do
        fCount <- Q Int -> BaseBuilder Q s step step Int
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q Int -> BaseBuilder Q s step step Int)
-> Q Int -> BaseBuilder Q s step step Int
forall a b. (a -> b) -> a -> b
$ Name -> Q Int
conFieldCount Name
conN
        (res, conP) <- liftB $ runBaseBuilder' (patBuilder fCount) (MkConP conN [] (TH.WildP <$ [0 .. fCount - 1]))
        setFieldPattern fidx $ fromEC conP
        return res

-- | Builds a case match for a predefined constructor (see 'matchCon')
type ConMatchBuilder = Builder ConMatchBuilderState

data ConMatchBuilderState = MkMBS {ConMatchBuilderState -> ConP
_conPat :: ConP, ConMatchBuilderState -> Maybe Exp
_matchBody :: Maybe TH.Exp}

makeLenses ''ConMatchBuilderState

instance ConPatternBuilder (ConMatchBuilder step step) where
    setFieldPattern :: Int -> Pat -> ConMatchBuilder step step ()
setFieldPattern Int
fidx Pat
patt = ((ConP -> Identity ConP)
-> ConMatchBuilderState -> Identity ConMatchBuilderState
Lens' ConMatchBuilderState ConP
conPat ((ConP -> Identity ConP)
 -> ConMatchBuilderState -> Identity ConMatchBuilderState)
-> ((IxValue [Pat] -> Identity Pat) -> ConP -> Identity ConP)
-> (IxValue [Pat] -> Identity Pat)
-> ConMatchBuilderState
-> Identity ConMatchBuilderState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Pat] -> Identity [Pat]) -> ConP -> Identity ConP
forall a b. HasPats a b => Lens' a b
Lens' ConP [Pat]
pats (([Pat] -> Identity [Pat]) -> ConP -> Identity ConP)
-> ((IxValue [Pat] -> Identity Pat) -> [Pat] -> Identity [Pat])
-> (IxValue [Pat] -> Identity Pat)
-> ConP
-> Identity ConP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index [Pat] -> Traversal' [Pat] (IxValue [Pat])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Pat]
fidx) ((IxValue [Pat] -> Identity Pat)
 -> ConMatchBuilderState -> Identity ConMatchBuilderState)
-> Pat -> ConMatchBuilder step step ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= Pat
patt

instance ConPatternBuilder PatternBuilder where
    setFieldPattern :: Int -> Pat -> PatternBuilder ()
setFieldPattern Int
fidx Pat
patt = (([Pat] -> Identity [Pat]) -> ConP -> Identity ConP
forall a b. HasPats a b => Lens' a b
Lens' ConP [Pat]
pats (([Pat] -> Identity [Pat]) -> ConP -> Identity ConP)
-> ((IxValue [Pat] -> Identity Pat) -> [Pat] -> Identity [Pat])
-> (IxValue [Pat] -> Identity Pat)
-> ConP
-> Identity ConP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index [Pat] -> Traversal' [Pat] (IxValue [Pat])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Pat]
fidx) ((IxValue [Pat] -> Identity Pat) -> ConP -> Identity ConP)
-> Pat -> PatternBuilder ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= Pat
patt

-- | Sets the body of the match
body :: (GenExpr b) => b -> ConMatchBuilder Empty Ready ()
body :: forall b. GenExpr b => b -> ConMatchBuilder 'Empty 'Ready ()
body b
q = BaseBuilder Q ConMatchBuilderState 'Empty 'Empty ()
-> ConMatchBuilder 'Empty 'Ready ()
forall {k} (m :: * -> *) s (step :: k) a (next :: k).
BaseBuilder m s step step a -> BaseBuilder m s step next a
impure (BaseBuilder Q ConMatchBuilderState 'Empty 'Empty ()
 -> ConMatchBuilder 'Empty 'Ready ())
-> BaseBuilder Q ConMatchBuilderState 'Empty 'Empty ()
-> ConMatchBuilder 'Empty 'Ready ()
forall a b. (a -> b) -> a -> b
$ do
    e <- Q Exp -> BaseBuilder Q ConMatchBuilderState 'Empty 'Empty Exp
forall {k} (m :: * -> *) a s (step :: k).
Monad m =>
m a -> BaseBuilder m s step step a
liftB (Q Exp -> BaseBuilder Q ConMatchBuilderState 'Empty 'Empty Exp)
-> Q Exp -> BaseBuilder Q ConMatchBuilderState 'Empty 'Empty Exp
forall a b. (a -> b) -> a -> b
$ b -> Q Exp
forall a. GenExpr a => a -> Q Exp
genExpr b
q
    matchBody ?= e