-- | Internal utilities
module Language.Haskell.TH.Natural.Internal.Utils (conFieldCount) where

import Data.Constructor.Extract (ExtractedConstructor (toEC))
import Data.List (find)
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.Syntax.ExtractedCons.ExtractedCons

-- | Get the number of fields in a constructor or a pattern synonym
conFieldCount :: Name -> Q Int
conFieldCount :: Name -> Q Int
conFieldCount Name
conName_ = do
    info <- Name -> Q Info
reify Name
conName_
    case info of
        (DataConI Name
n' Type
_ Name
pname) -> do
            parentInfo <- Name -> Q Info
reify Name
pname
            case toEC parentInfo of
                Just (MkTyConI Dec
dec) -> Int -> Q Int
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Q Int) -> Int -> Q Int
forall a b. (a -> b) -> a -> b
$ Dec -> Name -> Int
countFieldForCon Dec
dec Name
n'
                Maybe TyConI
_ -> String -> Q Int
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected a type constructor"
        (PatSynI Name
_ Type
ty) -> Int -> Q Int
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Q Int) -> Int -> Q Int
forall a b. (a -> b) -> a -> b
$ Type -> Int
countArgsForTy Type
ty
        Info
_ -> String -> Q Int
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected the name to be one of  a data constructor"
  where
    countFieldForCon :: Dec -> Name -> Int
countFieldForCon Dec
dec Name
n = case Dec
dec of
        DataD Cxt
_ Name
_ [TyVarBndr BndrVis]
_ Maybe Type
_ [Con]
cons [DerivClause]
_ -> Int -> ((Bool, Int) -> Int) -> Maybe (Bool, Int) -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 (Bool, Int) -> Int
forall a b. (a, b) -> b
snd (Maybe (Bool, Int) -> Int) -> Maybe (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ ((Bool, Int) -> Bool) -> [(Bool, Int)] -> Maybe (Bool, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Int)] -> Maybe (Bool, Int))
-> [(Bool, Int)] -> Maybe (Bool, Int)
forall a b. (a -> b) -> a -> b
$ (Con -> (Bool, Int)) -> [Con] -> [(Bool, Int)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name -> Con -> (Bool, Int)
countConArg Name
conName_) [Con]
cons
        NewtypeD Cxt
_ Name
_ [TyVarBndr BndrVis]
_ Maybe Type
_ Con
con [DerivClause]
_ -> (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ((Bool, Int) -> Int) -> (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ Name -> Con -> (Bool, Int)
countConArg Name
n Con
con
        Dec
_ -> Int
0
    countConArg :: Name -> Con -> (Bool, Int)
countConArg Name
n = \case
        NormalC Name
n' [BangType]
args -> (Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
n', [BangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
args)
        RecC Name
n' [VarBangType]
args -> (Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
n', [VarBangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
args)
        InfixC BangType
_ Name
n' BangType
_ -> (Name
n Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
n', Int
2)
        ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
con -> Name -> Con -> (Bool, Int)
countConArg Name
n Con
con
        GadtC [Name]
ns [BangType]
args Type
_ -> (Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
ns, [BangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
args)
        RecGadtC [Name]
ns [VarBangType]
args Type
_ -> (Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
ns, [VarBangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
args)
    countArgsForTy :: Type -> Int
countArgsForTy = \case
        ForallT [TyVarBndr Specificity]
_ Cxt
_ Type
ty -> Type -> Int
countArgsForTy Type
ty
        ForallVisT [TyVarBndr ()]
_ Type
ty -> Type -> Int
countArgsForTy Type
ty
        AppT (AppT Type
ArrowT Type
_) Type
b ->
            Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ case Type
b of
                (AppT (AppT Type
ArrowT Type
_) Type
_) -> Type -> Int
countArgsForTy Type
b
                Type
_ -> Int
0
        Type
_ -> Int
0