-- | This module allows generating lenses for non-record ADTs.
--
-- To allow overloading (e.g. using the same lens name for multiple ADTs), it first generates a typeclass and then an instance for it.
-- If other lenses generated in the same scope have the same name, only an instance is generated.
--
-- @
-- data A = A Int
-- data B = B Int
--
-- 'makeADTLenses' 'A ["value"]
--
-- -- ===> Generates the following
-- class HasValue a b where
--  value = 'Lens\'' a b
--
-- instance HasValue A Int where
--  value = 'position' @1
--
-- 'makeADTLenses' 'B ["value"]
--
-- -- ===> Generates the following
--
-- instance HasValue B Int where
--  value = 'position' @1
-- @
module Control.Lens.TH.ADT (
    makeADTLenses,
    makeADTLens,
    lensClassName,
) where

import Control.Lens
import Control.Monad
import Data.Char
import Data.Generics.Product
import Data.List ((!?))
import Data.Maybe
import Language.Haskell.TH

-- | Generates lenses for all the fields of the given type. Uses 'makeADTLens' for each field.
makeADTLenses :: Name -> [String] -> DecsQ
makeADTLenses :: Name -> [String] -> DecsQ
makeADTLenses Name
tyName [String]
fieldNames =
    [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
        ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> Int -> DecsQ) -> [String] -> [Int] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Name -> String -> Int -> DecsQ
makeADTLens Name
tyName) [String]
fieldNames [Int
0 ..]

-- | Generates a lens with the given name for the field at the given index for the given type
--
-- @
-- 'makeADTLens' '(,) "left" 0
--
-- -- ===> Generates the following
-- instance HasLeft (a,b) a where
--  left = 'position' @1
-- @
makeADTLens :: Name -> String -> Int -> DecsQ
makeADTLens :: Name -> String -> Int -> DecsQ
makeADTLens Name
tyName String
lensStrName Int
fieldIdx = do
    let lensName :: Name
lensName = String -> Name
mkName String
lensStrName
    instance_ <- Name -> Name -> Int -> DecQ
makeADTLensInstance Name
tyName Name
lensName Int
fieldIdx
    class_ <-
        lookupTypeName (nameBase $ lensClassName lensName) >>= \case
            Maybe Name
Nothing -> Dec -> Maybe Dec
forall a. a -> Maybe a
Just (Dec -> Maybe Dec) -> DecQ -> Q (Maybe Dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> DecQ
makeADTLensClass Name
lensName
            Just Name
_ -> Maybe Dec -> Q (Maybe Dec)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Dec
forall a. Maybe a
Nothing
    return (maybeToList class_ ++ [instance_])

makeADTLensInstance :: Name -> Name -> Int -> DecQ
makeADTLensInstance :: Name -> Name -> Int -> DecQ
makeADTLensInstance Name
tyName Name
lensName Int
fieldIdx = do
    (tyArgs, bt) <-
        Name -> Q Info
reify Name
tyName Q Info
-> (Info -> Q ([Name], [BangType])) -> Q ([Name], [BangType])
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            TyConI (DataD Cxt
_ Name
_ [TyVarBndr BndrVis]
tyBndrs Maybe Kind
_ [NormalC Name
_ [BangType]
con] [DerivClause]
_) ->
                let tyArgs :: [Name]
tyArgs =
                        [TyVarBndr BndrVis]
tyBndrs [TyVarBndr BndrVis] -> (TyVarBndr BndrVis -> Name) -> [Name]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
                            PlainTV Name
n BndrVis
_ -> Name
n
                            KindedTV Name
n BndrVis
_ Kind
_ -> Name
n
                 in ([Name], [BangType]) -> Q ([Name], [BangType])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Name]
tyArgs, [BangType]
con)
            Info
e -> String -> Q ([Name], [BangType])
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ([Name], [BangType]))
-> String -> Q ([Name], [BangType])
forall a b. (a -> b) -> a -> b
$ String
"Expected a data type with extractly one constructor, got: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Show a => a -> String
show Info
e
    fieldType <- snd <$> maybe (fail "Invalid field index") return (bt !? fieldIdx)
    let sourceTy = (Kind -> Name -> Kind) -> Kind -> [Name] -> Kind
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Kind
rest Name
t -> Kind
rest Kind -> Kind -> Kind
`AppT` Name -> Kind
VarT Name
t) (Name -> Kind
ConT Name
tyName) [Name]
tyArgs
    return $
        InstanceD
            Nothing
            []
            (ConT (lensClassName lensName) `AppT` sourceTy `AppT` fieldType)
            [ FunD
                lensName
                [Clause [] (NormalB (VarE 'position `AppTypeE` LitT (NumTyLit $ fromIntegral fieldIdx + 1))) []]
            ]

makeADTLensClass :: Name -> DecQ
makeADTLensClass :: Name -> DecQ
makeADTLensClass Name
lensName =
    Dec -> DecQ
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> DecQ) -> Dec -> DecQ
forall a b. (a -> b) -> a -> b
$
        Cxt -> Name -> [TyVarBndr BndrVis] -> [FunDep] -> [Dec] -> Dec
ClassD
            []
            (Name -> Name
lensClassName Name
lensName)
            [Name -> BndrVis -> TyVarBndr BndrVis
forall flag. Name -> flag -> TyVarBndr flag
PlainTV Name
a BndrVis
BndrReq, Name -> BndrVis -> TyVarBndr BndrVis
forall flag. Name -> flag -> TyVarBndr flag
PlainTV Name
b BndrVis
BndrReq]
            [[Name] -> [Name] -> FunDep
FunDep [Name
a] [Name
b]]
            [Name -> Kind -> Dec
SigD Name
lensName (Name -> Kind
ConT ''Lens' Kind -> Kind -> Kind
`AppT` Name -> Kind
VarT Name
a Kind -> Kind -> Kind
`AppT` Name -> Kind
VarT Name
b)]
  where
    a :: Name
a = String -> Name
mkName String
"a"
    b :: Name
b = String -> Name
mkName String
"b"

-- | Get the name of the lens typeclass to define, using the lens' name
lensClassName :: Name -> Name
lensClassName :: Name -> Name
lensClassName Name
lensName = String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Has" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
capitalize (Name -> String
nameBase Name
lensName)
  where
    capitalize :: String -> String
capitalize = ASetter String String Char Char
-> (Char -> Char) -> String -> String
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter String String Char Char
forall s a. Cons s s a a => Traversal' s a
Traversal' String Char
_head Char -> Char
toUpper