open HolKernel Parse boolLib bossLib;

val _ = new_theory "e4_arrays";

(* For demonstration, let's define our own induction theorem *)

val MY_NUM_INDUCT = store_thm ("MY_NUM_INDUCT",
  ``!P. P 1 /\ (!n. (2 <= n /\ (!m. (m < n /\ m <> 0) ==> P m)) ==> P n) ==> (!n. n <> 0 ==> P n)``,
REPEAT STRIP_TAC >>
completeInduct_on `n` >>
Cases_on `n` >> FULL_SIMP_TAC arith_ss [] >>
Cases_on `n'` >> ASM_SIMP_TAC arith_ss [])

val num2boolList_def = Define `
  (num2boolList 0 = []) /\
  (num2boolList 1 = []) /\
  (num2boolList n = (EVEN n) :: num2boolList (n DIV 2))`

val num2boolList_REWRS = store_thm ("num2boolList_REWRS",
 ``(num2boolList 0 = []) /\
   (num2boolList 1 = []) /\
   (!n. 2 <= n ==> ((num2boolList n = (EVEN n) :: num2boolList (n DIV 2))))``,
REPEAT STRIP_TAC >| [
  METIS_TAC[num2boolList_def],
  METIS_TAC[num2boolList_def],

  `n <> 0 /\ n <> 1` by DECIDE_TAC >>
  METIS_TAC[num2boolList_def]
]);

val num2boolList_EQ_NIL = store_thm ("num2boolList_EQ_NIL",
  ``!n. (num2boolList n = []) <=> ((n = 0) \/ (n = 1))``,
GEN_TAC >> EQ_TAC >| [
  REPEAT STRIP_TAC >>
  CCONTR_TAC >>
  FULL_SIMP_TAC list_ss [num2boolList_REWRS],

  REPEAT STRIP_TAC >> (
    ASM_SIMP_TAC std_ss [num2boolList_REWRS]
  )
]);


val num2boolList_INJ = store_thm ("num2boolList_INJ",
  ``!n. n <> 0 ==> !m. m <> 0 ==> (num2boolList n = num2boolList m) ==> (n = m)``,

HO_MATCH_MP_TAC MY_NUM_INDUCT >>
CONJ_TAC >- (
  SIMP_TAC list_ss [num2boolList_REWRS, num2boolList_EQ_NIL]
) >>
GEN_TAC >> STRIP_TAC >> GEN_TAC >> STRIP_TAC >>
Cases_on `m = 1` >- (
  ASM_SIMP_TAC list_ss [num2boolList_REWRS]
) >>
ASM_SIMP_TAC list_ss [num2boolList_REWRS] >>
REPEAT STRIP_TAC >>
`n DIV 2 = m DIV 2` by (
  `(m DIV 2 <> 0) /\ (n DIV 2 <> 0) /\ (n DIV 2 < n)` suffices_by METIS_TAC[] >>

  ASM_SIMP_TAC arith_ss [arithmeticTheory.NOT_ZERO_LT_ZERO,
    arithmeticTheory.X_LT_DIV]
) >>
`n MOD 2 = m MOD 2` by (
  ASM_SIMP_TAC std_ss [arithmeticTheory.MOD_2]
) >>
`0 < 2` by DECIDE_TAC >>
METIS_TAC[arithmeticTheory.DIVISION]);



val num2arrayIndex_def = Define `num2arrayIndex n = (num2boolList (SUC n))`

val num2arrayIndex_INJ = store_thm ("num2arrayIndex_INJ",
  ``!n m. (num2arrayIndex n = num2arrayIndex m) <=> (n = m)``,

SIMP_TAC list_ss [num2arrayIndex_def] >>
METIS_TAC [numTheory.NOT_SUC, num2boolList_INJ, numTheory.INV_SUC]);


val arrayIndex2num_aux_def = Define `
  (arrayIndex2num_aux [] = 1) /\
  (arrayIndex2num_aux (F::idx) = 2 * arrayIndex2num_aux idx + 1) /\
  (arrayIndex2num_aux (T::idx) = 2 * arrayIndex2num_aux idx)`

val arrayIndex2num_def = Define `arrayIndex2num idx = PRE (arrayIndex2num_aux idx)`

val arrayIndex2num_aux_GT_0 = prove (``!idx. 0 < arrayIndex2num_aux idx``,
Induct >- SIMP_TAC arith_ss [arrayIndex2num_aux_def] >>
Cases >> ASM_SIMP_TAC arith_ss [arrayIndex2num_aux_def]);


val arrayIndex2num_aux_inv = prove (``!idx. num2boolList (arrayIndex2num_aux idx) = idx``,
Induct >- (
  SIMP_TAC arith_ss [arrayIndex2num_aux_def, num2boolList_REWRS]
) >>
`0 < arrayIndex2num_aux idx` by METIS_TAC[arrayIndex2num_aux_GT_0] >>
`0 < 2` by DECIDE_TAC >>
Cases >| [
  `!x. (2 * x) MOD 2 = 0` by
     METIS_TAC[arithmeticTheory.MOD_EQ_0, arithmeticTheory.MULT_COMM] >>
  `!x. (2 * x) DIV 2 = x` by
     METIS_TAC[arithmeticTheory.MULT_DIV, arithmeticTheory.MULT_COMM] >>
  ASM_SIMP_TAC list_ss [arrayIndex2num_aux_def, num2boolList_REWRS,
    arithmeticTheory.EVEN_MOD2],

  `!x y. (2 * x + y) MOD 2 = (y MOD 2)` by
     METIS_TAC[arithmeticTheory.MOD_TIMES, arithmeticTheory.MULT_COMM] >>
  `!x y. (2 * x + y) DIV 2 = x + y DIV 2` by
     METIS_TAC[arithmeticTheory.ADD_DIV_ADD_DIV, arithmeticTheory.MULT_COMM] >>
  ASM_SIMP_TAC list_ss [arrayIndex2num_aux_def, num2boolList_REWRS,
    arithmeticTheory.EVEN_MOD2]
]);


val arrayIndex2num_inv = store_thm ("arrayIndex2num_inv",
  ``!idx. num2arrayIndex (arrayIndex2num idx) = idx``,
GEN_TAC >>
REWRITE_TAC[num2arrayIndex_def, arrayIndex2num_def] >>
`0 < arrayIndex2num_aux idx` by METIS_TAC[arrayIndex2num_aux_GT_0] >>
FULL_SIMP_TAC arith_ss [arithmeticTheory.SUC_PRE] >>
ASM_SIMP_TAC std_ss [arrayIndex2num_aux_inv, listTheory.REVERSE_REVERSE]);


val num2arrayIndex_inv = store_thm ("num2arrayIndex_inv",
  ``!n. arrayIndex2num (num2arrayIndex n) = n``,
METIS_TAC[ num2arrayIndex_INJ, arrayIndex2num_inv]);

val arrayIndex2num_INJ = store_thm ("arrayIndex2num_INJ",
  ``!idx1 idx2. (arrayIndex2num idx1 = arrayIndex2num idx2) <=> (idx1 = idx2)``,
METIS_TAC[ num2arrayIndex_INJ, arrayIndex2num_inv]);



(* TODO: Define a datatype for arrays storing values of type 'a. *)
val _ = Datatype `array = ...`


(* TODO: Define a new, empty array *)
val EMPTY_ARRAY_def = Define `EMPTY_ARRAY : 'a array = ...`

(* TODO: define ILOOKUP, IUPDATE and IREMOVE *)
val UPDATE_def = Define `IUPDATE (v : 'a) (a : 'a array) (k : bool list) = (...):'a array`
val LOOKUP_def = Define `ILOOKUP (a : 'a array) (k : bool list) = (...):'a option`
val REMOVE_def = Define `IREMOVE (a : 'a array) (k : bool list) = (...):'a array`



val LOOKUP_def = Define `LOOKUP a n = ILOOKUP a (num2arrayIndex n)`
val UPDATE_def = Define `UPDATE v a n = IUPDATE v a (num2arrayIndex n)`
val REMOVE_def = Define `REMOVE a n = IREMOVE a (num2arrayIndex n)`


(* TODO: show a few properties *)
val LOOKUP_EMPTY = store_thm ("LOOKUP_EMPTY",
  ``!k. LOOKUP EMPTY_ARRAY k = NONE``,
cheat);

val LOOKUP_UPDATE = store_thm ("LOOKUP_UPDATE",
  ``!v n n' a. LOOKUP (UPDATE v a n) n' =
       (if (n = n') then SOME v else LOOKUP a n')``,
cheat);

val LOOKUP_REMOVE = store_thm ("LOOKUP_REMOVE",
  ``!n n' a. LOOKUP (REMOVE a n) n' =
       (if (n = n') then NONE else LOOKUP a n')``,
cheat);


val UPDATE_REMOVE_EQ = store_thm ("UPDATE_REMOVE_EQ", ``
  (!v1 v2 n a. UPDATE v1 (UPDATE v2 a n) n = UPDATE v1 a n) /\
  (!v n a. UPDATE v (REMOVE a n) n = UPDATE v a n) /\
  (!v n a. REMOVE (UPDATE v a n) n = REMOVE a n)
``,
cheat);


val UPDATE_REMOVE_NEQ = store_thm ("UPDATE_REMOVE_NEQ", ``
  (!v1 v2 a n1 n2. n1 <> n2 ==>
     ((UPDATE v1 (UPDATE v2 a n2) n1) = (UPDATE v2 (UPDATE v1 a n1) n2))) /\
  (!v a n1 n2. n1 <> n2 ==>
     ((UPDATE v (REMOVE a n2) n1) = (REMOVE (UPDATE v a n1) n2))) /\
  (!a n1 n2. n1 <> n2 ==>
     ((REMOVE (REMOVE a n2) n1) = (REMOVE (REMOVE a n1) n2)))``,
cheat);


val _ = export_theory();
