HyPre.pred.Sync

From Coq Require Import
  Classes.Morphisms
  Classes.RelationClasses.

From HyPre Require Export
  ListExtra
  Misc
  NatExtra
  Tac
  VectorNat.

(* notations *)
Import ListNotations.
Import VecNatNotations.

Section with_nat_nat.
  Context (t_sync' : nat -> nat).

  Definition t_sync_step' := fun (nm : nat * nat) => let (n,m) := nm in (S n, t_sync' n + m).

  Definition t_sync_n_steps' := fun n => Nat.iter n t_sync_step'.

  Definition t_sync_snd' := fun (n : nat) => snd (t_sync_n_steps' n (0,0)).

  Definition t_reachable' := fun (nm : nat * nat) => let (n,m) := nm in t_sync_snd' n = m.
End with_nat_nat.

(*Reserved Notation "`t_sync_step`" (at level 65).
Reserved Notation "`t_sync_n_steps`" (at level 65).
Reserved Notation "`t_sync_snd`" (at level 65).
Reserved Notation "`t_reachable`" (at level 65).*)


(* t_sync returns for a given src position, how many steps the target does (starting related)
   t_next returns for a given src position,
            how often t_sync may be (i.e. how many steps on src) until tgt moves
 *)

Class TotalSynchroniser
      (t_sync : nat -> nat)
      (t_next : nat -> nat) :=
  {
    t_sync_step := fun (nm : nat * nat) => let (n,m) := nm in (S n, t_sync n + m);
    t_sync_n_steps := fun n => Nat.iter n t_sync_step;
    t_sync_snd := fun (n : nat) => snd (t_sync_n_steps n (0,0));
    t_reachable := fun (nm : nat * nat) => let (n,m) := nm in t_sync_snd n = m;
    t_sync_fair : forall n, t_sync_snd n < t_sync_snd (n + (t_next n))
  }.
