module Nes.APU.State.Pulse (
    -- * Pulse
    Pulse (..),
    newPulse,
    tickPulse,
    modifySweep,
    withSweep,

    -- * Sweep Unit
    SweepUnit (..),
    tickSweepUnit,
    updateTargetPeriod,

    -- * Output
    getPulseOutput,
) where

import Data.Bits
import Data.List ((!?))
import Data.Maybe (fromMaybe)
import Nes.APU.State.Envelope
import Nes.APU.State.LengthCounter

data Pulse = MkP
    { Pulse -> Int
dutyIndex :: {-# UNPACK #-} !Int
    -- ^ Index for the 'dutySequences' table
    , Pulse -> Int
dutyStep :: {-# UNPACK #-} !Int
    -- ^ Index for a row's element in the 'dutySequences' table
    , Pulse -> LengthCounter
lengthCounter :: !LengthCounter
    , Pulse -> Int
period :: {-# UNPACK #-} !Int
    -- ^ Max value of the timer
    , Pulse -> Int
timer :: {-# UNPACK #-} !Int
    -- ^ Decreases each tick, from 'period' to 0 and loops
    , Pulse -> SweepUnit
sweepUnit :: !SweepUnit
    , Pulse -> Envelope
envelope :: {-# UNPACK #-} !Envelope
    }

-- | Args is true if building pulse 1
newPulse :: Bool -> Pulse
newPulse :: Bool -> Pulse
newPulse Bool
isPulseOne = MkP{Int
Envelope
LengthCounter
SweepUnit
dutyIndex :: Int
dutyStep :: Int
lengthCounter :: LengthCounter
period :: Int
timer :: Int
sweepUnit :: SweepUnit
envelope :: Envelope
dutyIndex :: Int
dutyStep :: Int
lengthCounter :: LengthCounter
period :: Int
timer :: Int
sweepUnit :: SweepUnit
envelope :: Envelope
..}
  where
    dutyIndex :: Int
dutyIndex = Int
0
    dutyStep :: Int
dutyStep = Int
0
    lengthCounter :: LengthCounter
lengthCounter = LengthCounter
newLengthCounter
    period :: Int
period = Int
0
    timer :: Int
timer = Int
0
    sweepUnit :: SweepUnit
sweepUnit = Bool
-> Int -> Int -> Bool -> Int -> Int -> Bool -> Bool -> SweepUnit
MkSU Bool
False Int
0 Int
0 Bool
False Int
0 Int
0 Bool
False Bool
isPulseOne
    envelope :: Envelope
envelope = Envelope
newEnvelope

--

data SweepUnit = MkSU
    { SweepUnit -> Bool
enabled :: Bool
    , SweepUnit -> Int
dividerPeriod :: Int
    , SweepUnit -> Int
dividerCounter :: Int
    , SweepUnit -> Bool
negateDelta :: Bool
    , SweepUnit -> Int
targetPeriod :: Int
    , SweepUnit -> Int
shiftCount :: Int
    , SweepUnit -> Bool
reloadFlag :: Bool
    , SweepUnit -> Bool
isPulse1 :: Bool
    }

{-# INLINE modifySweep #-}
modifySweep :: (SweepUnit -> SweepUnit) -> Pulse -> Pulse
modifySweep :: (SweepUnit -> SweepUnit) -> Pulse -> Pulse
modifySweep SweepUnit -> SweepUnit
f Pulse
p = Pulse
p{sweepUnit = f (sweepUnit p)}

{-# INLINE withSweep #-}
withSweep :: (SweepUnit -> a) -> Pulse -> a
withSweep :: forall a. (SweepUnit -> a) -> Pulse -> a
withSweep SweepUnit -> a
f Pulse
p = SweepUnit -> a
f (Pulse -> SweepUnit
sweepUnit Pulse
p)

-- | Update the target period in the Sweep unit of the pulse
updateTargetPeriod :: Pulse -> Pulse
updateTargetPeriod :: Pulse -> Pulse
updateTargetPeriod Pulse
p =
    (SweepUnit -> SweepUnit) -> Pulse -> Pulse
modifySweep
        ( \SweepUnit
s ->
            let
                delta :: Int
delta = Pulse -> Int
period Pulse
p Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` SweepUnit -> Int
shiftCount SweepUnit
s
             in
                SweepUnit
s
                    { targetPeriod =
                        if negateDelta s
                            then period p - delta - fromEnum (isPulse1 s)
                            else period p + delta
                    }
        )
        Pulse
p

tickPulse :: Pulse -> Pulse
tickPulse :: Pulse -> Pulse
tickPulse Pulse
p = Pulse
p{dutyStep = newDutyStep, timer = newTimer}
  where
    newDutyStep :: Int
newDutyStep = if Pulse -> Int
timer Pulse
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then (Pulse -> Int
dutyStep Pulse
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8 else Pulse -> Int
dutyStep Pulse
p
    newTimer :: Int
newTimer = if Pulse -> Int
timer Pulse
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Pulse -> Int
period Pulse
p else Pulse -> Int
timer Pulse
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

tickSweepUnit :: Pulse -> Pulse
tickSweepUnit :: Pulse -> Pulse
tickSweepUnit Pulse
p = Pulse
p2
  where
    sweep :: SweepUnit
sweep = Pulse -> SweepUnit
sweepUnit Pulse
p
    p1 :: Pulse
p1 =
        if SweepUnit -> Int
dividerCounter SweepUnit
sweep Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& SweepUnit -> Bool
enabled SweepUnit
sweep Bool -> Bool -> Bool
&& SweepUnit -> Int
shiftCount SweepUnit
sweep Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
            then
                -- If sweep unit is not muting channel
                if Pulse -> Int
period Pulse
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
8 Bool -> Bool -> Bool
&& SweepUnit -> Int
targetPeriod SweepUnit
sweep Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0x7ff
                    then
                        Pulse -> Pulse
updateTargetPeriod (Pulse -> Pulse) -> Pulse -> Pulse
forall a b. (a -> b) -> a -> b
$ Pulse
p{period = targetPeriod sweep}
                    else
                        (SweepUnit -> SweepUnit) -> Pulse -> Pulse
modifySweep
                            (\SweepUnit
s -> SweepUnit
s{dividerCounter = dividerPeriod s})
                            Pulse
p
            else Pulse
p
    p2 :: Pulse
p2 =
        -- TODO Not sure if should use p1 or p0
        if (SweepUnit -> Bool
reloadFlag (SweepUnit -> Bool) -> (Pulse -> SweepUnit) -> Pulse -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pulse -> SweepUnit
sweepUnit) Pulse
p1 Bool -> Bool -> Bool
|| (SweepUnit -> Int
dividerCounter (SweepUnit -> Int) -> (Pulse -> SweepUnit) -> Pulse -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pulse -> SweepUnit
sweepUnit) Pulse
p1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
            then
                (SweepUnit -> SweepUnit) -> Pulse -> Pulse
modifySweep
                    (\SweepUnit
s -> SweepUnit
s{dividerCounter = dividerPeriod s, reloadFlag = False})
                    Pulse
p1
            else (SweepUnit -> SweepUnit) -> Pulse -> Pulse
modifySweep (\SweepUnit
s -> SweepUnit
s{dividerCounter = dividerCounter s - 1}) Pulse
p1

--

instance HasLengthCounter Pulse where
    getLengthCounter :: Pulse -> LengthCounter
getLengthCounter = Pulse -> LengthCounter
lengthCounter
    setLengthCounter :: LengthCounter -> Pulse -> Pulse
setLengthCounter LengthCounter
lc Pulse
a = Pulse
a{lengthCounter = lc}

instance HasEnvelope Pulse where
    getEnvelope :: Pulse -> Envelope
getEnvelope = Pulse -> Envelope
envelope
    setEnvelope :: Envelope -> Pulse -> Pulse
setEnvelope Envelope
e Pulse
a = Pulse
a{envelope = e}

{-# INLINE getPulseOutput #-}
getPulseOutput :: Pulse -> Int
getPulseOutput :: Pulse -> Int
getPulseOutput Pulse
p =
    let dutyValue :: Int
dutyValue = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (([[Int]]
dutySequences [[Int]] -> Int -> Maybe [Int]
forall a. [a] -> Int -> Maybe a
!? Pulse -> Int
dutyIndex Pulse
p) Maybe [Int] -> ([Int] -> Maybe Int) -> Maybe Int
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ([Int] -> Int -> Maybe Int
forall a. [a] -> Int -> Maybe a
!? Pulse -> Int
dutyStep Pulse
p))
        periodOverflows :: Bool
periodOverflows = Bool -> Bool
not (SweepUnit -> Bool
negateDelta (SweepUnit -> Bool) -> SweepUnit -> Bool
forall a b. (a -> b) -> a -> b
$ Pulse -> SweepUnit
sweepUnit Pulse
p) Bool -> Bool -> Bool
&& (SweepUnit -> Int
targetPeriod (SweepUnit -> Int) -> (Pulse -> SweepUnit) -> Pulse -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pulse -> SweepUnit
sweepUnit) Pulse
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0x7ff
        isSilenced :: Bool
isSilenced =
            Pulse -> Bool
forall a. HasLengthCounter a => a -> Bool
isSilencedByLengthCounter Pulse
p
                Bool -> Bool -> Bool
|| (Pulse -> Int
period Pulse
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
8)
                Bool -> Bool -> Bool
|| Int
dutyValue Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                Bool -> Bool -> Bool
|| Bool
periodOverflows
     in if Bool
isSilenced
            then Int
0
            else Envelope -> Int
getEnvelopeOutput (Pulse -> Envelope
envelope Pulse
p)

dutySequences :: [[Int]]
dutySequences :: [[Int]]
dutySequences =
    [ [Int
0, Int
1, Int
0, Int
0, Int
0, Int
0, Int
0, Int
0]
    , [Int
0, Int
1, Int
1, Int
0, Int
0, Int
0, Int
0, Int
0]
    , [Int
0, Int
1, Int
1, Int
1, Int
0, Int
0, Int
0, Int
0]
    , [Int
1, Int
0, Int
0, Int
1, Int
1, Int
1, Int
1, Int
1]
    ]