module Nes.CPU.Instructions.Jump (
    -- * Jump
    jmp,
    jsr,

    -- * Return
    rts,
    rti,
) where

import Data.Bits
import Nes.CPU.Instructions.Addressing
import Nes.CPU.Monad
import Nes.CPU.State
import Nes.Internal.MonadState
import Nes.Memory

-- | Sets the program counter to the address specified by the operand
--
-- https://www.nesdev.org/obelisk-6502-guide/reference.html#JMP
jmp :: AddressingMode -> CPU r ()
jmp :: forall r. AddressingMode -> CPU r ()
jmp AddressingMode
Absolute = Getting Addr CPUState Addr -> (Addr -> CPU r Addr) -> CPU r Addr
forall s (m :: * -> *) a r.
MonadState s m =>
Getting a s a -> (a -> m r) -> m r
usesM Getting Addr CPUState Addr
Lens' CPUState Addr
pc (Addr -> () -> CPU r Addr
forall a (m :: * -> *). MemoryInterface a m => Addr -> a -> m Addr
`readAddr` ()) CPU r Addr -> (Addr -> CPU r ()) -> CPU r ()
forall a b. CPU r a -> (a -> CPU r b) -> CPU r b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((Addr -> Identity Addr) -> CPUState -> Identity CPUState
Lens' CPUState Addr
pc ((Addr -> Identity Addr) -> CPUState -> Identity CPUState)
-> Addr -> CPU r ()
forall s (m :: * -> *) a.
MonadState s m =>
ASetter' s a -> a -> m ()
.=)
-- See https://www.nesdev.org/wiki/Instruction_reference#JMP
-- And https://github.com/bugzmanov/nes_ebook/blob/785b9ed8b803d9f4bd51274f4d0c68c14a1b3a8b/code/ch3.4/src/cpu.rs#L692
jmp AddressingMode
Indirect = do
    Addr
addr <- Getting Addr CPUState Addr -> (Addr -> CPU r Addr) -> CPU r Addr
forall s (m :: * -> *) a r.
MonadState s m =>
Getting a s a -> (a -> m r) -> m r
usesM Getting Addr CPUState Addr
Lens' CPUState Addr
pc (Addr -> () -> CPU r Addr
forall a (m :: * -> *). MemoryInterface a m => Addr -> a -> m Addr
`readAddr` ())
    Addr
ref <-
        if Addr
addr Addr -> Addr -> Addr
forall a. Bits a => a -> a -> a
.&. Addr
0x00FF Addr -> Addr -> Bool
forall a. Eq a => a -> a -> Bool
== Addr
0x00FF
            then do
                Addr
low <- Byte -> Addr
byteToAddr (Byte -> Addr) -> CPU r Byte -> CPU r Addr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Addr -> () -> CPU r Byte
forall a (m :: * -> *). MemoryInterface a m => Addr -> a -> m Byte
readByte Addr
addr ()
                Addr
high <- Byte -> Addr
byteToAddr (Byte -> Addr) -> CPU r Byte -> CPU r Addr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Addr -> () -> CPU r Byte
forall a (m :: * -> *). MemoryInterface a m => Addr -> a -> m Byte
readByte (Addr
addr Addr -> Addr -> Addr
forall a. Bits a => a -> a -> a
.&. Addr
0xff00) ()
                Addr -> CPU r Addr
forall a. a -> CPU r a
forall (m :: * -> *) a. Monad m => a -> m a
return (Addr -> CPU r Addr) -> Addr -> CPU r Addr
forall a b. (a -> b) -> a -> b
$ Addr -> Int -> Addr
forall a. Bits a => a -> Int -> a
shiftL Addr
high Int
8 Addr -> Addr -> Addr
forall a. Bits a => a -> a -> a
.|. Addr
low
            else Addr -> () -> CPU r Addr
forall a (m :: * -> *). MemoryInterface a m => Addr -> a -> m Addr
readAddr Addr
addr ()
    (Addr -> Identity Addr) -> CPUState -> Identity CPUState
Lens' CPUState Addr
pc ((Addr -> Identity Addr) -> CPUState -> Identity CPUState)
-> Addr -> CPU r ()
forall s (m :: * -> *) a.
MonadState s m =>
ASetter' s a -> a -> m ()
.= Addr
ref
jmp AddressingMode
_ = String -> CPU r ()
forall a. String -> CPU r a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unsupported addressing mode"

-- | Jump to Subroutine
--
-- https://www.nesdev.org/obelisk-6502-guide/reference.html#JSR
jsr :: CPU r ()
jsr :: forall r. CPU r ()
jsr = do
    Addr
pc' <- Getting Addr CPUState Addr -> CPU r Addr
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting Addr CPUState Addr
Lens' CPUState Addr
pc
    Addr -> CPU r ()
forall r. Addr -> CPU r ()
pushAddrStack (Addr
pc' Addr -> Addr -> Addr
forall a. Num a => a -> a -> a
+ Addr
2 Addr -> Addr -> Addr
forall a. Num a => a -> a -> a
- Addr
1)
    CPU r ()
forall r. CPU r ()
tickOnce
    ((Addr -> Identity Addr) -> CPUState -> Identity CPUState
Lens' CPUState Addr
pc ((Addr -> Identity Addr) -> CPUState -> Identity CPUState)
-> Addr -> CPU r ()
forall s (m :: * -> *) a.
MonadState s m =>
ASetter' s a -> a -> m ()
.=) (Addr -> CPU r ()) -> CPU r Addr -> CPU r ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Addr -> () -> CPU r Addr
forall a (m :: * -> *). MemoryInterface a m => Addr -> a -> m Addr
readAddr Addr
pc' ()

-- | Return from Subroutine
--
-- Pulls the PC from the stack
-- https://www.nesdev.org/obelisk-6502-guide/reference.html#RTS
rts :: CPU r ()
rts :: forall r. CPU r ()
rts = do
    Int -> CPU r ()
forall r. Int -> CPU r ()
tick Int
2
    Addr
res <- CPU r Addr
forall r. CPU r Addr
popStackAddr
    -- https://www.nesdev.org/wiki/Cycle_counting
    --  plus 1 cycle to post-increment the program counter
    CPU r ()
forall r. CPU r ()
tickOnce
    (Addr -> Identity Addr) -> CPUState -> Identity CPUState
Lens' CPUState Addr
pc ((Addr -> Identity Addr) -> CPUState -> Identity CPUState)
-> Addr -> CPU r ()
forall s (m :: * -> *) a.
MonadState s m =>
ASetter' s a -> a -> m ()
.= (Addr
res Addr -> Addr -> Addr
forall a. Num a => a -> a -> a
+ Addr
1)

-- | Return from interrupt
--
-- https://www.nesdev.org/obelisk-6502-guide/reference.html#RTI
rti :: CPU r ()
rti :: forall r. CPU r ()
rti = do
    CPU r ()
forall r. CPU r ()
popStatusRegister
    ((Addr -> Identity Addr) -> CPUState -> Identity CPUState
Lens' CPUState Addr
pc ((Addr -> Identity Addr) -> CPUState -> Identity CPUState)
-> Addr -> CPU r ()
forall s (m :: * -> *) a.
MonadState s m =>
ASetter' s a -> a -> m ()
.=) (Addr -> CPU r ()) -> CPU r Addr -> CPU r ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CPU r Addr
forall r. CPU r Addr
popStackAddr
    Int -> CPU r ()
forall r. Int -> CPU r ()
tick Int
2