Monads
Enter the Monad
A Monad is just a datatype that can represent some kind of computational behavior. A monad supports a "well-behaved" notion of sequential composition.- mutable state
- failure and exceptions
- nondeterminism
- I/O
- divergence
Modeling Imperative Programs
Inductive ceval : com → mem → mem → Prop :=
...
| E_Seq : ∀ c1 c2 st st' st'',
st =[ c1 ]=> st' →
st' =[ c2 ]=> st'' →
st =[ c1 ; c2 ]=> st''
Evaluation as a function (revisited)
What if we could model a command by denote : mem → mem ?Fixpoint denote (c:com) : mem → mem :=
match c with
| ... (* omitted *)
| c1 ; c2 ⇒ fun st ⇒
let st' := denote c1 st in
let st'' := denote c2 st' in
st''
end.
| CWhile b c ⇒ ??? (* No Coq function of type mem → mem can work. *)
Towards a solution
Stateful computations returning values
Consider how to model this "Imp-like" program:Definition fact5_body :=
<{ X := 5;
Z := X;
Y := 1;
while Z ≠ 0 do
Y := Y × Z;
Z := Z - 1
end;
return Y; (* <== NEW! *)
}>
mem → mem × nat
State monads
Definition state (S : Type) (B:Type) : Type := S → S × B.
state mem unit = mem → mem × unit
state mem nat = mem → mem × nat
Sequential composition of commands
<{ c1 ; c2 }> (* <== Imp sequential composition *)
fun σ ⇒ (* <== with "state plumbing" *)
let (σ', _) := c1 σ in
c2 tt σ' (* tt is Coq's unit value *)
Let notation vs. Monad notation
+----------------- Equal by Unfolding Definitions --------------+
| |
STANDARD COQ: EXPLICIT STATE PLUMBING: BIND: MONADIC NOTATION:
fun σ ⇒
let x := m in let (σ', x) := m σ in bind m (fun x ⇒ x <- m ;;
rest x rest x σ' rest x) rest x
The bind operation
Definition state_bind {S A B} (m : state S A) (k:A → state S B) : state S B :=
fun σ ⇒ let '(σ', a) := m σ in
k a σ'.
fun σ ⇒ let '(σ', a) := m σ in
k a σ'.
We introduce some notation to make using bind more palatable.
Notation "x <- m ;; body" := (state_bind m (fun x ⇒ body))
The ret operation
Definition state_ret {S A} : A → state S A :=
fun (a:A) (σ:S) ⇒ (σ, a).
fun (a:A) (σ:S) ⇒ (σ, a).
We introduce some handy notation for returning values.
Notation "'ret' a" := (state_ret a) (at level 50).
get and put
Definition get {S} : state S S :=
fun σ ⇒ (σ, σ).
fun σ ⇒ (σ, σ).
put updates the state, returning the trivial unit value.
Definition put {S} (σ:S) : state S unit :=
fun _ ⇒ (σ, tt).
fun _ ⇒ (σ, tt).
Example state computation
Our state monad can work on any state type, for example bool state. Below, b is the result of reading the state via get, we use _ because the result of the put operation is the trivial unit value, and the computation returns 3 if b is true and 17 otherwise.
Example boolean_state_example : state bool nat :=
b <- get ;;
_ <- put (negb b) ;;
if b then ret 3 else ret 17.
Eval compute in (boolean_state_example true).
b <- get ;;
_ <- put (negb b) ;;
if b then ret 3 else ret 17.
Eval compute in (boolean_state_example true).
==> (false, 3) : bool × nat
Eval compute in (boolean_state_example false).
==> (true, 17) : bool × nat
A monad M:Type → Type that supports:
We summarize these requirements in a typeclass:
- ret, a way to embed pure computations, and
- bind, a way to sequence computations.
Class Monad (M : Type → Type) := {
ret : ∀ A, A → M A;
bind : ∀ A B, M A → (A → M B) → M B;
}.
ret : ∀ A, A → M A;
bind : ∀ A B, M A → (A → M B) → M B;
}.
And introduce monad notation x <- m ;; k
Definition state (S:Type) : Type → Type := fun A ⇒ S → (S × A).
#[export] Instance stateM S : Monad (state S) :=
{|
ret := fun A (a : A) σ ⇒ (σ, a);
bind := fun A B (m : state S A) (k : A → state S B) ⇒
fun σ ⇒ let '(σ', a) := m σ in
k a σ'
|}.
Definition get {S} : state S S :=
fun σ ⇒ (σ, σ).
Definition put {S} (σ:S) : state S unit :=
fun _ ⇒ (σ, tt).
#[export] Instance stateM S : Monad (state S) :=
{|
ret := fun A (a : A) σ ⇒ (σ, a);
bind := fun A B (m : state S A) (k : A → state S B) ⇒
fun σ ⇒ let '(σ', a) := m σ in
k a σ'
|}.
Definition get {S} : state S S :=
fun σ ⇒ (σ, σ).
Definition put {S} (σ:S) : state S unit :=
fun _ ⇒ (σ, tt).
Other common monads.
- identity
id A := A - option, supporting fail
option A := None | Some A - nondeterminism, supporting choice
nondet A := list A
What does it mean for sequential composition to be "well behaved"? Consider these two Imp commands:
(c1 ; c2) ; c3 vs
c1 ; (c2 ; c3)
These programs should be equivalent! i.e. ';' is associative
(c1 ; c2) ; c3 vs
c1 ; (c2 ; c3)
let is associative too
Example nested_lets1 :=
let x :=
(let y := 3 in y + y)
in x × 2 .
Example nested_lets2 :=
let y := 3 in
let x := y + y in
x × 2.
let x :=
(let y := 3 in y + y)
in x × 2 .
Example nested_lets2 :=
let y := 3 in
let x := y + y in
x × 2.
Both of these expressions evaluate to the same value:
Eval compute in nested_lets1.
==> 12 : nat
Eval compute in nested_lets2.
Definition bind_associativity_law {M} `{Monad M} {A B C} :=
∀ (ma : M A) (kb : A → M B) (kc : B → M C),
x <- (y <- ma ;; kb y) ;; kc x
=
y <- ma ;; x <- kb y ;; kc x.
∀ (ma : M A) (kb : A → M B) (kc : B → M C),
x <- (y <- ma ;; kb y) ;; kc x
=
y <- ma ;; x <- kb y ;; kc x.
Example nested_lets_monadic2' :=
let y := 3 in
let x := y + y in
x × 2.
let y := 3 in
let x := y + y in
x × 2.
After substitution:
Example nested_lets2_subst :=
let x := 3 + 3 in (* n.b. substitute 3 for y in y + y *)
x × 2.
let x := 3 + 3 in (* n.b. substitute 3 for y in y + y *)
x × 2.
After the substitution, we still obtain the same result:
Eval compute in (nested_lets2_subst).
Definition bind_ret_l_law {M} `{Monad M} {A B} :=
∀ (a : A) (kb : A → M B),
x <- ret a ;; kb x
=
kb a.
∀ (a : A) (kb : A → M B),
x <- ret a ;; kb x
=
kb a.
One last law
To see what the last monad law means, consider this following example, again, stated with let notation:
Example let_return (m : nat) :=
let x := m in
x.
let x := m in
x.
Rewriting this example using monadic notation we have:
Example let_return_monadic (m : nat) :=
x <- m ;;
ret x.
x <- m ;;
ret x.
Definition bind_ret_r_law {M} `{Monad M} {A} :=
∀ (m : M A),
x <- m ;; ret x
=
m.
∀ (m : M A),
x <- m ;; ret x
=
m.
Class MonadLaws {M} `{Monad M} := {
bind_associativity :
∀ A B C (ma : M A) (kb : A → M B) (kc : B → M C),
x <- (y <- ma ;; kb y) ;; kc x
=
y <- ma ;; x <- kb y ;; kc x
; bind_ret_l :
∀ A B (a : A) (kb : A → M B),
x <- ret a ;; kb x
=
kb a
; bind_ret_r :
∀ A (m : M A),
x <- m ;; ret x
=
m
}.
bind_associativity :
∀ A B C (ma : M A) (kb : A → M B) (kc : B → M C),
x <- (y <- ma ;; kb y) ;; kc x
=
y <- ma ;; x <- kb y ;; kc x
; bind_ret_l :
∀ A B (a : A) (kb : A → M B),
x <- ret a ;; kb x
=
kb a
; bind_ret_r :
∀ A (m : M A),
x <- m ;; ret x
=
m
}.
Example nested_lets_generic {M} `{Monad M} : M nat :=
x <-
(y <- ret 3 ;; ret (y+y)) ;;
ret (x × 2).
Example monadic_proof {M} `{Monad M} `{MonadLaws M} :
nested_lets_generic = ret 12.
x <-
(y <- ret 3 ;; ret (y+y)) ;;
ret (x × 2).
Example monadic_proof {M} `{Monad M} `{MonadLaws M} :
nested_lets_generic = ret 12.
Proof.
unfold nested_lets_generic.
rewrite bind_associativity.
rewrite bind_ret_l.
rewrite bind_ret_l.
reflexivity.
Qed.
unfold nested_lets_generic.
rewrite bind_associativity.
rewrite bind_ret_l.
rewrite bind_ret_l.
reflexivity.
Qed.
Monad Equivalences
Lemma bind_associativity_state_problematic :
∀ S A B C (ma : state S A) (kb : A → state S B) (kc : B → state S C),
x <- (y <- ma ;; kb y) ;; kc x
=
y <- ma ;; x <- kb y ;; kc x.
Proof.
intros. unfold bind. simpl.
(* Because state S A is defined as S → (S × A), to say what it means for
two stateful computations to be the same means that we have to say what it means
for two _functions_ to be the same. To proceed from here, we could use
_functional extensionality_, which is safe to add as an axiom to Coq, but we
can't prove this directly. *)
Abort.
∀ S A B C (ma : state S A) (kb : A → state S B) (kc : B → state S C),
x <- (y <- ma ;; kb y) ;; kc x
=
y <- ma ;; x <- kb y ;; kc x.
Proof.
intros. unfold bind. simpl.
(* Because state S A is defined as S → (S × A), to say what it means for
two stateful computations to be the same means that we have to say what it means
for two _functions_ to be the same. To proceed from here, we could use
_functional extensionality_, which is safe to add as an axiom to Coq, but we
can't prove this directly. *)
Abort.
Class EqM (M : Type → Type) : Type :=
eqM : ∀ A, M A → M A → Prop.
eqM : ∀ A, M A → M A → Prop.
We also require a proof that eqM is an equivalence relation.
Class EqMEquivalence (M : Type → Type) `{EqM M} :=
eqM_equiv : ∀ A, Equivalence (eqM (A := A)).
eqM_equiv : ∀ A, Equivalence (eqM (A := A)).
Because the notion of monadic equivalence is so prevalant, we introduce
≈ as a succinct, infix version.
Infix "≈" := eqM (at level 70) : monad_scope.
Class MonadLaws M `{Monad M} `{EqM M} := {
bind_associativity :
∀ A B C (ma : M A) (kb : A → M B) (kc : B → M C),
x <- (y <- ma ;; kb y) ;; kc x
≈ (* <---- NEW! *)
y <- ma ;; x <- kb y ;; kc x
; bind_ret_l :
∀ A B (a : A) (kb : A → M B),
x <- ret a ;; kb x
≈
kb a
; bind_ret_r :
∀ A (m : M A),
x <- m ;; ret x
≈
m
(* NEW! *)
; Proper_bind : ∀ {A B},
@Proper (M A → (A → M B) → M B)
(eqM ==> pointwise_relation _ eqM ==> eqM) bind
}.
bind_associativity :
∀ A B C (ma : M A) (kb : A → M B) (kc : B → M C),
x <- (y <- ma ;; kb y) ;; kc x
≈ (* <---- NEW! *)
y <- ma ;; x <- kb y ;; kc x
; bind_ret_l :
∀ A B (a : A) (kb : A → M B),
x <- ret a ;; kb x
≈
kb a
; bind_ret_r :
∀ A (m : M A),
x <- m ;; ret x
≈
m
(* NEW! *)
; Proper_bind : ∀ {A B},
@Proper (M A → (A → M B) → M B)
(eqM ==> pointwise_relation _ eqM ==> eqM) bind
}.
Properness: a logical relation
Lemma Proper_bind_def : ∀ M `{Monad M} `{EqM M} `{MonadLaws M} A B,
(@Proper (M A → (A → M B) → M B)
(eqM ==> pointwise_relation _ eqM ==> eqM) bind) ↔
∀ (m1 m2 : M A),
m1 ≈ m2 →
∀ k1 k2 : A → M B,
(∀ a1 a2 : A, a1 = a2 → k1 a1 ≈ k2 a2)
→
(x <- m1;; k1 x) ≈ (x <- m2;; k2 x).
(@Proper (M A → (A → M B) → M B)
(eqM ==> pointwise_relation _ eqM ==> eqM) bind) ↔
∀ (m1 m2 : M A),
m1 ≈ m2 →
∀ k1 k2 : A → M B,
(∀ a1 a2 : A, a1 = a2 → k1 a1 ≈ k2 a2)
→
(x <- m1;; k1 x) ≈ (x <- m2;; k2 x).
Proof.
split; intros.
- apply H4; auto.
repeat red. intros. apply H6. auto.
- repeat red. intros. apply H4; auto. intros. subst. apply H6.
Qed.
split; intros.
- apply H4; auto.
repeat red. intros. apply H6. auto.
- repeat red. intros. apply H4; auto. intros. subst. apply H6.
Qed.
Important!
Monad Laws for state S
#[export]
Instance eqM_state {S} : EqM (state S) := fun A ⇒ (@eq (state S A)).
#[export]
Instance eqM_state_equiv {S} : EqMEquivalence (state S).
Instance eqM_state {S} : EqM (state S) := fun A ⇒ (@eq (state S A)).
#[export]
Instance eqM_state_equiv {S} : EqMEquivalence (state S).
Proof.
constructor; unfold eqM, eqM_state; typeclasses eauto.
Qed.
constructor; unfold eqM, eqM_state; typeclasses eauto.
Qed.
Next we prove the monad laws:
#[export]
Instance eqm_state_monad_laws {S} : MonadLaws (state S).
Instance eqm_state_monad_laws {S} : MonadLaws (state S).
Proof.
constructor.
- intros. unfold eqM, eqM_state.
(* use functional extensionality(!) *)
apply functional_extensionality.
intro σ.
simpl.
destruct (ma σ).
reflexivity.
- intros. unfold eqM, eqM_state.
apply functional_extensionality.
intro σ.
reflexivity.
- intros. unfold eqM, eqM_state.
apply functional_extensionality.
intro σ.
simpl.
destruct (m σ).
reflexivity.
- repeat red.
intros A B m1 m2 EQ a1 a2 HP.
apply functional_extensionality.
intro σ.
simpl.
rewrite EQ.
destruct (m2 σ).
rewrite HP.
reflexivity.
Qed.
constructor.
- intros. unfold eqM, eqM_state.
(* use functional extensionality(!) *)
apply functional_extensionality.
intro σ.
simpl.
destruct (ma σ).
reflexivity.
- intros. unfold eqM, eqM_state.
apply functional_extensionality.
intro σ.
reflexivity.
- intros. unfold eqM, eqM_state.
apply functional_extensionality.
intro σ.
simpl.
destruct (m σ).
reflexivity.
- repeat red.
intros A B m1 m2 EQ a1 a2 HP.
apply functional_extensionality.
intro σ.
simpl.
rewrite EQ.
destruct (m2 σ).
rewrite HP.
reflexivity.
Qed.
Lemma put_put {S} : ∀ (σ1 σ2 : S),
put σ1 ;; put σ2 ≈ put σ2.
Proof.
intros σ1 σ2.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
intros σ1 σ2.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
Lemma put_get {S A} : ∀ (σ : S) (k : S → state S A),
put σ ;; get ≈ put σ ;; ret σ.
put σ ;; get ≈ put σ ;; ret σ.
Proof.
intros σ1 k.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
intros σ1 k.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
Lemma get_put {S} :
s <- get ;; put s ≈ (ret tt : state S unit).
s <- get ;; put s ≈ (ret tt : state S unit).
Proof.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
Lemma get_get {S A} : ∀ (k : S → S → state S A),
s1 <- get ;; s2 <- get ;; k s1 s2 ≈ s1 <- get ;; k s1 s1.
s1 <- get ;; s2 <- get ;; k s1 s2 ≈ s1 <- get ;; k s1 s1.
Proof.
intros k.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.
intros k.
apply functional_extensionality.
intros σ0.
reflexivity.
Qed.