module UVMHS.Lib.AD where

import UVMHS.Core
import UVMHS.Lib.Pretty

--------------------------
-- Dual Number Forward ---
--------------------------

data ADF a = ADF
  { forall a. ADF a -> a
adfVal  a
  , forall a. ADF a -> a
adfDer  a
  } deriving (ADF a -> ADF a -> Bool
(ADF a -> ADF a -> Bool) -> (ADF a -> ADF a -> Bool) -> Eq (ADF a)
forall a. Eq a => ADF a -> ADF a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => ADF a -> ADF a -> Bool
== :: ADF a -> ADF a -> Bool
$c/= :: forall a. Eq a => ADF a -> ADF a -> Bool
/= :: ADF a -> ADF a -> Bool
Eq,Eq (ADF a)
Eq (ADF a) =>
(ADF a -> ADF a -> Ordering)
-> (ADF a -> ADF a -> Bool)
-> (ADF a -> ADF a -> Bool)
-> (ADF a -> ADF a -> Bool)
-> (ADF a -> ADF a -> Bool)
-> (ADF a -> ADF a -> ADF a)
-> (ADF a -> ADF a -> ADF a)
-> Ord (ADF a)
ADF a -> ADF a -> Bool
ADF a -> ADF a -> Ordering
ADF a -> ADF a -> ADF a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (ADF a)
forall a. Ord a => ADF a -> ADF a -> Bool
forall a. Ord a => ADF a -> ADF a -> Ordering
forall a. Ord a => ADF a -> ADF a -> ADF a
$ccompare :: forall a. Ord a => ADF a -> ADF a -> Ordering
compare :: ADF a -> ADF a -> Ordering
$c< :: forall a. Ord a => ADF a -> ADF a -> Bool
< :: ADF a -> ADF a -> Bool
$c<= :: forall a. Ord a => ADF a -> ADF a -> Bool
<= :: ADF a -> ADF a -> Bool
$c> :: forall a. Ord a => ADF a -> ADF a -> Bool
> :: ADF a -> ADF a -> Bool
$c>= :: forall a. Ord a => ADF a -> ADF a -> Bool
>= :: ADF a -> ADF a -> Bool
$cmax :: forall a. Ord a => ADF a -> ADF a -> ADF a
max :: ADF a -> ADF a -> ADF a
$cmin :: forall a. Ord a => ADF a -> ADF a -> ADF a
min :: ADF a -> ADF a -> ADF a
Ord,Int -> ADF a -> ShowS
[ADF a] -> ShowS
ADF a -> String
(Int -> ADF a -> ShowS)
-> (ADF a -> String) -> ([ADF a] -> ShowS) -> Show (ADF a)
forall a. Show a => Int -> ADF a -> ShowS
forall a. Show a => [ADF a] -> ShowS
forall a. Show a => ADF a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> ADF a -> ShowS
showsPrec :: Int -> ADF a -> ShowS
$cshow :: forall a. Show a => ADF a -> String
show :: ADF a -> String
$cshowList :: forall a. Show a => [ADF a] -> ShowS
showList :: [ADF a] -> ShowS
Show)
makeLenses ''ADF
makePrettySum ''ADF

-- ∂ns₁ ∂ns₂ ms = ADF (𝕄S ns₁ (ADF (𝕄S ns₂ (𝕄S ms a))))
--
-- 𝕄S ms (ADF (𝕄S ns₁ (ADF (𝕄S ns₂ a)))) ≈
-- valval: 𝕄S ns₁ (𝕄S ns₂ (𝕄S ms a))
-- valder: 𝕄S ns₁ (𝕄S ns₂ (𝕄S ms a))
-- derval: 𝕄S ns₁ (𝕄S ns₂ (𝕄S ms a))
-- derder: 𝕄S ns₁ (𝕄S ns₂ (𝕄S ms a))

constADF  (Zero a)  a  ADF a
constADF :: forall a. Zero a => a -> ADF a
constADF a
x = a -> a -> ADF a
forall a. a -> a -> ADF a
ADF a
x a
forall a. Zero a => a
zero

