{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Nes.APU.State.Filter.Iir (
    IirFilter (..),

    -- * Build predefined filters
    identityIirFilter,
    highPassIirFilter,
    lowPassIirFilter,
) where

import Nes.APU.State.Filter.Class
import Nes.APU.State.Filter.Constants

-- | Infinite impulse response (IIR) filter
data IirFilter = MkIirF
    { IirFilter -> Cutoff
alpha :: {-# UNPACK #-} !Float
    , IirFilter -> Cutoff
previousOutput :: {-# UNPACK #-} !Sample
    , IirFilter -> Cutoff
previousInput :: {-# UNPACK #-} !Sample
    , IirFilter -> Cutoff
delta :: {-# UNPACK #-} !Float
    , IirFilter -> IirFilterPass
pass :: {-# UNPACK #-} !IirFilterPass
    }

data IirFilterPass = Identity | LowPass | HighPass deriving (IirFilterPass -> IirFilterPass -> Bool
(IirFilterPass -> IirFilterPass -> Bool)
-> (IirFilterPass -> IirFilterPass -> Bool) -> Eq IirFilterPass
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: IirFilterPass -> IirFilterPass -> Bool
== :: IirFilterPass -> IirFilterPass -> Bool
$c/= :: IirFilterPass -> IirFilterPass -> Bool
/= :: IirFilterPass -> IirFilterPass -> Bool
Eq)

identityIirFilter :: IirFilter
identityIirFilter :: IirFilter
identityIirFilter =
    MkIirF
        { alpha :: Cutoff
alpha = Cutoff
0
        , previousInput :: Cutoff
previousInput = Cutoff
0
        , previousOutput :: Cutoff
previousOutput = Cutoff
0
        , delta :: Cutoff
delta = Cutoff
0
        , pass :: IirFilterPass
pass = IirFilterPass
Identity
        }

highPassIirFilter :: SampleRate -> Cutoff -> IirFilter
highPassIirFilter :: Cutoff -> Cutoff -> IirFilter
highPassIirFilter Cutoff
sampleRate Cutoff
cutoff =
    MkIirF
        { alpha :: Cutoff
alpha = Cutoff
cutoffPeriod Cutoff -> Cutoff -> Cutoff
forall a. Fractional a => a -> a -> a
/ (Cutoff
cutoffPeriod Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
+ Cutoff
period)
        , previousOutput :: Cutoff
previousOutput = Cutoff
0
        , previousInput :: Cutoff
previousInput = Cutoff
0
        , delta :: Cutoff
delta = Cutoff
0
        , pass :: IirFilterPass
pass = IirFilterPass
HighPass
        }
  where
    period :: Cutoff
period = Cutoff
1 Cutoff -> Cutoff -> Cutoff
forall a. Fractional a => a -> a -> a
/ Cutoff
sampleRate
    cutoffPeriod :: Cutoff
cutoffPeriod = Cutoff
1 Cutoff -> Cutoff -> Cutoff
forall a. Fractional a => a -> a -> a
/ Cutoff
cutoff

lowPassIirFilter :: SampleRate -> Cutoff -> IirFilter
lowPassIirFilter :: Cutoff -> Cutoff -> IirFilter
lowPassIirFilter Cutoff
sampleRate Cutoff
cutoff =
    MkIirF
        { alpha :: Cutoff
alpha = Cutoff
cutoffPeriod Cutoff -> Cutoff -> Cutoff
forall a. Fractional a => a -> a -> a
/ (Cutoff
cutoffPeriod Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
+ Cutoff
period)
        , previousOutput :: Cutoff
previousOutput = Cutoff
0
        , previousInput :: Cutoff
previousInput = Cutoff
0
        , delta :: Cutoff
delta = Cutoff
0
        , pass :: IirFilterPass
pass = IirFilterPass
LowPass
        }
  where
    period :: Cutoff
period = Cutoff
1 Cutoff -> Cutoff -> Cutoff
forall a. Fractional a => a -> a -> a
/ Cutoff
sampleRate
    cutoffPeriod :: Cutoff
cutoffPeriod = Cutoff
1 Cutoff -> Cutoff -> Cutoff
forall a. Fractional a => a -> a -> a
/ (Cutoff
2 Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
* Cutoff
forall a. Floating a => a
pi Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
* Cutoff
cutoff)

instance (Monad m) => Filter m IirFilter where
    {-# INLINE output #-}
    output :: IirFilter -> m Cutoff
output IirFilter
f = Cutoff -> m Cutoff
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Cutoff -> m Cutoff) -> Cutoff -> m Cutoff
forall a b. (a -> b) -> a -> b
$ case IirFilter -> IirFilterPass
pass IirFilter
f of
        IirFilterPass
Identity -> IirFilter -> Cutoff
previousInput IirFilter
f
        IirFilterPass
LowPass -> IirFilter -> Cutoff
previousOutput IirFilter
f Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
+ IirFilter -> Cutoff
alpha IirFilter
f Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
* IirFilter -> Cutoff
delta IirFilter
f
        IirFilterPass
HighPass -> IirFilter -> Cutoff
alpha IirFilter
f Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
* IirFilter -> Cutoff
previousOutput IirFilter
f Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
+ IirFilter -> Cutoff
alpha IirFilter
f Cutoff -> Cutoff -> Cutoff
forall a. Num a => a -> a -> a
* IirFilter -> Cutoff
delta IirFilter
f

    {-# INLINE consume #-}
    consume :: Cutoff -> IirFilter -> m IirFilter
consume Cutoff
sample IirFilter
f = do
        Cutoff
prevOut <- forall (m :: * -> *) a. Filter m a => a -> m Cutoff
output @m IirFilter
f
        IirFilter -> m IirFilter
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (IirFilter -> m IirFilter) -> IirFilter -> m IirFilter
forall a b. (a -> b) -> a -> b
$
            IirFilter
f
                { previousOutput = prevOut
                , delta = sample - previousInput f
                , previousInput = sample
                }