(*
  (* the general formulation here would be [nat * M] where M is some monoid *)
  Class TotalSynchroniser
        (t_sync : nat -> nat) (* argument is fst, result is snd of successor *)
        (t_next : nat -> nat) :=
    {
    t_sync_step' := fun (nm : nat * nat) => let (n,m) := nm in (S n, m + t_sync n);
    t_sync_snd := fun n => snd (Nat.iter n t_sync_step' (0,0));
    t_sync_n_steps := fun n m => (n + m, t_sync_snd (n + m));
    t_sync_step := t_sync_n_steps 1;
    t_reachable := fun (nm : nat * nat) => let (n,m) := nm in t_sync_snd n = m;
    t_sync_fair := forall n, t_sync_snd n < t_sync_snd (n + (t_next n))
    }.
 *)


Section with_k.
  Context {k : nat}.

By defining the synchroniser on Vector.t nat k, determinism is baked in: it is not possible to have a syncroniser that picks two different "next-steps" on different hyper states with the same progress on every trace, but that is ok, since by determinism there are no such two different hyper states. Using determinism we can build a syncroniser on Vector.t nat k from a synchroniser on Vector.t (list _) k

  Definition opt_S (b : bool) (n : nat) := if b then S n else n.

  (* probably, we won't need this version *)
  Class OneStepSynchroniser
        (os_sync : Vector.t nat k -> Vector.t bool k)
        (os_next : Vector.t nat k -> Fin.t k -> nat) :=
    {
    os_sync_step := fun v => Vector.map2 opt_S (os_sync v) v;
    os_sync_n := fun n => Nat.iter n os_sync_step;
    os_sync_fair : forall a, VecNatO opt_S os_sync a -> forall i, a [[i]] < Vector.nth (os_sync_n (os_next a i) a) i;
    }.

  Inductive LeftTotalNatO (Orcl : (nat * nat) -> nat)
    : (nat * nat) -> Prop
    :=
    | UNOnil : LeftTotalNatO Orcl (0,0)
    | UNOcons : forall x1 x2 y, LeftTotalNatO Orcl (x1,x2) -> y = (S x1, Orcl (x1,x2)) -> LeftTotalNatO Orcl y.

  (* sync returns given a vector of positions, for any trace how many steps are to be performed
     next returns given a vector of positions and an trace index i,
            how often sync may be performed until there is progress on i
   *)

  Class Synchroniser
        (sync : Vector.t nat k -> Vector.t nat k)
        (next : Vector.t nat k -> Fin.t k -> nat) :=
    {
    sync_step := fun v => (sync v) +vn+ v;
    sync_n_steps := fun n => Nat.iter n sync_step;
    reachable := VecNatO Nat.add sync;
    sync_fair : forall a, reachable a -> forall i, a [[i]] < Vector.nth (sync_n_steps (next a i) a) i;
    }.

End with_k.

Section with_sync.
  Context `{Sync : Synchroniser}.

  Definition sync_n_steps_cont n v
    := vn_sub (sync_n_steps n v) v.

  Lemma sync_step_vn_le
    : forall a, a (sync_step a).
  (* Plus.le_plus_l *)
  Proof.
    intro a.
    eapply Vec_Forall2_forall2.
    intros.
    unfold sync_step.
    setoid_rewrite Vec_nth_map2.
    lia.
  Qed.

  Lemma sync_n_steps_S
    : forall n a, sync_n_steps (S n) a = sync_step (sync_n_steps n a).
  Proof.
    unfold sync_n_steps, sync_step, Nat.iter.
    cbn.
    reflexivity.
  Qed.

  Lemma sync_n_steps_monotone
    : forall n m, n <= m -> forall a, sync_n_steps n a sync_n_steps m a.
  Proof.
    intros n m Hnm a.
    induction Hnm.
    - reflexivity.
    - rewrite sync_n_steps_S. rewrite IHHnm.
      eapply sync_step_vn_le.
  Qed.

  Lemma sync_step_reachable
    : forall a, reachable a -> reachable (sync_step a).
  Proof.
    intros.
    econstructor;eauto.
  Qed.

  Lemma sync_n_steps_reachable
    : forall a, reachable a -> forall n, reachable (sync_n_steps n a).
  Proof.
    intros a Ha n.
    revert a Ha.
    induction n; intros a Ha.
    - cbn. assumption.
    - cbn. setoid_rewrite sync_n_steps_S. eapply sync_step_reachable;eauto.
  Qed.

  Lemma sync_n_steps_plus
    : forall a, reachable a -> forall m n, sync_n_steps (m + n) a = sync_n_steps n (sync_n_steps m a).
  Proof.
    intros.
    rewrite PeanoNat.Nat.add_comm.
    unfold sync_n_steps, Nat.iter.
    eapply nat_rect_plus.
  Qed.

  Fixpoint next_n i n a : nat
    := match n with
       | O => O
       | S n => let m := next_n i n a in
               m + next (sync_n_steps m a ) i
       end.

  Lemma next_n_S i n a
    : next_n i (S n) a = next_n i n a + next (sync_n_steps (next_n i n a) a) i.
  Proof.
    reflexivity.
  Qed.

  Lemma fair_next_n
    : forall a, reachable a -> forall i n, Vector.nth a i + n <= Vector.nth (sync_n_steps (next_n i n a) a) i.
  Proof.
    intros a Ha i.
    (*setoid_rewrite nextn_nextn'.*)
    induction n.
    - cbn. lia.
    - specialize (sync_fair (sync_n_steps (next_n i n a) a)) as Hfair.
      exploit' Hfair. 1: eapply sync_n_steps_reachable;eauto.
      specialize (Hfair i).
      rewrite next_n_S.
      setoid_rewrite sync_n_steps_plus;eauto.
      lia.
  Qed.

  Lemma fair_next_n_ex
    : forall a, reachable a -> forall i n, exists m, Vector.nth a i + n <= Vector.nth (sync_n_steps m a) i.
  Proof.
    intros.
    exists (next_n i n a).
    eapply fair_next_n;eauto.
  Qed.

  Definition vn_next_postfix a b : nat
    := Vector.fold_left max O (vec_mapi2 (fun i ai bi => next_n i (bi - ai) a) a b).

  Lemma sync_reach
    : forall a, reachable a -> forall b, a b -> b sync_n_steps (vn_next_postfix a b) a.
  Proof.
    intros.
    eapply Vec_Forall2_forall2.
    intros i.
    setoid_rewrite (vn_eq_plus_minus _ _ _ H0).
    rewrite fair_next_n. 2:eauto.
    eapply vn_le_impl_forall_le.
    eapply sync_n_steps_monotone.
    unfold vn_next_postfix.
    setoid_rewrite <-Vec_fold_left_max with (i:=i).
    rewrite Vec_nth_mapi2.
    reflexivity.
  Qed.

  Lemma sync_n_steps_reach_sub_add v
    : VecNatO Nat.add sync (vn_sub (sync_n_steps (vn_next_postfix (vn_zero k) v) (vn_zero k)) v +vn+ v).
  Proof.
    rewrite vn_sub_add.
    - eapply sync_n_steps_reachable. econstructor.
    - eapply sync_reach.
      + econstructor.
      + eapply vn_zero_minimal.
  Qed.

End with_sync.

Lemma sync_trivial (k : nat)
  : Synchroniser (fun _ => Vector.const 1 k) (fun _ _ => 1).
Proof.
  econstructor.
  intros.
  cbn.
  rewrite Vec_nth_map2. rewrite Vector.const_nth. lia.
Qed.

Section with_tsync.
  Context `{TS : TotalSynchroniser}.

  Lemma t_sync_step_eq
    : t_sync_step' t_sync = t_sync_step.
  Proof.
    reflexivity.
  Qed.

  Lemma t_sync_n_steps_eq
    : t_sync_n_steps' t_sync = t_sync_n_steps.
  Proof.
    reflexivity.
  Qed.

  Lemma t_sync_snd_eq
    : t_sync_snd' t_sync = t_sync_snd.
  Proof.
    reflexivity.
  Qed.

  Lemma t_reachable_eq
    : t_reachable' t_sync = t_reachable.
  Proof.
    reflexivity.
  Qed.

  (* TODO Prove t_sync lemmas *)
  Lemma t_sync_snd_t_next_lt a
    : t_sync_snd a
      < t_sync_snd (a + (t_next a)).
  Proof.
    unfold t_sync_snd.
    eapply t_sync_fair.
  Qed.

  Lemma t_sync_snd_monotone n m
    : n <= m -> t_sync_snd n <= t_sync_snd m.
  Proof.
    intro H.
    unfold t_sync_snd, t_sync_n_steps.
    induction H.
    - reflexivity.
    - rewrite iter_S.
      unfold t_sync_step at 2.
      destruct (Nat.iter m t_sync_step (0, 0)) eqn:E.
      cbn.
      replace n1 with (snd (Nat.iter m t_sync_step (0,0))) in *.
      2: rewrite surjective_pairing in E at 1;inversion E;eauto.
      cbn in IHle.
      lia.
  Qed.

  Lemma t_reachable_inv m n
    : t_reachable (S m, n) -> t_reachable (m , n - t_sync m).
  Proof.
    revert n.
    induction m;intros.
    - unfold t_reachable, t_sync_snd, t_sync_n_steps, t_sync_step in *.
      cbn in *. lia.
    - unfold t_reachable, t_sync_snd, t_sync_n_steps, t_sync_step in *.
      rewrite <-H.
      setoid_rewrite iter_S.
      rewrite iter_S.
      destruct (Nat.iter m (fun nm : nat * nat => let (n0, m0) := nm in (S n0, t_sync n0 + m0)) (0, 0)) as [p q] eqn:E.
      cbn.
      enough (p = m).
      {
        subst p.
        lia.
      }
      clear - E.
      revert p q E.
      induction m;intros.
      + cbn in *.
        inversion E.
        reflexivity.
      + rewrite iter_S in E.
        destruct (Nat.iter m (fun nm : nat * nat => let (n0, m0) := nm in (S n0, t_sync n0 + m0))) as [x y].
        specialize (IHm x y).
        rewrite <-IHm;eauto.
        inversion E.
        reflexivity.
  Qed.

  Lemma t_reachable_inv2 m n n'
    : t_reachable (S m, n) -> t_reachable (m , n') -> n = t_sync m + n'.
  Proof.
    revert n.
    induction m;intros.
    - unfold t_reachable, t_sync_snd, t_sync_n_steps, t_sync_step in *.
      cbn in *. lia.
    - unfold t_reachable, t_sync_snd, t_sync_n_steps, t_sync_step in *.
      rewrite <-H.
      rewrite <-H0.
      setoid_rewrite iter_S.
      rewrite iter_S.
      destruct (Nat.iter m (fun nm : nat * nat => let (n0, m0) := nm in (S n0, t_sync n0 + m0)) (0, 0)) as [p q] eqn:E.
      cbn.
      enough (p = m).
      {
        subst p.
        lia.
      }
      clear - E.
      revert p q E.
      induction m;intros.
      + cbn in *.
        inversion E.
        reflexivity.
      + rewrite iter_S in E.
        destruct (Nat.iter m (fun nm : nat * nat => let (n0, m0) := nm in (S n0, t_sync n0 + m0))) as [x y].
        specialize (IHm x y).
        rewrite <-IHm;eauto.
        inversion E.
        reflexivity.
  Qed.

  Lemma tsync_ind (P : nat -> nat -> Prop)
    : P 0 0 -> (forall a α, P a α -> P (S a) (t_sync a + α)) -> forall a α, t_reachable (a, α) -> P a α.
  Proof.
    intros.
    revert α H1.
    induction a;intros.
    - unfold t_reachable in H1.
      unfold t_sync_snd in H1.
      unfold t_sync_n_steps in H1.
      unfold t_sync_step in H1.
      cbn in H1.
      subst α.
      eapply H.
    - eapply t_reachable_inv in H1 as H2.
      eapply IHa in H2 as H3.
      eapply H0 in H3;eauto.
      eapply t_reachable_inv2 in H2;eauto.
      rewrite H2.
      eapply H3.
  Qed.
End with_tsync.

Require Import Coq.Classes.Equivalence.

Fixpoint t_sync_related (t_sync : nat -> nat) `{Equivalence} (l1 l2 : list A)
  := match l1 with
     | []
       => True
     | a :: l1'
       => t_sync_related t_sync l1' l2
         /\ match nth_error l2 (t_sync_snd' t_sync (length l1)) with
           | Some b => R a b
           | None => True
           end
     end.

Lemma t_sync_trivial
  : TotalSynchroniser (fun _ => 1) (fun _ => 1).
Proof.
  econstructor.
  intros.
  induction n.
  - cbn. lia.
  - rewrite iter_S.
    replace (S n + 1) with (S (n + 1)) by lia.
    rewrite iter_S.
    replace (n + 1) with (S n) by lia.
    rewrite iter_S.
    setoid_rewrite surjective_pairing with (p:=Nat.iter n _ _).
    cbn.
    lia.
Qed.

Lemma t_sync_trivial_iff (n m : nat)
  : @t_reachable _ _ (t_sync_trivial) (n, m) <-> n = m.
Proof.
  unfold t_reachable,t_sync_snd,t_sync_n_steps,t_sync_step.
  cbn. revert m.
  - induction n;intros.
    + cbn in *; tauto.
    + rewrite iter_S.
      destruct (Nat.iter n (fun nm : nat * nat => let (n, m) := nm in (S n, S m)) (0, 0)). cbn.
      destruct m.
      * split;congruence.
      * cbn in IHn. specialize (IHn m).
        destruct IHn.
        split;intros;f_equal;eauto.
Qed.

From Coq Require Import Equality.
From HyPre Require Import VectorPair.

Section with_sync_tsynci.
  Context `{Sync : Synchroniser}.
  Context {t_synci t_nexti : Fin.t k -> nat -> nat}.
  Context {TSi : forall i, TotalSynchroniser (t_synci i) (t_nexti i)}.

  (* MAYDO idea: generalize sychroniser for any partially ordered monoid.
         lists are generalised as another monoid with a monotone injective homomorphism towards the first

         Definition merg_sync (v : Vector.t (Vector.t nat k) n) : Vector.t (Vector.t nat k) n.
   *)


  Definition merge_sync (v : vecnat (k+k)) : vecnat (k+k)
    := let v1 := sync (fst_vec v) in
       Vector.append ((vec_mapi (fun i n => @t_sync_snd _ _ (TSi i) n) (v1 +vn+ fst_vec v)) -vn- (snd_vec v)) v1.

  Definition merge_next (v : vecnat (k+k)) (i : Fin.t (k+k)) : nat
    := let (is_dom,j) := depair_i i in
       if is_dom then next (fst_vec v) j else next_n j (t_nexti j (fst_vec v [[j]])) (fst_vec v).

  (*      Lemma vn_ind_plus
        : forall (P : vecnat k -> Prop), P (vn_zero k) -> (forall (a b : vecnat k), P a -> P (a b*)


  Local Notation merge_sync_step := (fun v => (merge_sync v) +vn+ v).
  Local Notation merge_sync_n_steps n := (Nat.iter n merge_sync_step).
  Local Notation merge_reachable := (VecNatO Nat.add merge_sync).
  Local Notation t_reachable__i j := (@t_reachable _ _ (TSi j)).
  Local Notation t_sync_snd__i j := (@t_sync_snd _ _ (TSi j)).

  Lemma merge_sync_fst_vec_eq a
    : fst_vec (merge_sync a) = sync (fst_vec a).
  Proof.
    setoid_rewrite to_kk_fst_vec.
    reflexivity.
  Qed.

  Lemma merge_sync_snd_vec_eq a
    : snd_vec (merge_sync a) = vec_mapi (fun (i : Fin.t k) (n : nat) => t_sync_snd__i i n)
                                        (sync (fst_vec a) +vn+ fst_vec a)
                               -vn- snd_vec a.
  Proof.
    setoid_rewrite to_kk_snd_vec.
    reflexivity.
  Qed.

  Lemma merge_reach a
    : merge_reachable a -> reachable (fst_vec a).
  Proof.
    intros. dependent induction H.
    - rewrite fst_vec_const. econstructor.
    - eapply f_equal with (f:=fst_vec) in H0.
      rewrite fst_vec_map2 in H0.
      rewrite merge_sync_fst_vec_eq in H0.
      econstructor;eauto.
  Qed.

  Lemma merge_t_reach a
    : merge_reachable a -> forall j, t_reachable__i j (nth_pair j a).
  Proof.
    intros.
    unfold t_reachable__i. cbn.
    induction H.
    - rewrite fst_vec_const.
      rewrite snd_vec_const.
      rewrite Vector.const_nth.
      cbn. reflexivity.
    - subst yt.
      rewrite fst_vec_map2. rewrite snd_vec_map2.
      rewrite merge_sync_fst_vec_eq.
      rewrite merge_sync_snd_vec_eq.
      repeat rewrite Vec_nth_map2.
      rewrite vec_mapi_nth.
      rewrite Vec_nth_map2.
      match goal with
        |- _ = ?x - ?y + ?z => assert (y <= x) as Hge;
                                 [|replace (x - y + z) with x by lia]
      end.
      {
        rewrite <-IHVecNatO.
        eapply t_sync_snd_monotone.
        lia.
      }
      reflexivity.
  Qed.

  Lemma merge_next_fst_next i j a
    : (true, j) = depair_i i -> merge_next a i = next (fst_vec a) j.
  Proof.
    intros.
    unfold merge_next. inversion H.
    reflexivity.
  Qed.

  Lemma merge_sync_n_steps_fst_sync_n_steps n a
    : fst_vec (merge_sync_n_steps n a) = sync_n_steps n (fst_vec a).
  Proof.
    intros.
    induction n.
    - cbn. reflexivity.
    - rewrite iter_S.
      rewrite fst_vec_map2.
      rewrite IHn.
      rewrite merge_sync_fst_vec_eq.
      rewrite IHn.
      cbn.
      reflexivity.
  Qed.

  Lemma merge_fair_next_n n a j
    : 0 < k -> reachable (fst_vec a)
      -> fst_vec a [[j]] + n <= fst_vec (merge_sync_n_steps (next_n j n (fst_vec a)) a) [[j]].
  Proof.
    intros.
    rewrite merge_sync_n_steps_fst_sync_n_steps.
    eapply fair_next_n.
    assumption.
  Qed.

  Lemma merge_reachable_merge_sync_n_steps a n
    : merge_reachable a -> merge_reachable (merge_sync_n_steps n a).
  Proof.
    intros. revert dependent a.
    induction n;intros.
    - cbn. assumption.
    - rewrite iter_S.
      econstructor.
      + eapply IHn. eapply H.
      + reflexivity.
  Qed.

  Lemma us_sync_sync_sync'
    : 0 < k -> Synchroniser merge_sync merge_next.
  Proof.
    econstructor.
    intros a Hreach i.
    eapply merge_reach in Hreach as Hreach1.
    remember (depair_i i) as Q. destruct Q as [is_dom j].
    eapply merge_t_reach with (j:=j) in Hreach as Hreach2.
    destruct is_dom.
    - setoid_rewrite <-fst_vec_nth. 2,3:eauto.
      setoid_rewrite merge_sync_n_steps_fst_sync_n_steps.
      setoid_rewrite merge_next_fst_next. 2:eauto.
      eapply sync_fair with (i:=j) in Hreach1.
      now eauto.
    - unfold merge_next. cbn.
      setoid_rewrite <-snd_vec_nth. 2,3:eauto.
      destruct (depair_i i).
      inversion HeqQ. subst b t.
      unfold t_reachable in Hreach2. cbn in Hreach2.
      setoid_rewrite <-Hreach2.
      eapply PeanoNat.Nat.lt_le_trans.
      + eapply t_sync_snd_t_next_lt.
      + match goal with |- context [merge_sync_n_steps ?x _]
                        => eapply merge_reachable_merge_sync_n_steps in Hreach;
                            eapply merge_t_reach with (j:=j) in Hreach
        end.
        unfold t_reachable in Hreach. cbn in Hreach.
        rewrite <-Hreach.
        eapply t_sync_snd_monotone.
        eapply merge_fair_next_n;eauto.
  Qed.

(*
       the most general form would be:
       k-synchroniser times j-synchroniser, where there are overlaps. doesn't work:
           you may go 2 steps on all the k traces and 1 and then two steps on all the j traces and it never syncs.
       the weakest sufficient condition is that the overlapping traces will always eventually sync.
       easier version: the j-synchroniser always does exactly one step on the trace that is in the k-synchroniser.
 *)

End with_sync_tsynci.

Global Instance us_sync_sync_sync
       `{Sync : Synchroniser}
       {t_synci t_nexti : Fin.t k -> nat -> nat}
       {TSi : forall i : Fin.t k, TotalSynchroniser (t_synci i) (t_nexti i)}
  : Synchroniser (@merge_sync k sync t_synci t_nexti _) (@merge_next k sync next _ t_nexti).
Proof.
  destruct k.
  - econstructor.
    intros.
    inversion i.
  - eapply us_sync_sync_sync'.
    lia.
Qed.