sensADF  a  a  ADF a
sensADF :: forall a. a -> a -> ADF a
sensADF = a -> a -> ADF a
forall a. a -> a -> ADF a
ADF

plusADF  (Plus a)  ADF a  ADF a  ADF a
plusADF :: forall a. Plus a => ADF a -> ADF a -> ADF a
plusADF (ADF a
v₁ a
d₁) (ADF a
v₂ a
d₂) = a -> a -> ADF a
forall a. a -> a -> ADF a
ADF (a
v₁ a -> a -> a
forall a. Plus a => a -> a -> a
+ a
v₂) (a -> ADF a) -> a -> ADF a
forall a b. (a -> b) -> a -> b
$ a
d₁ a -> a -> a
forall a. Plus a => a -> a -> a
+ a
d₂

timesADF  (Plus a,Times a)  ADF a  ADF a  ADF a
timesADF :: forall a. (Plus a, Times a) => ADF a -> ADF a -> ADF a
timesADF (ADF a
v₁ a
d₁) (ADF a
v₂ a
d₂) = a -> a -> ADF a
forall a. a -> a -> ADF a
ADF (a
v₁ a -> a -> a
forall a. Times a => a -> a -> a
× a
v₂) (a -> ADF a) -> a -> ADF a
forall a b. (a -> b) -> a -> b
$ a
d₁ a -> a -> a
forall a. Times a => a -> a -> a
× a
v₂ a -> a -> a
forall a. Plus a => a -> a -> a
+ a
d₂ a -> a -> a
forall a. Times a => a -> a -> a
× a
v₁

---------------------------
-- Dual Number Backward ---
---------------------------

data ADB a = ADB
  { forall a. ADB a -> a
adbVal  a
  , forall a. ADB a -> a -> a -> a
adbDer  a  a  a
  }
makeLenses ''ADB
makePrettySum ''ADB

-- ∂ns₁ ∂ns₂ ms = ADB (𝕄S ns₁ (ADB (𝕄S ns₂ (𝕄S ms a))))

constADB  (Zero a)  a  ADB a
constADB :: forall a. Zero a => a -> ADB a
constADB a
x = a -> (a -> a -> a) -> ADB a
forall a. a -> (a -> a -> a) -> ADB a
ADB a
x ((a -> a -> a) -> ADB a) -> (a -> a -> a) -> ADB a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> a -> a -> a
forall a b. a -> b -> a
const a -> a
forall a. a -> a
id

sensADB  a  (a  a  a)  ADB a
sensADB :: forall a. a -> (a -> a -> a) -> ADB a
sensADB = a -> (a -> a -> a) -> ADB a
forall a. a -> (a -> a -> a) -> ADB a
ADB

plusADB  (Plus a)  ADB a  ADB a  ADB a
plusADB :: forall a. Plus a => ADB a -> ADB a -> ADB a
plusADB (ADB a
v₁ a -> a -> a
𝒹₁) (ADB a
v₂ a -> a -> a
𝒹₂) = a -> (a -> a -> a) -> ADB a
forall a. a -> (a -> a -> a) -> ADB a
ADB (a
v₁ a -> a -> a
forall a. Plus a => a -> a -> a
+ a
v₂) ((a -> a -> a) -> ADB a) -> (a -> a -> a) -> ADB a
forall a b. (a -> b) -> a -> b
$ \ a
d  a -> a -> a
𝒹₁ a
d (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
 a -> a -> a
𝒹₂ a
d

timesADB  (Times a)  ADB a  ADB a  ADB a
timesADB :: forall a. Times a => ADB a -> ADB a -> ADB a
timesADB (ADB a
v₁ a -> a -> a
𝒹₁) (ADB a
v₂ a -> a -> a
𝒹₂) = a -> (a -> a -> a) -> ADB a
forall a. a -> (a -> a -> a) -> ADB a
ADB (a
v₁ a -> a -> a
forall a. Times a => a -> a -> a
× a
v₂) ((a -> a -> a) -> ADB a) -> (a -> a -> a) -> ADB a
forall a b. (a -> b) -> a -> b
$ \ a
d  a -> a -> a
𝒹₁ (a
d a -> a -> a
forall a. Times a => a -> a -> a
× a
v₂) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
 a -> a -> a
𝒹₂ (a
d a -> a -> a
forall a. Times a => a -> a -> a
× a
v₁)

-- }}}

------------------------------
-- Dual Number Flat Forward --
------------------------------

-- this should just be a newtype over ADF --
data ADFF (ns  [𝐍]) (f  [𝐍]    ) (ms  [𝐍]) (a  ) = ADFF
  { forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
ADFF ns f ms a -> f (ns ⧺ ms) a
adffVal  f (ns  ms) a
  , forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
ADFF ns f ms a -> f (ns ⧺ ms) a
adffDer  f (ns  ms) a
  }
makeLenses ''ADFF
makePrettySum ''ADFF

-- ∂ns₁ ∂ns₂ ms = ADFF ns₁ (ADFF ns₂ 𝕄S) ms a
-- val:    (ADFF (ns₁ ⧺ ns₂) 𝕄S ms a)²
-- der:    (ADFF (ns₁ ⧺ ns₂) 𝕄S ms a)²
-- valval: 𝕄S (ns₁ ⧺ ns₂ ⧺ ms) a
-- valder: 𝕄S (ns₁ ⧺ ns₂ ⧺ ms) a
-- derval: 𝕄S (ns₁ ⧺ ns₂ ⧺ ms) a
-- derder: 𝕄S (ns₁ ⧺ ns₂ ⧺ ms) a

constADFF  (Zero (f (ns  ms) a))  f (ns  ms) a  ADFF ns f ms a
constADFF :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
Zero (f (ns ⧺ ms) a) =>
f (ns ⧺ ms) a -> ADFF ns f ms a
constADFF f (ns ⧺ ms) a
v = f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
ADFF f (ns ⧺ ms) a
v f (ns ⧺ ms) a
forall a. Zero a => a
zero

sensADFF  f (ns  ms) a  f (ns  ms) a  ADFF ns f ms a
sensADFF :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
sensADFF = f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
ADFF

plusADFF  (Plus (f (ns  ms) a))  ADFF ns f ms a  ADFF ns f ms a  ADFF ns f ms a
plusADFF :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
Plus (f (ns ⧺ ms) a) =>
ADFF ns f ms a -> ADFF ns f ms a -> ADFF ns f ms a
plusADFF (ADFF f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a
d₁) (ADFF f (ns ⧺ ms) a
v₂ f (ns ⧺ ms) a
d₂) = f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
ADFF (f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Plus a => a -> a -> a
+ f (ns ⧺ ms) a
v₂) (f (ns ⧺ ms) a -> ADFF ns f ms a)
-> f (ns ⧺ ms) a -> ADFF ns f ms a
forall a b. (a -> b) -> a -> b
$ f (ns ⧺ ms) a
d₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Plus a => a -> a -> a
+ f (ns ⧺ ms) a
d₂

timesADFF  (Plus (f (ns  ms) a),Times (f (ns  ms) a))  ADFF ns f ms a  ADFF ns f ms a  ADFF ns f ms a
timesADFF :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
(Plus (f (ns ⧺ ms) a), Times (f (ns ⧺ ms) a)) =>
ADFF ns f ms a -> ADFF ns f ms a -> ADFF ns f ms a
timesADFF (ADFF f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a
d₁) (ADFF f (ns ⧺ ms) a
v₂ f (ns ⧺ ms) a
d₂) = f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> ADFF ns f ms a
ADFF (f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Times a => a -> a -> a
× f (ns ⧺ ms) a
v₂) (f (ns ⧺ ms) a -> ADFF ns f ms a)
-> f (ns ⧺ ms) a -> ADFF ns f ms a
forall a b. (a -> b) -> a -> b
$ f (ns ⧺ ms) a
d₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Times a => a -> a -> a
× f (ns ⧺ ms) a
v₂ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Plus a => a -> a -> a
+ f (ns ⧺ ms) a
d₂ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Times a => a -> a -> a
× f (ns ⧺ ms) a
v₁

-------------------------------
-- Dual Number Flat Backward --
-------------------------------

-- this should just be a newtype over ADB --
data ADFB (ns  [𝐍]) (f  [𝐍]    ) (ms  [𝐍]) (a  ) = ADFB
  { forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
ADFB ns f ms a -> f (ns ⧺ ms) a
adfbVal  f (ns  ms) a
  , forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
ADFB ns f ms a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
adfbDer  f (ns  ms) a  f (ns  ms) a  f (ns  ms) a
  }
makeLenses ''ADFB
makePrettySum ''ADFB

-- ∂ns₁ ∂ns₂ ms = ADFB ns₁ (ADFB ns₂ 𝕄S) ms a
-- val:   𝕄S ms
-- der:   𝕄S ms → X → X
-- X.val: 𝕄S (ns₁ ⧺ ms)
-- X.der: 𝕄S (ns₁ ⧺ ms) → 𝕄S (ns₂ ⧺ ns₁ ⧺ ms) → 𝕄S (ns₂ ⧺ ns₁ ⧺ ms)

constADFB  f (ns  ms) a  ADFB ns f ms a
constADFB :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
f (ns ⧺ ms) a -> ADFB ns f ms a
constADFB f (ns ⧺ ms) a
v = f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
ADFB f (ns ⧺ ms) a
v ((f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
 -> ADFB ns f ms a)
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall a b. (a -> b) -> a -> b
$ (f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a b. a -> b -> a
const f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. a -> a
id

sensADFB  f (ns  ms) a  (f (ns  ms) a  f (ns  ms) a  f (ns  ms) a)  ADFB ns f ms a
sensADFB :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
sensADFB = f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
ADFB

plusADFB  (Plus (f (ns  ms) a))  ADFB ns f ms a  ADFB ns f ms a  ADFB ns f ms a
plusADFB :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
Plus (f (ns ⧺ ms) a) =>
ADFB ns f ms a -> ADFB ns f ms a -> ADFB ns f ms a
plusADFB (ADFB f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₁) (ADFB f (ns ⧺ ms) a
v₂ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₂) = f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
ADFB (f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Plus a => a -> a -> a
+ f (ns ⧺ ms) a
v₂) ((f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
 -> ADFB ns f ms a)
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall a b. (a -> b) -> a -> b
$ \ f (ns ⧺ ms) a
d  f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₁ f (ns ⧺ ms) a
d (f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> f (ns ⧺ ms) a
-> f (ns ⧺ ms) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
 f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₂ f (ns ⧺ ms) a
d

timesADFB  (Plus (f (ns  ms) a),Times (f (ns  ms) a))  ADFB ns f ms a  ADFB ns f ms a  ADFB ns f ms a
timesADFB :: forall (f :: [𝐍] -> * -> *) (ns :: [𝐍]) (ms :: [𝐍]) a.
(Plus (f (ns ⧺ ms) a), Times (f (ns ⧺ ms) a)) =>
ADFB ns f ms a -> ADFB ns f ms a -> ADFB ns f ms a
timesADFB (ADFB f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₁) (ADFB f (ns ⧺ ms) a
v₂ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₂) = f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall (ns :: [𝐍]) (f :: [𝐍] -> * -> *) (ms :: [𝐍]) a.
f (ns ⧺ ms) a
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
ADFB (f (ns ⧺ ms) a
v₁ f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Times a => a -> a -> a
× f (ns ⧺ ms) a
v₂) ((f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
 -> ADFB ns f ms a)
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> ADFB ns f ms a
forall a b. (a -> b) -> a -> b
$ \ f (ns ⧺ ms) a
d  f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₁ (f (ns ⧺ ms) a
d f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Times a => a -> a -> a
× f (ns ⧺ ms) a
v₂) (f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> (f (ns ⧺ ms) a -> f (ns ⧺ ms) a)
-> f (ns ⧺ ms) a
-> f (ns ⧺ ms) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
 f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
𝒹₂ (f (ns ⧺ ms) a
d f (ns ⧺ ms) a -> f (ns ⧺ ms) a -> f (ns ⧺ ms) a
forall a. Times a => a -> a -> a
× f (ns ⧺ ms) a
v₁)