diff --git a/README.md b/README.md index 82fa54e6..ebc824ad 100644 --- a/README.md +++ b/README.md @@ -69,9 +69,9 @@ FilesystemStore.Group.create store group_node;; let array_node = Node.Array.(group_node / "name");; (* creates an array with char data type and fill value '?' *) FilesystemStore.Array.create - ~codecs:[`Transpose [|2; 0; 1|]; `Bytes BE; `Gzip L2] - ~shape:[|100; 100; 50|] - ~chunks:[|10; 15; 20|] + ~codecs:[`Transpose [2; 0; 1]; `Bytes BE; `Gzip L2] + ~shape:[100; 100; 50] + ~chunks:[10; 15; 20] Ndarray.Char '?' array_node @@ -79,7 +79,7 @@ FilesystemStore.Array.create ``` ### read/write from/to an array ```ocaml -let slice = [|R [|0; 20|]; I 10; R [||]|];; +let slice = [R (0, 20); I 10; F];; let x = FilesystemStore.Array.read store array_node slice Ndarray.Char;; (* Do some computation on the array slice *) let x' = Zarr.Ndarray.map (fun _ -> Random.int 256 |> Char.chr) x;; @@ -90,8 +90,8 @@ assert (Ndarray.equal x' y);; ### create an array with sharding ```ocaml let config = - {chunk_shape = [|5; 3; 5|] - ;codecs = [`Transpose [|2; 0; 1|]; `Bytes LE; `Zstd (0, true)] + {chunk_shape = [5; 3; 5] + ;codecs = [`Transpose [2; 0; 1]; `Bytes LE; `Zstd (0, true)] ;index_codecs = [`Bytes BE; `Crc32c] ;index_location = Start};; @@ -99,8 +99,8 @@ let shard_node = Node.Array.(group_node / "another");; FilesystemStore.Array.create ~codecs:[`ShardingIndexed config] - ~shape:[|100; 100; 50|] - ~chunks:[|10; 15; 20|] + ~shape:[100; 100; 50] + ~chunks:[10; 15; 20] Ndarray.Complex32 Complex.zero shard_node @@ -114,7 +114,7 @@ List.map Node.Array.to_path a;; List.map Node.Group.to_path g;; (*- : string list = ["/"; "/some"; "/some/group"] *) -FilesystemStore.Array.reshape store array_node [|25; 32; 10|];; +FilesystemStore.Array.reshape store array_node [25; 32; 10];; let meta = FilesystemStore.Group.metadata store group_node;; Metadata.Group.show meta;; (* pretty prints the contents of the metadata *) diff --git a/zarr-eio/test/test_eio.ml b/zarr-eio/test/test_eio.ml index 1e27194d..7bf8b8ec 100644 --- a/zarr-eio/test/test_eio.ml +++ b/zarr-eio/test/test_eio.ml @@ -6,7 +6,6 @@ open Zarr_eio.Storage let string_of_list = [%show: string list] let print_node_pair = [%show: Node.Array.t list * Node.Group.t list] -let print_int_array = [%show : int array] module type EIO_STORE = Zarr.Storage.S with type 'a io := 'a @@ -40,19 +39,17 @@ let test_storage assert_equal ~printer:string_of_bool false exists; let cfg = - {chunk_shape = [|2; 5; 5|] + {chunk_shape = [2; 5; 5] ;index_location = End ;index_codecs = [`Bytes BE] ;codecs = [`Bytes LE]} in let anode = Node.Array.(gnode / "arrnode") in - let slice = [|R [|0; 20|]; I 10; R [|0; 29|]|] in - let exp = Ndarray.init Complex32 [|21; 1; 30|] (Fun.const Complex.one) in + let slice = [R (0, 20); I 10; R (0, 29)] in + let exp = Ndarray.init Complex32 [21; 1; 30] (Fun.const Complex.one) in List.iter (fun codecs -> - Array.create - ~codecs ~shape:[|100; 100; 50|] ~chunks:[|10; 15; 20|] - Complex32 Complex.one anode store; + Array.create ~codecs ~shape:[100; 100; 50] ~chunks:[10; 15; 20] Complex32 Complex.one anode store; Array.write store anode slice exp; let got = Array.read store anode slice Complex32 in assert_equal exp got; @@ -72,44 +69,29 @@ let test_storage let child = Node.Group.of_path "/some/child/group" in Group.create store child; let arrays, groups = Group.children store gnode in - assert_equal - ~printer:string_of_list ["/arrnode"] (List.map Node.Array.to_path arrays); - assert_equal - ~printer:string_of_list ["/some"] (List.map Node.Group.to_path groups); - + assert_equal ~printer:string_of_list ["/arrnode"] (List.map Node.Array.to_path arrays); + assert_equal ~printer:string_of_list ["/some"] (List.map Node.Group.to_path groups); let c = Group.children store @@ Node.Group.(root / "fakegroup") in assert_equal ([], []) c; - let ac, gc = hierarchy store in - let got = - List.fast_sort String.compare @@ - List.map Node.Array.show ac @ List.map Node.Group.show gc in - assert_equal - ~printer:string_of_list - ["/"; "/arrnode"; "/some"; "/some/child"; "/some/child/group"] got; - + let got = List.fast_sort String.compare @@ List.map Node.Array.show ac @ List.map Node.Group.show gc in + assert_equal ~printer:string_of_list ["/"; "/arrnode"; "/some"; "/some/child"; "/some/child/group"] got; (* tests for renaming nodes *) let some = Node.Group.of_path "/some/child" in Array.rename store anode "ARRAYNODE"; Group.rename store some "CHILD"; let ac, gc = hierarchy store in - let got = - List.fast_sort String.compare @@ - List.map Node.Array.show ac @ List.map Node.Group.show gc in - assert_equal - ~printer:string_of_list - ["/"; "/ARRAYNODE"; "/some"; "/some/CHILD"; "/some/CHILD/group"] got; + let got = List.fast_sort String.compare (List.map Node.Array.show ac @ List.map Node.Group.show gc) in + assert_equal ~printer:string_of_list ["/"; "/ARRAYNODE"; "/some"; "/some/CHILD"; "/some/CHILD/group"] got; (* restore old array node name. *) Array.rename store (Node.Array.of_path "/ARRAYNODE") "arrnode"; - - let nshape = [|25; 32; 10|] in + let nshape = [25; 32; 10] in Array.reshape store anode nshape; let meta = Array.metadata store anode in - assert_equal ~printer:print_int_array nshape @@ Metadata.Array.shape meta; + assert_equal ~printer:[%show : int list] nshape @@ Metadata.Array.shape meta; assert_raises (Zarr.Storage.Key_not_found "fakegroup/zarr.json") (fun () -> Array.metadata store Node.Array.(gnode / "fakegroup")); - Array.delete store anode; clear store; let got = hierarchy store in @@ -120,22 +102,19 @@ let _ = "test eio-based stores" >:: (fun _ -> Eio_main.run @@ fun env -> - let rand_num = string_of_int @@ Random.int 10_000 in + let rand_num = string_of_int (Random.int 10_000) in let tmp_dir = Filename.(concat (get_temp_dir_name ()) (rand_num ^ ".zarr")) in let s = FilesystemStore.create ~env tmp_dir in assert_raises (Sys_error (Format.sprintf "%s: File exists" tmp_dir)) (fun () -> FilesystemStore.create ~env tmp_dir); - (* ensure it works with an extra "/" appended to directory name. *) ignore @@ FilesystemStore.open_store ~env (tmp_dir ^ "/"); - let fakedir = "non-existant-zarr-store112345.zarr" in assert_raises (Sys_error (Printf.sprintf "%s: No such file or directory" fakedir)) (fun () -> FilesystemStore.open_store ~env fakedir); - let fn = Filename.temp_file "nonexistantfile" ".zarr" in assert_raises (Zarr.Storage.Not_a_filesystem_store fn) @@ -146,6 +125,6 @@ let _ = ZipStore.with_open `Read_write zpath (fun z -> test_storage (module ZipStore) z); (* test just opening the now exisitant archive created by the previous test. *) ZipStore.with_open `Read_only zpath (fun _ -> ()); - test_storage (module MemoryStore) @@ MemoryStore.create (); + test_storage (module MemoryStore) (MemoryStore.create ()); test_storage (module FilesystemStore) s) ]) diff --git a/zarr-lwt/test/test_lwt.ml b/zarr-lwt/test/test_lwt.ml index 3221fe95..d3fadd27 100644 --- a/zarr-lwt/test/test_lwt.ml +++ b/zarr-lwt/test/test_lwt.ml @@ -6,7 +6,6 @@ open Zarr_lwt.Storage let string_of_list = [%show: string list] let print_node_pair = [%show: Node.Array.t list * Node.Group.t list] -let print_int_array = [%show : int array] module type LWT_STORE = Zarr.Storage.S with type 'a io := 'a Lwt.t @@ -41,18 +40,18 @@ let test_storage assert_equal ~printer:string_of_bool false exists; let cfg = - {chunk_shape = [|2; 5; 5|] + {chunk_shape = [2; 5; 5] ;index_location = End ;index_codecs = [`Bytes BE] ;codecs = [`Bytes LE]} in let anode = Node.Array.(gnode / "arrnode") in - let slice = [|R [|0; 20|]; I 10; R [|0; 29|]|] in - let exp = Ndarray.init Ndarray.Complex32 [|21; 1; 30|] (Fun.const Complex.one) in + let slice = [R (0, 20); I 10; R (0, 29)] in + let exp = Ndarray.init Ndarray.Complex32 [21; 1; 30] (Fun.const Complex.one) in Lwt_list.iter_s (fun codecs -> Array.create - ~codecs ~shape:[|100; 100; 50|] ~chunks:[|10; 15; 20|] + ~codecs ~shape:[100; 100; 50] ~chunks:[10; 15; 20] Ndarray.Complex32 Complex.one anode store >>= fun () -> Array.write store anode slice exp >>= fun () -> Array.read store anode slice Complex32 >>= fun got -> @@ -103,10 +102,10 @@ let test_storage (* restore old array node name. *) Array.rename store (Node.Array.of_path "/ARRAYNODE") "arrnode" >>= fun () -> - let nshape = [|25; 32; 10|] in + let nshape = [25; 32; 10] in Array.reshape store anode nshape >>= fun () -> Array.metadata store anode >>= fun meta -> - assert_equal ~printer:print_int_array nshape @@ Metadata.Array.shape meta; + assert_equal ~printer:[%show : int list] nshape @@ Metadata.Array.shape meta; Array.delete store anode >>= fun () -> clear store >>= fun () -> diff --git a/zarr-sync/test/test_sync.ml b/zarr-sync/test/test_sync.ml index 4d081b42..942949d1 100644 --- a/zarr-sync/test/test_sync.ml +++ b/zarr-sync/test/test_sync.ml @@ -6,7 +6,6 @@ open Zarr_sync.Storage let string_of_list = [%show: string list] let print_node_pair = [%show: Node.Array.t list * Node.Group.t list] -let print_int_array = [%show : int array] module type SYNC_STORE = Zarr.Storage.S with type 'a io := 'a @@ -21,6 +20,7 @@ let test_storage Group.create store gnode; let exists = Group.exists store gnode in assert_equal ~printer:string_of_bool true exists; + assert_equal ~printer:print_node_pair ([], [gnode]) (hierarchy store); let meta = Group.metadata store gnode in assert_equal ~printer:Metadata.Group.show Metadata.Group.default meta; @@ -40,25 +40,23 @@ let test_storage assert_equal ~printer:string_of_bool false exists; let cfg = - {chunk_shape = [|2; 5; 5|] + {chunk_shape = [2; 5; 5] ;index_location = End ;index_codecs = [`Bytes LE; `Crc32c] - ;codecs = [`Transpose [|2; 0; 1|]; `Bytes BE; `Zstd (0, false)]} in + ;codecs = [`Transpose [2; 0; 1]; `Bytes BE; `Zstd (0, false)]} in let cfg2 = - {chunk_shape = [|2; 5; 5|] + {chunk_shape = [2; 5; 5] ;index_location = Start ;index_codecs = [`Bytes BE] ;codecs = [`Bytes LE]} in let anode = Node.Array.(gnode / "arrnode") in - let slice = [|R [|0; 20|]; I 10; R [|0; 29|]|] in - let bigger_slice = [|R [|0; 21|]; L [|9; 10|] ; R [|0; 30|]|] in + let slice = [R (0, 20); I 10; R (0, 29)] in + let bigger_slice = [R (0, 21); L [9; 10] ; R (0, 30)] in List.iter (fun codecs -> - Array.create - ~codecs ~shape:[|100; 100; 50|] ~chunks:[|10; 15; 20|] - Complex32 Complex.one anode store; - let exp = Ndarray.init Complex32 [|21; 1; 30|] (Fun.const Complex.one) in + Array.create ~codecs ~shape:[100; 100; 50] ~chunks:[10; 15; 20] Complex32 Complex.one anode store; + let exp = Ndarray.init Complex32 [21; 1; 30] (Fun.const Complex.one) in let got = Array.read store anode slice Complex32 in assert_equal exp got; Ndarray.fill exp Complex.{re=2.0; im=0.}; @@ -68,7 +66,7 @@ let test_storage let _ = Array.read store anode bigger_slice Complex32 in assert_equal exp got; (* test writing a bigger slice to store *) - Array.write store anode bigger_slice @@ Ndarray.init Complex32 [|22; 2; 31|] (Fun.const Complex.{re=0.; im=3.0}); + Array.write store anode bigger_slice @@ Ndarray.init Complex32 [22; 2; 31] (Fun.const Complex.{re=0.; im=3.0}); let got = Array.read store anode slice Complex32 in Ndarray.fill exp Complex.{re=0.; im=3.0}; assert_equal exp got; @@ -76,12 +74,10 @@ let test_storage [[`ShardingIndexed cfg]; [`ShardingIndexed cfg2]]; (* repeat tests for non-sharding codec chain *) - Array.create - ~sep:`Dot ~codecs:[`Bytes BE] - ~shape:[|100; 100; 50|] ~chunks:[|10; 15; 20|] - Ndarray.Int Int.max_int anode store; + Array.create ~sep:`Dot ~codecs:[`Bytes BE] ~shape:[100; 100; 50] ~chunks:[10; 15; 20] Ndarray.Int Int.max_int anode store; + assert_equal ~printer:print_node_pair ([anode], [gnode]) (hierarchy store); (* test path where there is no chunk key present in store *) - let exp = Ndarray.init Int [|21; 1; 30|] (Fun.const Int.max_int) in + let exp = Ndarray.init Int [21; 1; 30] (Fun.const Int.max_int) in Array.write store anode slice exp; let got = Array.read store anode slice Int in assert_equal exp got; @@ -93,7 +89,7 @@ let test_storage assert_raises (Zarr.Storage.Invalid_data_type) (fun () -> Array.read store anode slice Ndarray.Char); - let badslice = [|R [|0; 20|]; I 10; R [||]; R [||] |] in + let badslice = [R (0, 20); I 10; F; F] in assert_raises (Zarr.Storage.Invalid_array_slice) (fun () -> Array.read store anode badslice Ndarray.Int); @@ -102,8 +98,8 @@ let test_storage (fun () -> Array.write store anode badslice exp); assert_raises (Zarr.Storage.Invalid_array_slice) - (fun () -> Array.write store anode [|R [|0; 20|]; R [||]; R [||]|] exp); - let badarray = Ndarray.init Float64 [|21; 1; 30|] (Fun.const 0.) in + (fun () -> Array.write store anode [R (0, 20); F; F] exp); + let badarray = Ndarray.init Float64 [21; 1; 30] (Fun.const 0.) in assert_raises (Zarr.Storage.Invalid_data_type) (fun () -> Array.write store anode slice badarray); @@ -147,13 +143,13 @@ let test_storage (* restore old array node name. *) Array.rename store (Node.Array.of_path "/ARRAYNODE") "arrnode"; - let nshape = [|25; 32; 10|] in + let nshape = [25; 32; 10] in Array.reshape store anode nshape; let meta = Array.metadata store anode in - assert_equal ~printer:print_int_array nshape @@ Metadata.Array.shape meta; + assert_equal ~printer:[%show : int list] nshape @@ Metadata.Array.shape meta; assert_raises (Zarr.Storage.Invalid_resize_shape) - (fun () -> Array.reshape store anode [|25; 10|]); + (fun () -> Array.reshape store anode [25; 10]); assert_raises (Zarr.Storage.Key_not_found "fakegroup/zarr.json") (fun () -> Array.metadata store Node.Array.(gnode / "fakegroup")); diff --git a/zarr/src/codecs.ml b/zarr/src/codecs.ml index 03b28fe9..23c20e5a 100644 --- a/zarr/src/codecs.ml +++ b/zarr/src/codecs.ml @@ -4,7 +4,7 @@ exception Invalid_sharding_chunk_shape exception Invalid_codec_ordering exception Invalid_zstd_level -type arraytoarray = [ `Transpose of int array ] +type arraytoarray = [ `Transpose of int list ] type compression_level = L0 | L1 | L2 | L3 | L4 | L5 | L6 | L7 | L8 | L9 type fixed_bytestobytes = [ `Crc32c ] type variable_bytestobytes = [ `Gzip of compression_level | `Zstd of int * bool ] @@ -14,41 +14,36 @@ type endianness = LE | BE type fixed_arraytobytes = [ `Bytes of endianness ] type variable_arraytobytes = [ `ShardingIndexed of internal_shard_config ] and internal_shard_config = - {chunk_shape : int array + {chunk_shape : int list ;codecs : ([fixed_arraytobytes | `ShardingIndexed of internal_shard_config ], bytestobytes) chain ;index_codecs : (fixed_arraytobytes, fixed_bytestobytes) chain ;index_location : loc} and ('a, 'b) chain = {a2a : arraytoarray list; a2b : 'a; b2b : 'b list} type arraytobytes = [ fixed_arraytobytes | variable_arraytobytes ] -type 'a array_repr = {kind : 'a Ndarray.dtype; shape : int array} +type 'a array_repr = {kind : 'a Ndarray.dtype; shape : int list} module ArrayToArray = struct module Transpose = struct let encoded_size : int -> int = Fun.id let encode order x = Ndarray.transpose ~axes:order x + let encoded_repr ~order (shape : int list) = List.map (fun x -> List.nth shape x) order - let encoded_repr : order:int array -> int array -> int array = fun ~order shape -> - Array.map (fun x -> shape.(x)) order - - let parse : order:int array -> int array -> unit = fun ~order shape -> - let o = Array.copy order in - Array.fast_sort Int.compare o; - let l = Array.length order in - if l = 0 || l <> Array.length shape || o <> Array.(init (length o) Fun.id) + let parse ~order (shape : int list) = + let o = List.fast_sort Int.compare order in + let l = List.length o in + if l = 0 || List.compare_length_with shape l <> 0 || o <> List.init l Fun.id then raise Invalid_transpose_order else () let decode o x = - let inv_order = Array.(make (length o) 0) in - Array.iteri (fun i x -> inv_order.(x) <- i) o; - Ndarray.transpose ~axes:inv_order x + let inv_order = Array.(make (List.length o) 0) in + List.iteri (fun i x -> inv_order.(x) <- i) o; + Ndarray.transpose ~axes:(Array.to_list inv_order) x - let to_yojson : int array -> Yojson.Safe.t = fun order -> - let o = `List (List.map (fun x -> `Int x) (Array.to_list order)) in + let to_yojson order : Yojson.Safe.t = + let o = `List (List.map (fun x -> `Int x) order) in `Assoc [("name", `String "transpose"); ("configuration", `Assoc ["order", o])] - let rec of_yojson : - int array -> Yojson.Safe.t -> ([`Transpose of int array], string) result - = fun chunk_shape x -> + let rec of_yojson chunk_shape x : ([`Transpose of int list], string) result = match Yojson.Safe.Util.(member "configuration" x) with | `Assoc [("order", `List o)] -> let accum a acc = Result.bind acc (add_as_int a) in @@ -56,35 +51,33 @@ module ArrayToArray = struct | _ -> Error "Invalid transpose configuration." and add_as_int v acc = match v with - | `Int i -> Ok (i :: acc) - | _ -> Error "transpose order values must be integers." + | `Int i when i >= 0 -> Ok (i :: acc) + | _ -> Error "transpose order values must be non-negative integers." - and to_codec ~chunk_shape o = - let order = Array.of_list o in - match parse ~order chunk_shape with + and to_codec ~chunk_shape order = match parse ~order chunk_shape with | exception Invalid_transpose_order -> Error "Invalid_transpose_order" | () -> Ok (`Transpose order) end - let parse : arraytoarray -> int array -> unit = fun t shape -> match t with + let parse (t : arraytoarray) shape = match t with | `Transpose order -> Transpose.parse ~order shape - let encoded_size : int -> arraytoarray -> int = fun input_size -> function + let encoded_size input_size (t : arraytoarray) = match t with | `Transpose _ -> Transpose.encoded_size input_size - let encoded_repr : int array -> arraytoarray -> int array = fun shape t -> match t with + let encoded_repr shape (t : arraytoarray) = match t with | `Transpose order -> Transpose.encoded_repr ~order shape - let encode : 'a Ndarray.t -> arraytoarray -> 'a Ndarray.t = fun x -> function + let encode x (t : arraytoarray) = match t with | `Transpose order -> Transpose.encode order x - let decode : arraytoarray -> 'a Ndarray.t -> 'a Ndarray.t = fun t x -> match t with + let decode (t : arraytoarray) x = match t with | `Transpose order -> Transpose.decode order x let to_yojson : arraytoarray -> Yojson.Safe.t = function | `Transpose order -> Transpose.to_yojson order - let of_yojson cs x = match Util.get_name x with + let of_yojson cs x : (arraytoarray, string) result = match Util.get_name x with | "transpose" -> Transpose.of_yojson cs x | s -> Error (Printf.sprintf "array->array codec %s not supported" s) end @@ -158,19 +151,19 @@ module BytesToBytes = struct | _ -> Error "Invalid Zstd configuration." end - let encoded_size : int -> fixed_bytestobytes -> int = fun input -> function + let encoded_size input (t : fixed_bytestobytes) = match t with | `Crc32c -> Crc32c.encoded_size input let parse : bytestobytes -> unit = function | `Zstd (l, _) -> Zstd.parse_clevel l | (`Gzip _ | `Crc32c) -> () - let encode : string -> bytestobytes -> string = fun x -> function + let encode x (t : bytestobytes) = match t with | `Gzip l -> Gzip.encode l x | `Crc32c -> Crc32c.encode x | `Zstd (l, c) -> Zstd.encode l c x - let decode : bytestobytes -> string -> string = fun t x -> match t with + let decode (t : bytestobytes) x = match t with | `Gzip _ -> Gzip.decode x | `Crc32c -> Crc32c.decode x | `Zstd _ -> Zstd.decode x @@ -180,14 +173,14 @@ module BytesToBytes = struct | `Crc32c -> Crc32c.to_yojson | `Zstd (l, c) -> Zstd.to_yojson l c - let of_yojson : Yojson.Safe.t -> (bytestobytes, string) result = fun x -> match Util.get_name x with + let of_yojson x : (bytestobytes, string) result = match Util.get_name x with | "gzip" -> Gzip.of_yojson x | "crc32c" -> Crc32c.of_yojson x | "zstd" -> Zstd.of_yojson x | s -> Error (Printf.sprintf "codec %s is not supported." s) end -module ArrayMap = Util.ArrayMap +module CoordMap = Util.CoordMap module RegularGrid = Extensions.RegularGrid module rec ArrayToBytes : sig @@ -195,14 +188,14 @@ module rec ArrayToBytes : sig type t = internal_shard_config type get_partial_values = Types.range list -> string list IO.t type set_fn = ?append:bool -> (int * string) list -> unit IO.t - val partial_encode : t -> get_partial_values -> set_fn -> int -> 'a array_repr -> (int array * 'a) list -> 'a -> unit IO.t - val partial_decode : t -> get_partial_values -> int -> 'a array_repr -> (int * int array) list -> 'a -> (int * 'a) list IO.t + val partial_encode : t -> get_partial_values -> set_fn -> int -> 'a array_repr -> (int list * 'a) list -> 'a -> unit IO.t + val partial_decode : t -> get_partial_values -> int -> 'a array_repr -> (int * int list) list -> 'a -> (int * 'a) list IO.t end - val parse : arraytobytes -> int array -> unit + val parse : arraytobytes -> int list -> unit val encoded_size : int -> fixed_arraytobytes -> int val encode : arraytobytes -> 'a Ndarray.t -> string val decode : arraytobytes -> 'a array_repr -> string -> 'a Ndarray.t - val of_yojson : int array -> Yojson.Safe.t -> (arraytobytes, string) result + val of_yojson : int list -> Yojson.Safe.t -> (arraytobytes, string) result val to_yojson : arraytobytes -> Yojson.Safe.t end = struct @@ -216,7 +209,7 @@ end = struct let add_binding ~grid acc (c, v) = let id, co = RegularGrid.index_coord_pair grid c in - ArrayMap.add_to_list id (co, v) acc + CoordMap.add_to_list id (co, v) acc (* specialized function for partially writing possibly multiple inner chunks to an empty shard of a designated array using the sharding indexed codec.*) @@ -226,12 +219,12 @@ end = struct List.iter (fun (c, v) -> Ndarray.set arr c v) z; let s = encode_chain t.codecs arr in let n = String.length s in - Ndarray.set index (Array.append i [|0|]) (Stdint.Uint64.of_int ofs); - Ndarray.set index (Array.append i [|1|]) (Stdint.Uint64.of_int n); + Ndarray.set index (i @ [0]) (Stdint.Uint64.of_int ofs); + Ndarray.set index (i @ [1]) (Stdint.Uint64.of_int n); ofs + n, (ofs, s) :: acc in - let cps = Array.map2 (/) repr.shape t.chunk_shape in - let index = Ndarray.create Uint64 (Array.append cps [|2|]) Stdint.Uint64.max_int in + let cps = List.map2 (/) repr.shape t.chunk_shape in + let index = Ndarray.create Uint64 (cps @ [2]) Stdint.Uint64.max_int in let init = match t.index_location with | Start -> index_size t.index_codecs cps | End -> 0 @@ -242,9 +235,8 @@ end = struct being a list of (coord-within-inner-chunk, new-value) pairs such that new-value is set for the coordinate coord-within-inner-chunk of the inner chunk represented by the associated key.*) - let m = List.fold_left (add_binding ~grid) ArrayMap.empty pairs in - let f = update_index ~t ~index ~fill_value ~repr in - let shard_size, ranges = ArrayMap.fold f m (init, []) in + let m = List.fold_left (add_binding ~grid) CoordMap.empty pairs in + let shard_size, ranges = CoordMap.fold (update_index ~t ~index ~fill_value ~repr) m (init, []) in let indexbytes = encode_index_chain t.index_codecs index in (* write all resultant (offset, bytes) pairs into the bytes of the new shard taking note to append/prepend the bytes of the shard's index array.*) @@ -255,16 +247,16 @@ end = struct (* function to partially write new values to one or more inner chunks of an existing shard using the sharding indexed codec. *) let partial_encode t get_partial (set_partial : set_fn) shardsize repr pairs fv = - let choose ~idx_arr ((i, _) as bd) = - let oc = Array.append i [|0|] and nc = Array.append i [|1|] in + let choose ~idx_arr key value (l, r) = + let oc = key @ [0] and nc = key @ [1] in match Ndarray.(get idx_arr oc, get idx_arr nc) with | o, n when Stdint.Uint64.(max_int = o && max_int = n) -> - Either.Left ((-1, None), (oc, nc, -1, 0, bd)) + ((-1, None), (oc, nc, -1, 0, value)) :: l, r | o, n -> let o', n' = Stdint.Uint64.(to_int o, to_int n) in - Either.Right ((o', Some n'), (oc, nc, o', n', bd)) + l, ((o', Some n'), (oc, nc, o', n', value)) :: r in - let accumulate_nonempty ~repr' ~idx_arr (acc, l, r) x (oc, nc, ofs, nb, (_, z)) = + let accumulate_nonempty ~repr' ~idx_arr (acc, l, r) x (oc, nc, ofs, nb, z) = let arr = decode_chain t.codecs repr' x in List.iter (fun (c, v) -> Ndarray.set arr c v) z; let s = encode_chain t.codecs arr in @@ -275,7 +267,7 @@ end = struct acc + nb', l, (acc, s) :: r end in - let accumulate_empty ~repr' ~idx_arr ~fv (ofs, l) (_, (oc, nc, _, _, (_, z))) = + let accumulate_empty ~repr' ~idx_arr ~fv (ofs, l) (_, (oc, nc, _, _, z)) = let arr = Ndarray.create repr'.kind repr'.shape fv in List.iter (fun (c, v) -> Ndarray.set arr c v) z; let s = encode_chain t.codecs arr in @@ -286,7 +278,7 @@ end = struct in (* begin *) if shardsize = 0 then partial_encode_empty_shard t set_partial repr pairs fv else - let cps = Array.map2 (/) repr.shape t.chunk_shape in + let cps = List.map2 (/) repr.shape t.chunk_shape in let is = index_size t.index_codecs cps in let* l = match t.index_location with | Start -> get_partial [0, Some is] @@ -295,11 +287,11 @@ end = struct let index_bytes = List.hd l in let idx_arr, _ = decode_index t cps index_bytes in let grid = RegularGrid.create ~array_shape:repr.shape t.chunk_shape in - let m = List.fold_left (add_binding ~grid) ArrayMap.empty pairs in + let m = List.fold_left (add_binding ~grid) CoordMap.empty pairs in (* split the finite map m into key-value pairs representing empty inner chunks and those that don't (using the fact that empty inner chunks have index array values equal to 2^64 - 1; then process these seperately.*) - let empty, nonempty = List.partition_map (choose ~idx_arr) (ArrayMap.bindings m) in + let empty, nonempty = CoordMap.fold (choose ~idx_arr) m ([], []) in let ranges, nonempty' = List.split nonempty in let* xs = get_partial ranges in let repr' = {repr with shape = t.chunk_shape} in @@ -335,31 +327,23 @@ end = struct | End -> set_partial ~append:true [bsize', ib] (* end *) - type indexed_coord = int * int array (* function to partially read values off of a non-empty shard previously encoded using the sharding indexed codec. *) - let partial_decode t get_partial shardsize repr (pairs : indexed_coord list) fill_value = + let partial_decode t get_partial shardsize repr (pairs : (int * int list) list) fill_value = let add_binding ~grid acc (i, y) = let id, c = RegularGrid.index_coord_pair grid y in - ArrayMap.add_to_list id (i, c) acc + CoordMap.add_to_list id (i, c) acc in - let choose ~index ((i, _) as bd) = - match Ndarray.(get index @@ Array.append i [|0|], - get index @@ Array.append i [|1|]) with - | o, n when Stdint.Uint64.(max_int = o && max_int = n) -> - Either.Left ((-1, None), bd) - | o, n -> - let o', n' = Stdint.Uint64.(to_int o, to_int n) in - Either.Right ((o', Some n'), bd) + let choose ~index key value (l, r) = match Ndarray.(get index (key @ [0]), get index (key @ [1])) with + | o, n when Stdint.Uint64.(max_int = o && max_int = n) -> ((-1, None), value) :: l, r + | o, n -> l, ((Stdint.Uint64.to_int o, Some (Stdint.Uint64.to_int n)), value) :: r in - let indexed_chunk_data ~repr' x (_, z) = + let indexed_chunk_data ~repr' acc x z = let arr = decode_chain t.codecs repr' x in - List.map (fun (i, c) -> i, Ndarray.get arr c) (z : indexed_coord list) - in - let indexed_empty_chunk (_, (_, z)) = - List.map (fun (i, _) -> i, fill_value) (z : indexed_coord list) + acc @ List.map (fun ((i, c) : int * int list) -> i, Ndarray.get arr c) z in - let cps = Array.map2 (/) repr.shape t.chunk_shape in + let indexed_empty_chunk (_, z) = List.map (fun ((i, _) : int * int list) -> i, fill_value) z in + let cps = List.map2 (/) repr.shape t.chunk_shape in let is = index_size t.index_codecs cps in let* l = match t.index_location with | Start -> get_partial [0, Some is] @@ -368,28 +352,28 @@ end = struct let index_bytes = List.hd l in let index, _ = decode_index t cps index_bytes in let grid = RegularGrid.create ~array_shape:repr.shape t.chunk_shape in - let m = List.fold_left (add_binding ~grid) ArrayMap.empty pairs in - let empty, nonempty = List.partition_map (choose ~index) (ArrayMap.bindings m) in + let m = List.fold_left (add_binding ~grid) CoordMap.empty pairs in + let empty, nonempty = CoordMap.fold (choose ~index) m ([], []) in let ranges, bindings = List.split nonempty in let+ xs = get_partial ranges in let repr' = {repr with shape = t.chunk_shape} in - let res1 = List.concat @@ List.map2 (indexed_chunk_data ~repr') xs bindings in + let res1 = List.fold_left2 (indexed_chunk_data ~repr') [] xs bindings in let res2 = List.concat_map indexed_empty_chunk empty in res1 @ res2 end - let parse : arraytobytes -> int array -> unit = fun t shape -> match t with + let parse (t : arraytobytes) shape = match t with | `Bytes _ -> () | `ShardingIndexed c -> ShardingIndexed.parse c shape - let encoded_size : int -> fixed_arraytobytes -> int = fun input_size -> function + let encoded_size input_size (t : fixed_arraytobytes) = match t with | `Bytes _ -> Bytes'.encoded_size input_size - let encode : arraytobytes -> 'a Ndarray.t -> string = fun t x -> match t with + let encode (t : arraytobytes) x = match t with | `Bytes endian -> Bytes'.encode x endian | `ShardingIndexed c -> ShardingIndexed.encode c x - let decode : arraytobytes -> 'a array_repr -> string -> 'a Ndarray.t = fun t repr b -> match t with + let decode (t : arraytobytes) repr b = match t with | `Bytes endian -> Bytes'.decode b repr endian | `ShardingIndexed c -> ShardingIndexed.decode c repr b @@ -397,10 +381,10 @@ end = struct | `Bytes endian -> Bytes'.to_yojson endian | `ShardingIndexed c -> ShardingIndexed.to_yojson c - let of_yojson : int array -> Yojson.Safe.t -> (arraytobytes, string) result = fun shp x -> + let of_yojson shape x : (arraytobytes, string) result = match Util.get_name x with | "bytes" -> Result.map (fun e -> `Bytes e) (Bytes'.of_yojson x) - | "sharding_indexed" -> Result.map (fun c -> `ShardingIndexed c) (ShardingIndexed.of_yojson shp x) + | "sharding_indexed" -> Result.map (fun c -> `ShardingIndexed c) (ShardingIndexed.of_yojson shape x) | _ -> Error ("array->bytes codec not supported: ") end @@ -417,7 +401,7 @@ end = struct | LE -> (module Ebuffer.Little : Ebuffer.S) | BE -> (module Ebuffer.Big : Ebuffer.S) - let encode (type a) (x : a Ndarray.t) (e : endianness) : string = + let encode (type a) (x : a Ndarray.t) e : string = let open (val endian_module e) in let buf = Bytes.create (Ndarray.byte_size x) in match Ndarray.data_type x with @@ -437,7 +421,7 @@ end = struct | Int -> Ndarray.iteri (set_int buf) x; Bytes.unsafe_to_string buf | Nativeint -> Ndarray.iteri (set_nativeint buf) x; Bytes.unsafe_to_string buf - let decode (type a) (str : string) (decoded : a array_repr) (e : endianness) : a Ndarray.t = + let decode (type a) (str : string) (decoded : a array_repr) e : a Ndarray.t = let open (val endian_module e) in let k, shape = decoded.kind, decoded.shape in let buf = Bytes.unsafe_of_string str in @@ -458,7 +442,7 @@ end = struct | Int, s -> Ndarray.init k shape (fun i -> get_int buf (i*s)) | Nativeint, s -> Ndarray.init k shape (fun i -> get_nativeint buf (i*s)) - let to_yojson : endianness -> Yojson.Safe.t = fun e -> + let to_yojson e : Yojson.Safe.t = let endian = match e with | LE -> "little" | BE -> "big" @@ -477,21 +461,21 @@ end and ShardingIndexed : sig type t = internal_shard_config - val parse : t -> int array -> unit + val parse : t -> int list -> unit val encode : t -> 'a Ndarray.t -> string val decode : t -> 'a array_repr -> string -> 'a Ndarray.t - val of_yojson : int array -> Yojson.Safe.t -> (t, string) result + val of_yojson : int list -> Yojson.Safe.t -> (t, string) result val to_yojson : t -> Yojson.Safe.t val encode_chain : (arraytobytes, bytestobytes) chain -> 'a Ndarray.t -> string val decode_chain : (arraytobytes, bytestobytes) chain -> 'a array_repr -> string -> 'a Ndarray.t - val decode_index : t -> int array -> string -> Stdint.uint64 Ndarray.t * string - val index_size : (fixed_arraytobytes, fixed_bytestobytes) chain -> int array -> int + val decode_index : t -> int list -> string -> Stdint.uint64 Ndarray.t * string + val index_size : (fixed_arraytobytes, fixed_bytestobytes) chain -> int list -> int val encode_index_chain : (fixed_arraytobytes, fixed_bytestobytes) chain -> Stdint.uint64 Ndarray.t -> string end = struct module Indexing = Ndarray.Indexing type t = internal_shard_config - let parse_chain : int array -> (arraytobytes, bytestobytes) chain -> unit = fun shape chain -> + let parse_chain : int list -> (arraytobytes, bytestobytes) chain -> unit = fun shape chain -> let shape' = match chain.a2a with | [] -> shape | x :: _ as xs -> @@ -501,12 +485,11 @@ end = struct ArrayToBytes.parse chain.a2b shape' let parse t shape = - if Array.(length shape <> length t.chunk_shape) - || not @@ Array.for_all2 (fun x y -> (x mod y) = 0) shape t.chunk_shape + if List.(length shape <> length t.chunk_shape) + || not @@ List.for_all2 (fun x y -> (x mod y) = 0) shape t.chunk_shape then raise Invalid_sharding_chunk_shape else parse_chain shape t.codecs; - let index_shape = Array.append shape [|2|] in - parse_chain index_shape (t.index_codecs :> (arraytobytes, bytestobytes) chain) + parse_chain (shape @ [2]) (t.index_codecs :> (arraytobytes, bytestobytes) chain) let encoded_size init chain = let a2a_size = List.fold_left ArrayToArray.encoded_size init chain.a2a in @@ -518,7 +501,7 @@ end = struct let b = ArrayToBytes.encode chain.a2b a in List.fold_left BytesToBytes.encode b chain.b2b - let encode_index_chain : (fixed_arraytobytes, fixed_bytestobytes) chain -> Stdint.uint64 Ndarray.t -> string = fun t x -> + let encode_index_chain (t : (fixed_arraytobytes, fixed_bytestobytes) chain) x = let y = match t.a2a with | [] -> x | `Transpose o :: _ -> Ndarray.transpose ~axes:o x @@ -531,26 +514,25 @@ end = struct let encode (type a) (t : t) (x : a Ndarray.t) = let add_coord ~grid ~arr coord acc = let k, c = RegularGrid.index_coord_pair grid coord in - ArrayMap.add_to_list k (c, Ndarray.get arr coord) acc + CoordMap.add_to_list k (c, Ndarray.get arr coord) acc in - let update_inner_chunk ~t ~shard_idx ~kind idx pairs (ofs, xs) = + let update_inner_chunk ~t ~shard_idx ~kind i pairs (ofs, xs) = let v = Array.of_list (List.map snd pairs) in let x' = Ndarray.of_array kind t.chunk_shape v in let b = encode_chain t.codecs x' in - let nb = Stdint.Uint64.of_int @@ String.length b in - Ndarray.set shard_idx (Array.append idx [|0|]) ofs; - Ndarray.set shard_idx (Array.append idx [|1|]) nb; + let nb = Stdint.Uint64.of_int (String.length b) in + Ndarray.set shard_idx (i @ [0]) ofs; + Ndarray.set shard_idx (i @ [1]) nb; Stdint.Uint64.(ofs + nb), b :: xs in let shard_shape = Ndarray.shape x in - let cps = Array.map2 (/) shard_shape t.chunk_shape in - let idx_shp = Array.append cps [|2|] in - let shard_idx = Ndarray.create Uint64 idx_shp Stdint.Uint64.max_int in + let cps = List.map2 (/) shard_shape t.chunk_shape in + let shard_idx = Ndarray.create Uint64 (cps @ [2]) Stdint.Uint64.max_int in let grid = RegularGrid.create ~array_shape:shard_shape t.chunk_shape in let kind = Ndarray.data_type x in - let coords = Indexing.coords_of_slice [||] shard_shape in - let m = Array.fold_right (add_coord ~grid ~arr:x) coords ArrayMap.empty in - let _, xs = ArrayMap.fold (update_inner_chunk ~t ~shard_idx ~kind) m (Stdint.Uint64.zero, []) in + let coords = Indexing.coords_of_slice [] shard_shape in + let m = List.fold_right (add_coord ~grid ~arr:x) coords CoordMap.empty in + let _, xs = CoordMap.fold (update_inner_chunk ~t ~shard_idx ~kind) m (Stdint.Uint64.zero, []) in let idx_bytes = encode_index_chain t.index_codecs shard_idx in match t.index_location with | Start -> String.concat String.empty (idx_bytes :: List.rev xs) @@ -562,19 +544,16 @@ end = struct let a2b = ArrayToBytes.decode t.a2b {repr with shape} b2b in List.fold_right ArrayToArray.decode t.a2a a2b - let decode_index_chain : (fixed_arraytobytes, fixed_bytestobytes) chain -> int array -> string -> Stdint.uint64 Ndarray.t = fun t shape x -> + let decode_index_chain (t: (fixed_arraytobytes, fixed_bytestobytes) chain) shape x = let shape' = List.fold_left ArrayToArray.encoded_repr shape t.a2a in let y = List.fold_right BytesToBytes.decode (t.b2b :> bytestobytes list) x in let arr = match t.a2b with | `Bytes e -> Bytes'.decode y {shape=shape'; kind=Uint64} e in match t.a2a with | [] -> arr - | `Transpose o :: _ -> - let inv_order = Array.(make (length o) 0) in - Array.iteri (fun i x -> inv_order.(x) <- i) o; - Ndarray.transpose ~axes:inv_order arr + | `Transpose o :: _ -> ArrayToArray.Transpose.decode o arr - let index_size index_chain cps = encoded_size (16 * Util.prod cps) index_chain + let index_size index_chain cps = encoded_size (16 * List.fold_left Int.mul 1 cps) index_chain let decode_index t cps b = let l = index_size t.index_codecs cps in @@ -583,31 +562,26 @@ end = struct | End -> String.sub b o l, String.sub b 0 o | Start -> String.sub b 0 l, String.sub b l o in - let idx_shape = Array.append cps [|2|] in - decode_index_chain t.index_codecs idx_shape ib, rest + decode_index_chain t.index_codecs (cps @ [2]) ib, rest let decode (type a) (t : t) (repr : a array_repr) (b : string) = - let add_indexed_coord ~grid acc (i, coord) = + let add_indexed_coord ~grid acc i coord = let k, c = RegularGrid.index_coord_pair grid coord in - ArrayMap.add_to_list k (i, c) acc + CoordMap.add_to_list k (i, c) acc in - let read_inner_chunk ~t ~idx_arr ~inner_repr ~chunk_bytes (idx, pairs) = - let oc = Array.append idx [|0|] in - let nc = Array.append idx [|1|] in - let ofs = Stdint.Uint64.to_int (Ndarray.get idx_arr oc) in - let nb = Stdint.Uint64.to_int (Ndarray.get idx_arr nc) in + let read_inner_chunk ~t ~idx_arr ~inner_repr ~chunk_bytes key value acc = + let ofs = Stdint.Uint64.to_int (Ndarray.get idx_arr (key @ [0])) in + let nb = Stdint.Uint64.to_int (Ndarray.get idx_arr (key @ [1])) in let arr = decode_chain t.codecs inner_repr (String.sub chunk_bytes ofs nb) in - List.map (fun (i, c) -> i, Ndarray.get arr c) pairs + acc @ List.map (fun ((i, c) : int * int list) -> i, Ndarray.get arr c) value in - let cps = Array.map2 (/) repr.shape t.chunk_shape in + let cps = List.map2 (/) repr.shape t.chunk_shape in let idx_arr, chunk_bytes = decode_index t cps b in let grid = RegularGrid.create ~array_shape:repr.shape t.chunk_shape in - let coords = Indexing.coords_of_slice [||] repr.shape in - let indexed_coords = Array.mapi (fun i v -> i, v) coords in - let m = Array.fold_left (add_indexed_coord ~grid) ArrayMap.empty indexed_coords in + let coords = Indexing.coords_of_slice [] repr.shape in + let m = List.fold_left2 (add_indexed_coord ~grid) CoordMap.empty List.(init (length coords) Fun.id) coords in let inner_repr = {repr with shape = t.chunk_shape} in - let bd = ArrayMap.bindings m in - let pairs = List.concat_map (read_inner_chunk ~t ~idx_arr ~inner_repr ~chunk_bytes) bd in + let pairs = CoordMap.fold (read_inner_chunk ~t ~idx_arr ~inner_repr ~chunk_bytes) m [] in let sorted_pairs = List.fast_sort (fun (x, _) (y, _) -> Int.compare x y) pairs in let vs = List.map snd sorted_pairs in Ndarray.of_array inner_repr.kind repr.shape (Array.of_list vs) @@ -624,16 +598,15 @@ end = struct | End -> `String "end" | Start -> `String "start" in - let chunk_shape = `List (List.map (fun x -> `Int x) @@ Array.to_list t.chunk_shape) in `Assoc [("name", `String "sharding_indexed"); ("configuration", `Assoc - [("chunk_shape", chunk_shape); + [("chunk_shape", `List (List.map (fun x -> `Int x) t.chunk_shape)); ("index_location", index_location); ("index_codecs", index_codecs); ("codecs", chain_to_yojson t.codecs)])] - let chain_of_yojson chunk_shape codecs = + let chain_of_yojson (chunk_shape : int list) codecs = let open Util.Result_syntax in let split ~f codec (l, r) = Result.fold ~ok:(fun v -> v :: l, r) ~error:(fun _ -> l, codec :: r) (f codec) @@ -655,7 +628,7 @@ end = struct let codec = Util.get_name x in Error (Printf.sprintf "%s codec is unsupported or has invalid configuration." codec) - let of_yojson shard_shape x = + let of_yojson (shard_shape : int list) x = let open Util.Result_syntax in let extract ~assoc name = Yojson.Safe.Util.filter_map (fun (n, v) -> if n = name then Some v else None) assoc @@ -673,11 +646,10 @@ end = struct | `Gzip _ | `Zstd _ -> Error error_msg in let assoc = Yojson.Safe.Util.(member "configuration" x |> to_assoc) in - let* l' = match extract ~assoc "chunk_shape" with + let* chunk_shape = match extract ~assoc "chunk_shape" with | [] -> Error "sharding_indexed must contain a chunk_shape field" | x :: _ -> List.fold_right add_as_int (Yojson.Safe.Util.to_list x) (Ok []) in - let chunk_shape = Array.of_list l' in let* index_location = match extract ~assoc "index_location" with | [] -> Error "sharding_indexed must have a index_location field" | x :: _ -> @@ -693,9 +665,8 @@ end = struct let* ic = match extract ~assoc "index_codecs" with | [] -> Error "sharding_indexed must have a index_codecs field" | x :: _ -> - let cps = Array.map2 (/) shard_shape chunk_shape in - let idx_shape = Array.append cps [|2|] in - chain_of_yojson idx_shape (Yojson.Safe.Util.to_list x) + let cps = List.map2 (/) shard_shape chunk_shape in + chain_of_yojson (cps @ [2]) (Yojson.Safe.Util.to_list x) in (* Ensure index_codecs only contains fixed size array->bytes and bytes->bytes codecs. *) @@ -712,7 +683,7 @@ type variable_array_tobytes = [ `ShardingIndexed of shard_config ] and codec = [ arraytoarray | fixed_arraytobytes | `ShardingIndexed of shard_config | bytestobytes ] and index_codec = [ arraytoarray | fixed_arraytobytes | fixed_bytestobytes ] and shard_config = - {chunk_shape : int array + {chunk_shape : int list ;codecs : codec list ;index_codecs : index_codec list ;index_location : loc} @@ -750,7 +721,7 @@ module Chain = struct begin match x with | `ShardingIndexed cfg -> let codecs = create shape cfg.codecs in - let index_codecs = create (Array.append shape [|2|]) (cfg.index_codecs :> codec list) in + let index_codecs = create (shape @ [2]) (cfg.index_codecs :> codec list) in (* coerse to a fixed codec chain list type *) let pred = function #fixed_bytestobytes as c -> Some c | _ -> None in let b2b = List.filter_map pred index_codecs.b2b in @@ -794,9 +765,10 @@ module Chain = struct Result.fold ~ok:(fun v -> v :: l, r) ~error:(fun _ -> l, codec :: r) (f codec) in let partition f encoded = List.fold_right (split ~f) encoded ([], []) in - let* codecs = match Yojson.Safe.Util.to_list x with - | [] -> Error "No codec specified." - | y -> Ok y + let* codecs = match x with + | `List xs -> Ok xs + | `Null -> Error "array metadata must contain a codecs field." + | _ -> Error "codecs field must be a list of objects." in let* a2b, rest = match partition (ArrayToBytes.of_yojson chunk_shape) codecs with | [x], rest -> Ok (x, rest) diff --git a/zarr/src/codecs.mli b/zarr/src/codecs.mli index 5e39ed1d..52f60fe4 100644 --- a/zarr/src/codecs.mli +++ b/zarr/src/codecs.mli @@ -20,7 +20,7 @@ exception Invalid_zstd_level (** raised when a codec chain contains a Zstd codec with an incorrect compression value.*) (** The type of [array -> array] codecs. *) -type arraytoarray = [ `Transpose of int array ] +type arraytoarray = [ `Transpose of int list ] (** A type representing valid Gzip codec compression levels. *) type compression_level = L0 | L1 | L2 | L3 | L4 | L5 | L6 | L7 | L8 | L9 @@ -55,14 +55,14 @@ and index_codec = [ arraytoarray | fixed_arraytobytes | fixed_bytestobytes ] (** A type representing the Sharding indexed codec's configuration parameters. *) and shard_config = - {chunk_shape : int array + {chunk_shape : int list ;codecs : codec list ;index_codecs : index_codec list ;index_location : loc} (** The type summarizing the decoded/encoded representation of a Zarr array or chunk. *) -type 'a array_repr = {kind : 'a Ndarray.dtype; shape : int array} +type 'a array_repr = {kind : 'a Ndarray.dtype; shape : int list} (** A module containing functions to encode/decode an array chunk using a predefined set of codecs. *) @@ -83,7 +83,7 @@ module Chain : sig @raise Invalid_sharding_chunk_shape if [c] contains a shardingindexed codec with an incorrect inner chunk shape. *) - val create : int array -> codec list -> t + val create : int list -> codec list -> t (** [encode t x] computes the encoded byte string representation of array chunk [x]. *) @@ -99,7 +99,7 @@ module Chain : sig (** [of_yojson x] returns a code chain of type {!t} from its json object representation. *) - val of_yojson : int array -> Yojson.Safe.t -> (t, string) result + val of_yojson : int list -> Yojson.Safe.t -> (t, string) result (** [to_yojson x] returns a json object representation of codec chain [x]. *) val to_yojson : t -> Yojson.Safe.t @@ -119,7 +119,7 @@ module Make (IO : Types.IO) : sig (?append:bool -> (int * string) list -> unit IO.t) -> int -> 'a array_repr -> - (int array * 'a) list -> + (int list * 'a) list -> 'a -> unit IO.t @@ -128,7 +128,7 @@ module Make (IO : Types.IO) : sig (Types.range list -> string list IO.t) -> int -> 'a array_repr -> - (int * int array) list -> + (int * int list) list -> 'a -> (int * 'a) list IO.t end diff --git a/zarr/src/extensions.ml b/zarr/src/extensions.ml index d379048c..ea4098f7 100644 --- a/zarr/src/extensions.ml +++ b/zarr/src/extensions.ml @@ -1,33 +1,37 @@ module RegularGrid = struct exception Grid_shape_mismatch - - type t = int array - - let chunk_shape : t -> int array = Fun.id + type t = int list + let chunk_shape : t -> int list = Fun.id let ceildiv x y = Float.(to_int @@ ceil (of_int x /. of_int y)) let floordiv x y = Float.(to_int @@ floor (of_int x /. of_int y)) - let grid_shape t array_shape = Array.map2 ceildiv array_shape t - let index_coord_pair t coord = (Array.map2 floordiv coord t, Array.map2 Int.rem coord t) - let ( = ) x y = x = y + let grid_shape t array_shape = List.map2 ceildiv array_shape t + let index_coord_pair t coord = (List.map2 floordiv coord t, List.map2 Int.rem coord t) + let ( = ) x y = List.equal Int.equal x y + let max = List.fold_left Int.max Int.min_int - let create : array_shape:int array -> int array -> t - = fun ~array_shape chunk_shape -> - if Array.(length chunk_shape <> length array_shape) || Util.(max chunk_shape > max array_shape) + let create ~array_shape chunk_shape = + if List.(length chunk_shape <> length array_shape) || (max chunk_shape > max array_shape) then raise Grid_shape_mismatch else chunk_shape (* returns all chunk indices in this regular grid *) let indices t array_shape = - grid_shape t array_shape - |> Array.to_list - |> List.map (fun x -> List.init x Fun.id) - |> Ndarray.Indexing.cartesian_prod - |> List.map Array.of_list - - let to_yojson : t -> Yojson.Safe.t = fun t -> - let chunk_shape = `List (List.map (fun x -> `Int x) @@ Array.to_list t) in - `Assoc - [("name", `String "regular") - ;("configuration", `Assoc [("chunk_shape", chunk_shape)])] + let lol = List.map (fun x -> List.init x Fun.id) (grid_shape t array_shape) in + Ndarray.Indexing.cartesian_prod lol + + let to_yojson (g : t) : Yojson.Safe.t = + let name = ("name", `String "regular") in + `Assoc [name; ("configuration", `Assoc [("chunk_shape", `List (List.map (fun x -> `Int x) g))])] + + let add (x : Yojson.Safe.t) acc = match x with + | `Int i when i > 0 -> Result.map (List.cons i) acc + | _ -> Error "chunk_shape must only contain positive ints." + + let of_yojson (array_shape: int list) (x : Yojson.Safe.t) = match x with + | `Assoc ["name", `String "regular"; "configuration", `Assoc ["chunk_shape", `List l]] -> + begin try Result.map (create ~array_shape) (List.fold_right add l (Ok [])) + with Grid_shape_mismatch -> Error "grid shape mismatch." end + | `Null -> Error "array metadata must contain a chunk_grid field." + | _ -> Error "Invalid Chunk grid name or configuration." end module ChunkKeyEncoding = struct @@ -40,13 +44,12 @@ module ChunkKeyEncoding = struct (* map a chunk coordinate index to a key. E.g, (2,3,1) maps to c/2/3/1 *) let encode {name; sep; _} index = - let xs = Array.fold_right (fun i acc -> string_of_int i :: acc) index [] in + let xs = List.fold_right (fun i acc -> string_of_int i :: acc) index [] in match name with | Default -> String.concat sep ("c" :: xs) - | V2 -> if Array.length index = 0 then "0" else String.concat sep xs + | V2 -> if List.length index = 0 then "0" else String.concat sep xs - let ( = ) x y = - x.name = y.name && x.sep = y.sep && x.is_default = y.is_default + let ( = ) x y = Bool.equal x.is_default y.is_default && String.equal x.sep y.sep && x.name = y.name let to_yojson : t -> Yojson.Safe.t = fun {name; sep; is_default} -> let str = match name with @@ -54,23 +57,20 @@ module ChunkKeyEncoding = struct | V2 -> "v2" in if is_default then `Assoc [("name", `String str)] else - `Assoc - [("name", `String str) - ;("configuration", `Assoc [("separator", `String sep)])] - - let of_yojson x = match Util.get_name x, Yojson.Safe.Util.member "configuration" x with - | "default", `Null -> - Ok {name = Default; sep = "/"; is_default = true} - | "default", `Assoc [("separator", `String "/")] -> - Ok {name = Default; sep = "/"; is_default = false} - | "default", `Assoc [("separator", `String ".")] -> - Ok {name = Default; sep = "."; is_default = false} - | "v2", `Null -> - Ok {name = V2; sep = "."; is_default = true} - | "v2", `Assoc [("separator", `String "/")] -> - Ok {name = V2; sep = "/"; is_default = false} - | "v2", `Assoc [("separator", `String ".")] -> - Ok {name = V2; sep = "."; is_default = false} + `Assoc [("name", `String str); ("configuration", `Assoc [("separator", `String sep)])] + + let of_yojson : Yojson.Safe.t -> (t, string) result = function + | `Assoc [("name", `String "v2")] -> Ok {name = V2; sep = "."; is_default = true} + | `Assoc [("name", `String "v2"); ("configuration", `Assoc [("separator", `String ("/" as slash))])] -> + Ok {name = V2; sep = slash; is_default = false} + | `Assoc [("name", `String "v2"); ("configuration", `Assoc [("separator", `String ("." as dot))])] -> + Ok {name = V2; sep = dot; is_default = false} + | `Assoc [("name", `String "default")] -> Ok {name = Default; sep = "/"; is_default = true} + | `Assoc [("name", `String "default"); ("configuration", `Assoc [("separator", `String ("/" as slash))])] -> + Ok {name = Default; sep = slash; is_default = false} + | `Assoc [("name", `String "default"); ("configuration", `Assoc [("separator", `String ("." as dot))])] -> + Ok {name = Default; sep = dot; is_default = false} + | `Null -> Error "array metadata must contain a chunk_key_encoding field." | _ -> Error "Invalid chunk key encoding configuration." end @@ -144,5 +144,6 @@ module Datatype = struct | `String "complex64" -> Ok Complex64 | `String "int" -> Ok Int | `String "nativeint" -> Ok Nativeint - | _ -> Error ("Unsupported metadata data_type") + | `Null -> Error "array metadata must contain a data_type field." + | _ -> Error "Unsupported metadata data_type" end diff --git a/zarr/src/extensions.mli b/zarr/src/extensions.mli index 390727fe..6c14db02 100644 --- a/zarr/src/extensions.mli +++ b/zarr/src/extensions.mli @@ -1,18 +1,19 @@ module RegularGrid : sig exception Grid_shape_mismatch type t - val create : array_shape:int array -> int array -> t - val chunk_shape : t -> int array - val indices : t -> int array -> int array list - val index_coord_pair : t -> int array -> int array * int array + val create : array_shape:int list -> int list -> t + val chunk_shape : t -> int list + val indices : t -> int list -> int list list + val index_coord_pair : t -> int list -> int list * int list val ( = ) : t -> t -> bool + val of_yojson : int list -> Yojson.Safe.t -> (t, string) result val to_yojson : t -> Yojson.Safe.t end module ChunkKeyEncoding : sig type t val create : [< `Slash | `Dot > `Slash ] -> t - val encode : t -> int array -> string + val encode : t -> int list -> string val ( = ) : t -> t -> bool val of_yojson : Yojson.Safe.t -> (t, string) result val to_yojson : t -> Yojson.Safe.t diff --git a/zarr/src/metadata.ml b/zarr/src/metadata.ml index 495822e9..86b94681 100644 --- a/zarr/src/metadata.ml +++ b/zarr/src/metadata.ml @@ -6,254 +6,304 @@ module FillValue = struct type t = | Char of char | Bool of bool - | Int of Stdint.uint64 + | Int of int + | Intlit of string * Stdint.uint64 (* for ints that cannot fit in a 63bit integer type *) | Float of float - | FloatBits of float - | IntComplex of Complex.t - | FloatComplex of Complex.t - | FFComplex of Complex.t - | FBComplex of Complex.t - | BFComplex of Complex.t - | BBComplex of Complex.t - - let ( = ) x y = x = y - - let of_kind - : type a. a Ndarray.dtype -> a -> t - = fun kind a -> match kind with - | Ndarray.Char -> Char a - | Ndarray.Bool -> Bool a - | Ndarray.Int8 -> Int (Stdint.Uint64.of_int a) - | Ndarray.Uint8 -> Int (Stdint.Uint64.of_int a) - | Ndarray.Int16 -> Int (Stdint.Uint64.of_int a) - | Ndarray.Uint16 -> Int (Stdint.Uint64.of_int a) - | Ndarray.Int32 -> Int (Stdint.Uint64.of_int32 a) - | Ndarray.Int64 -> Int (Stdint.Uint64.of_int64 a) - | Ndarray.Uint64 -> Int a - | Ndarray.Float32 -> Float a - | Ndarray.Float64 -> Float a - | Ndarray.Complex32 -> FloatComplex a - | Ndarray.Complex64 -> FloatComplex a - | Ndarray.Int -> Int (Stdint.Uint64.of_int a) - | Ndarray.Nativeint -> Int (Stdint.Uint64.of_nativeint a) - - let rec of_yojson x = match x with - | `Bool b -> Ok (Bool b) - | `Int i -> Result.ok @@ Int (Stdint.Uint64.of_int i) - | `String "Infinity" -> Ok (Float Float.infinity) - | `String "-Infinity" -> Ok (Float Float.neg_infinity) - | `String "NaN" -> Ok (Float Float.nan) - | `Float f -> Ok (Float f) - | `String s when String.length s = 1 -> Ok (Char (String.get s 0)) - | `String s when String.starts_with ~prefix:"0x" s -> - Ok (FloatBits Int64.(float_of_bits @@ of_string s)) - | `List [`Int x; `Int y] -> - Ok (IntComplex Complex.{re=Float.of_int x; im=Float.of_int y}) - | `List [`Float re; `Float im] -> Ok (FloatComplex Complex.{re; im}) - | `List [`String _ as a; `String _ as b] -> - Result.bind (of_yojson a) @@ fun x -> - Result.bind (of_yojson b) @@ fun y -> - (match x, y with - | Float re, Float im -> Ok (FFComplex Complex.{re; im}) - | Float re, FloatBits im -> Ok (FBComplex Complex.{re; im}) - | FloatBits re, Float im -> Ok (BFComplex Complex.{re; im}) - | FloatBits re, FloatBits im -> Ok (BBComplex Complex.{re; im}) - | _ -> Error "Unsupported fill value.") + | IntFloat of int * float + | IntlitFloat of string * float + | StringFloat of string * float (* float represented using hex string in the metadata json. *) + | IntComplex of (int * int) * Complex.t (* complex number represented using ints in the metadata json. *) + | IntlitComplex of (string * string) * Complex.t (* complex number represented using ints in the metadata json. *) + | FloatComplex of Complex.t (* complex number represented using floats in the metadata json. *) + | StringComplex of (string * string) * Complex.t + + let rec create : type a. a Ndarray.dtype -> a -> t = fun kind x -> match kind with + | Ndarray.Char -> Char x + | Ndarray.Bool -> Bool x + | Ndarray.Int8 -> Int x + | Ndarray.Uint8 -> Int x + | Ndarray.Int16 -> Int x + | Ndarray.Uint16 -> Int x + | Ndarray.Int32 -> Int (Int32.to_int x) + | Ndarray.Int -> Int x + | Ndarray.Int64 when x >= -4611686018427387904L && x <= 4611686018427387903L -> Int (Int64.to_int x) + | Ndarray.Int64 -> Intlit (Int64.to_string x, Stdint.Uint64.of_int64 x) + | Ndarray.Uint64 when Stdint.Uint64.(compare x (of_int Int.max_int)) < 0 -> Int (Stdint.Uint64.to_int x) + | Ndarray.Uint64 -> Intlit (Stdint.Uint64.to_string x, x) + | Ndarray.Float32 -> Float x + | Ndarray.Float64 -> Float x + | Ndarray.Complex32 -> FloatComplex x + | Ndarray.Complex64 -> FloatComplex x + | Ndarray.Nativeint -> create Ndarray.Int64 (Int64.of_nativeint x) + + let equal x y = match x, y with + | Char a, Char b when Char.equal a b -> true + | Bool false, Bool false -> true + | Bool true, Bool true -> true + | Int a, Int b when Int.equal a b -> true + | Intlit (a, _), Intlit (b, _) when String.equal a b -> true + | Float a, Float b when Float.equal a b -> true + | IntFloat (a, _), IntFloat (b, _) when Int.equal a b -> true + | IntlitFloat (a, _), IntlitFloat (b, _) when String.equal a b -> true + | StringFloat ("Infinity", _), StringFloat ("Infinity", _) -> true + | StringFloat ("-Infinity", _), StringFloat ("-Infinity", _) -> true + | StringFloat ("NaN", _), StringFloat ("NaN", _) -> true + | StringFloat (a, _), StringFloat (b, _) when String.equal a b -> true + | IntComplex ((a1, b1), _), IntComplex ((a2, b2), _) when Int.(equal a1 a2 && equal b1 b2) -> true + | IntlitComplex ((a1, b1), _), IntlitComplex ((a2, b2), _) when String.(equal a1 a2 && equal b1 b2) -> true + | FloatComplex Complex.{re=r1;im=i1}, FloatComplex Complex.{re=r2;im=i2} when Float.(equal r1 r2 && equal i1 i2) -> true + | StringComplex ((a1, b1), _), StringComplex ((a2, b2), _) when String.(equal a1 a2 && equal b1 b2) -> true + | _ -> false + + (* This makes sure the way the fill-value is encoded in the metadata is + preserved when converting a parsed FillValue.t back to it's JSON value. *) + let rec of_yojson (d : Datatype.t) (x : Yojson.Safe.t) = match d, x with + | Datatype.Char, `String s when String.length s = 1 -> Ok (Char (String.get s 0)) + | Datatype.Bool, `Bool b -> Ok (Bool b) + | Datatype.Int8, `Int a when a >= -128 && a <= 127 -> Ok (Int a) + | Datatype.Uint8, `Int a when a >= 0 && a <= 255 -> Ok (Int a) + | Datatype.Int16, `Int a when a >= -32768 && a <= 32767 -> Ok (Int a) + | Datatype.Uint16, `Int a when a >= 0 && a <= 65535 -> Ok (Int a) + | Datatype.Int32, `Int a when a >= -2147483648 && a <= 2147483647 -> Ok (Int a) + | Datatype.Int, `Int a -> Ok (Int a) + | Datatype.Int64, `Int a -> Ok (Int a) + | Datatype.Int64, `Intlit a -> begin match Int64.of_string_opt a with + | None -> Error "Unsupported fill value." + | Some b -> Ok (Intlit (a, Stdint.Uint64.of_int64 b)) end + | Datatype.Nativeint, a -> of_yojson Datatype.Int64 a + | Datatype.Uint64, `Int a when a >= 0 -> Ok (Int a) + | Datatype.Uint64, `Intlit a when not (String.starts_with ~prefix:"-" a) -> + begin match Stdint.Uint64.of_string a with + | exception Invalid_argument _ -> Error "Unsupported fill value." + | b -> Ok (Intlit (a, b)) end + | Datatype.Float32, `Float a -> Ok (Float a) + | Datatype.Float32, `Int a -> Ok (IntFloat (a, Float.of_int a)) + | Datatype.Float32, `Intlit a -> Ok (IntlitFloat (a, Float.of_string a)) + | Datatype.Float32, `String ("Infinity" as s) -> Ok (StringFloat (s, Float.infinity)) + | Datatype.Float32, `String ("-Infinity" as s) -> Ok (StringFloat (s, Float.neg_infinity)) + | Datatype.Float32, `String ("NaN" as s) -> Ok (StringFloat (s, Float.nan)) + | Datatype.Float32, `String s when String.starts_with ~prefix:"0x" s -> + begin match Stdint.Uint64.of_string s with + | exception Invalid_argument _ -> Error "Unsupported fill value." + | a -> Ok (StringFloat (s, Stdint.Uint64.to_float a)) end + | Datatype.Float64, `Float a -> Ok (Float a) + | Datatype.Float64, `Int a -> Ok (IntFloat (a, Float.of_int a)) + | Datatype.Float64, `Intlit a -> Ok (IntlitFloat (a, Float.of_string a)) + | Datatype.Float64, `String ("Infinity" as s) -> Ok (StringFloat (s, Float.infinity)) + | Datatype.Float64, `String ("-Infinity" as s) -> Ok (StringFloat (s, Float.neg_infinity)) + | Datatype.Float64, `String ("NaN" as s) -> Ok (StringFloat (s, Float.nan)) + | Datatype.Float64, `String s when String.starts_with ~prefix:"0x" s -> + begin match Stdint.Uint64.of_string s with + | exception Invalid_argument _ -> Error "Unsupported fill value." + | a -> Ok (StringFloat (s, Stdint.Uint64.to_float a)) end + | Datatype.Complex32, `List [`Int a; `Int b] -> Ok (IntComplex ((a, b), Complex.{re=Float.of_int a; im=Float.of_int b})) + | Datatype.Complex32, `List [`Intlit a; `Intlit b] -> Ok (IntlitComplex ((a, b), Complex.{re=Float.of_string a; im=Float.of_string b})) + | Datatype.Complex32, `List [`Float re; `Float im] -> Ok (FloatComplex Complex.{re; im}) + | Datatype.Complex32, `List [`String a; `String b] -> Ok (StringComplex ((a, b), Complex.{re=Float.of_string a; im=Float.of_string b})) + | Datatype.Complex64, `List [`Int a; `Int b] -> Ok (IntComplex ((a, b), Complex.{re=Float.of_int a; im=Float.of_int b})) + | Datatype.Complex64, `List [`Intlit a; `Intlit b] -> Ok (IntlitComplex ((a, b), Complex.{re=Float.of_string a; im=Float.of_string b})) + | Datatype.Complex64, `List [`Float re; `Float im] -> Ok (FloatComplex Complex.{re; im}) + | Datatype.Complex64, `List [`String a; `String b] -> Ok (StringComplex ((a, b), Complex.{re=Float.of_string a; im=Float.of_string b})) + | _, `Null -> Error "array metadata must contain a fill_value field." | _ -> Error "Unsupported fill value." - let rec to_yojson = function - | Bool b -> `Bool b - | Int i -> `Int (Stdint.Uint64.to_int i) + let to_yojson : t -> Yojson.Safe.t = function | Char c -> `String (Printf.sprintf "%c" c) - | Float f when Float.is_nan f -> `String "NaN" - | Float f when f = Float.infinity -> `String "Infinity" - | Float f when f = Float.neg_infinity -> `String "-Infinity" + | Bool b -> `Bool b + | Int i -> `Int i + | Intlit (s, _) -> `Intlit s | Float f -> `Float f - | FloatBits f -> `String (Stdint.Int64.to_string_hex @@ Int64.bits_of_float f) - | IntComplex Complex.{re; im} -> - `List [`Int (Float.to_int re); `Int (Float.to_int im)] + | IntFloat (i, _) -> `Int i + | IntlitFloat (s, _) -> `Intlit s + | StringFloat (s, _) -> `String s + | IntComplex ((a, b), _) -> `List [`Int a; `Int b] + | IntlitComplex ((a, b), _) -> `List [`Intlit a; `Intlit b] | FloatComplex Complex.{re; im} -> `List [`Float re; `Float im] - | FFComplex Complex.{re; im} -> - `List [to_yojson (Float re); to_yojson (Float im)] - | FBComplex Complex.{re; im} -> - `List [to_yojson (Float re); to_yojson (FloatBits im)] - | BFComplex Complex.{re; im} -> - `List [to_yojson (FloatBits re); to_yojson (Float im)] - | BBComplex Complex.{re; im} -> - `List [to_yojson (FloatBits re); to_yojson (FloatBits im)] + | StringComplex ((a, b), _) -> `List [`String a; `String b] +end + +module NodeType = struct + type t = Array | Group + let rec to_yojson x : Yojson.Safe.t = `String (show x) + and show = function + | Array -> "array" + | Group -> "group" + + module Array = struct + let of_yojson : Yojson.Safe.t -> (t, string) result = function + | `String "array" -> Ok Array + | `Null -> Error "metadata must contain a node_type field." + | _ -> Error "node_type field must be 'array'." + end + + module Group = struct + let of_yojson : Yojson.Safe.t -> (t, string) result = function + | `String "group" -> Ok Group + | `Null -> Error "group metadata must contain a node_type field." + | _ -> Error "node_type field must be 'group'." + end +end + +(* The shape of a Zarr array is the list of dimension lengths. It can be the + empty list in the case of a zero-dimension array (scalar). *) +module Shape = struct + type t = Empty | Dims of int list + + let create = function + | [] -> Empty + | xs -> Dims xs + + let ( = ) x y = match x, y with + | Empty, Empty -> true + | Dims a, Dims b when List.equal Int.equal a b -> true + | _ -> false + + let add (x : Yojson.Safe.t) acc = match x with + | `Int i when i > 0 -> Result.map (List.cons i) acc + | _ -> Error "shape field list must only contain positive integers." + + let of_yojson : Yojson.Safe.t -> (t, string) result = function + | `List [] -> Ok Empty + | `List xs -> Result.map (fun x -> Dims x) (List.fold_right add xs (Ok [])) + | `Null -> Error "array metadata must contain a shape field." + | _ -> Error "shape field must be a list of integers." + + let to_yojson : t -> Yojson.Safe.t = function + | Empty -> `List [] + | Dims xs -> `List (List.map (fun x -> `Int x) xs) + + let to_list = function + | Empty -> [] + | Dims xs -> xs + + let ndim = function + | Empty -> 0 + | Dims xs -> List.length xs +end + +module ZarrFormat = struct + type t = int + let to_yojson x : Yojson.Safe.t = `Int x + let of_yojson = function + | `Int (3 as i) -> Ok i + | `Null -> Error "metadata must contain a zarr_format field." + | _ -> Error "zarr_format field must be the integer 3." +end + +module DimensionNames = struct + type t = string option list + + let to_yojson (xs : t) : Yojson.Safe.t = + `List (List.map (Option.fold ~none:`Null ~some:(fun s -> `String s)) xs) + + let add (x : Yojson.Safe.t) acc = match x with + | `String s -> Result.map (List.cons (Some s)) acc + | `Null -> Result.map (List.cons None) acc + | _ -> Error "dimension_names must contain strings or null values." + + let of_yojson ndim x = match x with + | `Null -> Ok [] + | `List xs -> + if List.length xs = ndim then List.fold_right add xs (Ok []) + else Error "dimension_names length and array dimensionality must be equal." + | _ -> Error "dimension_names field must be a list." end module Array = struct type t = - {zarr_format : int - ;shape : int array - ;node_type : string + {zarr_format : ZarrFormat.t + ;shape : Shape.t + ;node_type : NodeType.t ;data_type : Datatype.t ;codecs : Codecs.Chain.t ;fill_value : FillValue.t ;chunk_grid : RegularGrid.t ;chunk_key_encoding : ChunkKeyEncoding.t ;attributes : Yojson.Safe.t - ;dimension_names : string option list + ;dimension_names : DimensionNames.t ;storage_transformers : Yojson.Safe.t list} - let create - ?(sep=`Slash) ?(dimension_names=[]) ?(attributes=`Null) - ~codecs ~shape - kind fv chunks - = - {shape - ;codecs + let create ?(sep=`Slash) ?(dimension_names=[]) ?(attributes=`Null) ~codecs ~shape kind fv chunks = + {codecs ;attributes ;dimension_names ;zarr_format = 3 - ;node_type = "array" + ;shape = Shape.create shape + ;node_type = NodeType.Array ;storage_transformers = [] - ;fill_value = FillValue.of_kind kind fv + ;fill_value = FillValue.create kind fv ;data_type = Datatype.of_kind kind ;chunk_key_encoding = ChunkKeyEncoding.create sep ;chunk_grid = RegularGrid.create ~array_shape:shape chunks} let to_yojson : t -> Yojson.Safe.t = fun t -> - let shape = List.map (fun x -> `Int x) (Array.to_list t.shape) in let l = - [("zarr_format", `Int t.zarr_format) - ;("shape", `List shape) - ;("node_type", `String t.node_type) + [("zarr_format", ZarrFormat.to_yojson t.zarr_format) + ;("shape", Shape.to_yojson t.shape) + ;("node_type", NodeType.to_yojson t.node_type) ;("data_type", Datatype.to_yojson t.data_type) ;("codecs", Codecs.Chain.to_yojson t.codecs) ;("fill_value", FillValue.to_yojson t.fill_value) ;("chunk_grid", RegularGrid.to_yojson t.chunk_grid) - ;("chunk_key_encoding", - ChunkKeyEncoding.to_yojson t.chunk_key_encoding)] - in - let l = match t.attributes with - | `Null -> l - | x -> l @ [("attributes", x)] + ;("chunk_key_encoding", ChunkKeyEncoding.to_yojson t.chunk_key_encoding)] in - match t.dimension_names with - | [] -> `Assoc l - | xs -> - let xs' = List.map (Option.fold ~none:`Null ~some:(fun s -> `String s)) xs in - `Assoc (l @ [("dimension_names", `List xs')]) + (* optional fields.*) + match t.attributes, t.dimension_names with + | `Null, [] -> `Assoc l + | `Null, xs -> `Assoc (l @ ["dimension_names", DimensionNames.to_yojson xs]) + | x, [] -> `Assoc (l @ ["attributes", x]) + | x, xs -> `Assoc (l @ [("attributes", x); ("dimension_names", DimensionNames.to_yojson xs)]) let of_yojson x = - let open Yojson.Safe.Util in let open Util.Result_syntax in - let add_as_int ~error a acc = - let* k = acc in - match a with - | `Int i when i > 0 -> Ok (i :: k) - | _ -> Error error - in - let* zarr_format = match member "zarr_format" x with - | `Int (3 as i) -> Ok i - | `Null -> Error "array metadata must contain a zarr_format field." - | _ -> Error "zarr_format field must be the integer 3." - in - let* node_type = match member "node_type" x with - | `String ("array" as a) -> Ok a - | `Null -> Error "array metadata must contain a node_type field." - | _ -> Error "node_type field must be 'array'." - in - let* shape = match member "shape" x with - | `List xs -> - let error = "shape field list must only contain positive integers." in - let+ l = List.fold_right (add_as_int ~error) xs (Ok []) in - Array.of_list l - | `Null -> Error "array metadata must contain a shape field." - | _ -> Error "shape field must be a list of integers." - in - let* data_type = match member "data_type" x with - | `String _ as c -> Datatype.of_yojson c - | `Null -> Error "array metadata must contain a data_type field." - | _ -> Error "data_type field must be a string." - in - let* chunk_shape, chunk_grid = match member "chunk_grid" x with - | `Null -> Error "array metadata must contain a chunk_grid field." - | xs -> - match Util.get_name xs, member "configuration" xs with - | "regular", `Assoc [("chunk_shape", `List l)] -> - let error = "chunk_shape must only contain positive ints." in - let* v = List.fold_right (add_as_int ~error) l (Ok []) in - let cs = Array.of_list v in - let+ r = match RegularGrid.create ~array_shape:shape cs with - | exception RegularGrid.Grid_shape_mismatch -> Error "grid shape mismatch." - | g -> Ok g - in cs, r - | _ -> Error "Invalid Chunk grid name or configuration." - in - let* codecs = match member "codecs" x with - | `List _ as c -> Codecs.Chain.of_yojson chunk_shape c - | `Null -> Error "array metadata must contain a codecs field." - | _ -> Error "codecs field must be a list of objects." - in - let* fill_value = match member "fill_value" x with - | `Null -> Error "array metadata must contain a fill_value field." - | xs -> FillValue.of_yojson xs - in - let* chunk_key_encoding = match member "chunk_key_encoding" x with - | `Null -> Error "array metadata must contain a chunk_key_encoding field." - | xs -> ChunkKeyEncoding.of_yojson xs - in + let member = Yojson.Safe.Util.member in + let* zarr_format = ZarrFormat.of_yojson (member "zarr_format" x) in + let* shape = Shape.of_yojson (member "shape" x) in + let* data_type = Datatype.of_yojson (member "data_type" x) in + let* fill_value = FillValue.of_yojson data_type (member "fill_value" x) in + let* chunk_key_encoding = ChunkKeyEncoding.of_yojson (member "chunk_key_encoding" x) in + let* chunk_grid = RegularGrid.of_yojson (Shape.to_list shape) (member "chunk_grid" x) in + let* codecs = Codecs.Chain.of_yojson (RegularGrid.chunk_shape chunk_grid) (member "codecs" x) in + let* node_type = NodeType.Array.of_yojson (member "node_type" x) in (* Optional fields *) - let add_as_str ~error a acc = - let* k = acc in - match a with - | `String s -> Ok (Some s :: k) - | `Null -> Ok (None :: k) - | _ -> Error error - in - let attributes = member "attributes" x in - let* dimension_names = match member "dimension_names" x with - | `Null -> Ok [] - | `List xs -> - if List.length xs <> Array.length shape then - Error "dimension_names length and array dimensionality must be equal." - else - let error = "dimension_names must contain strings or null values." in - List.fold_right (add_as_str ~error) xs (Ok []) - | _ -> Error "dimension_names field must be a list." - in + let* dimension_names = DimensionNames.of_yojson (Shape.ndim shape) (member "dimension_names" x) in let+ storage_transformers = match member "storage_transformers" x with | `Null -> Ok [] | _ -> Error "storage_transformers field is not yet supported." in + let attributes = member "attributes" x in {zarr_format; shape; node_type; data_type; codecs; fill_value; chunk_grid ;chunk_key_encoding; attributes; dimension_names; storage_transformers} let ( = ) x y = - x.zarr_format = y.zarr_format - && x.shape = y.shape - && x.node_type = y.node_type + Shape.(x.shape = y.shape) && Datatype.(x.data_type = y.data_type) && Codecs.Chain.(x.codecs = y.codecs) - && FillValue.(x.fill_value = y.fill_value) + && FillValue.(equal x.fill_value y.fill_value) && RegularGrid.(x.chunk_grid = y.chunk_grid) && ChunkKeyEncoding.(x.chunk_key_encoding = y.chunk_key_encoding) - && x.attributes = y.attributes - && x.dimension_names = y.dimension_names - && x.storage_transformers = y.storage_transformers + && Yojson.Safe.(equal x.attributes y.attributes) + && List.equal (fun a b -> Option.equal String.equal a b) x.dimension_names y.dimension_names + && List.equal Yojson.Safe.equal x.storage_transformers y.storage_transformers - let shape t = t.shape let codecs t = t.codecs - let dimension_names t = t.dimension_names let attributes t = t.attributes + let shape t = Shape.to_list t.shape + let dimension_names t = t.dimension_names let chunk_shape t = RegularGrid.chunk_shape t.chunk_grid let index_coord_pair t coord = RegularGrid.index_coord_pair t.chunk_grid coord let chunk_key t index = ChunkKeyEncoding.encode t.chunk_key_encoding index let chunk_indices t shape = RegularGrid.indices t.chunk_grid shape let encode t = Yojson.Safe.to_string (to_yojson t) let update_attributes t attrs = {t with attributes = attrs} - let update_shape t shape = {t with shape} + (* FIXME: must ensure the dimensions of the array remain unchanged. *) + let update_shape t shape = {t with shape = Shape.create shape} let decode s = match of_yojson (Yojson.Safe.from_string s) with | Error e -> raise (Parse_error e) | Ok m -> m - let is_valid_kind - : type a. t -> a Ndarray.dtype -> bool - = fun t kind -> match kind, t.data_type with + let is_valid_kind (type a) t (kind : a Ndarray.dtype) = match kind, t.data_type with | Ndarray.Char, Datatype.Char | Ndarray.Bool, Datatype.Bool | Ndarray.Int8, Datatype.Int8 @@ -271,78 +321,55 @@ module Array = struct | Ndarray.Nativeint, Datatype.Nativeint -> true | _ -> false - let fillvalue_of_kind - : type a. t -> a Ndarray.dtype -> a - = fun t kind -> match kind, t.fill_value with + let fillvalue_of_kind (type a) t (kind : a Ndarray.dtype) : a = match kind, t.fill_value with | Ndarray.Char, FillValue.Char c -> c | Ndarray.Bool, FillValue.Bool b -> b - | Ndarray.Int8, FillValue.Int i -> Stdint.Uint64.to_int i - | Ndarray.Uint8, FillValue.Int i -> Stdint.Uint64.to_int i - | Ndarray.Int16, FillValue.Int i -> Stdint.Uint64.to_int i - | Ndarray.Uint16, FillValue.Int i -> Stdint.Uint64.to_int i - | Ndarray.Int32, FillValue.Int i -> Stdint.Uint64.to_int32 i - | Ndarray.Int64, FillValue.Int i -> Stdint.Uint64.to_int64 i - | Ndarray.Uint64, FillValue.Int i -> i - | Ndarray.Int, FillValue.Int i -> Stdint.Uint64.to_int i - | Ndarray.Nativeint, FillValue.Int i -> Stdint.Uint64.to_nativeint i + | Ndarray.Int8, FillValue.Int i -> i + | Ndarray.Uint8, FillValue.Int i -> i + | Ndarray.Int16, FillValue.Int i -> i + | Ndarray.Uint16, FillValue.Int i -> i + | Ndarray.Int32, FillValue.Int i -> Int32.of_int i + | Ndarray.Int, FillValue.Int i -> i + | Ndarray.Int64, FillValue.Int i -> Int64.of_int i + | Ndarray.Int64, FillValue.Intlit (_, i) -> Stdint.Uint64.to_int64 i + | Ndarray.Uint64, FillValue.Int i -> Stdint.Uint64.of_int i + | Ndarray.Uint64, FillValue.Intlit (_, i) -> i + | Ndarray.Nativeint, FillValue.Int i -> Nativeint.of_int i + | Ndarray.Nativeint, FillValue.Intlit (_, i) -> Stdint.Uint64.to_nativeint i | Ndarray.Float32, FillValue.Float f -> f - | Ndarray.Float32, FillValue.FloatBits f -> f | Ndarray.Float64, FillValue.Float f -> f - | Ndarray.Float64, FillValue.FloatBits f -> f - | Ndarray.Complex32, FillValue.IntComplex c -> c - | Ndarray.Complex32, FillValue.FloatComplex c -> c - | Ndarray.Complex32, FillValue.FFComplex c -> c - | Ndarray.Complex32, FillValue.FBComplex c -> c - | Ndarray.Complex32, FillValue.BFComplex c -> c - | Ndarray.Complex32, FillValue.BBComplex c -> c - | Ndarray.Complex64, FillValue.IntComplex c -> c - | Ndarray.Complex64, FillValue.FloatComplex c -> c - | Ndarray.Complex64, FillValue.FFComplex c -> c - | Ndarray.Complex64, FillValue.FBComplex c -> c - | Ndarray.Complex64, FillValue.BFComplex c -> c - | Ndarray.Complex64, FillValue.BBComplex c -> c + | Ndarray.Complex32, FillValue.FloatComplex f -> f + | Ndarray.Complex64, FillValue.FloatComplex f -> f | _ -> failwith "kind is not compatible with node's fill value." end module Group = struct - type t = {zarr_format : int; node_type : string; attributes : Yojson.Safe.t} + type t = {zarr_format : ZarrFormat.t; node_type : NodeType.t; attributes : Yojson.Safe.t} let to_yojson : t -> Yojson.Safe.t = fun t -> - let l = [("zarr_format", `Int t.zarr_format); ("node_type", `String t.node_type)] in + let l = [("zarr_format", ZarrFormat.to_yojson t.zarr_format); ("node_type", NodeType.to_yojson t.node_type)] in + (* optional fields.*) match t.attributes with | `Null -> `Assoc l | x -> `Assoc (l @ [("attributes", x)]) - let default = {zarr_format = 3; node_type = "group"; attributes = `Null} + let default = {zarr_format = 3; node_type = NodeType.Group; attributes = `Null} let encode t = Yojson.Safe.to_string (to_yojson t) + let ( = ) x y = Yojson.Safe.(equal x.attributes y.attributes) let update_attributes t attrs = {t with attributes = attrs} let attributes t = t.attributes let of_yojson x = - let open Yojson.Safe.Util in let open Util.Result_syntax in - let* zarr_format = match member "zarr_format" x with - | `Int (3 as i) -> Ok i - | `Null -> Error "group metadata must contain a zarr_format field." - | _ -> Error "zarr_format field must be the integer 3." - in - let+ node_type = match member "node_type" x with - | `String ("group" as g) -> Ok g - | `Null -> Error "group metadata must contain a node_type field." - | _ -> Error "node_type field must be 'group." - in - let attributes = match member "attributes" x with - | `Null -> `Null - | xs -> xs - in - {zarr_format; node_type; attributes} + let* zarr_format = ZarrFormat.of_yojson Yojson.Safe.Util.(member "zarr_format" x) in + let+ node_type = NodeType.Group.of_yojson Yojson.Safe.Util.(member "node_type" x) in + {zarr_format; node_type; attributes = Yojson.Safe.Util.member "attributes" x} let decode s = match of_yojson (Yojson.Safe.from_string s) with | Error e -> raise (Parse_error e) | Ok m -> m let show t = - Format.sprintf - {|"{zarr_format=%d; node_type=%s; attributes=%s}"|} - t.zarr_format t.node_type (Yojson.Safe.show t.attributes) + let x, y = NodeType.show t.node_type, Yojson.Safe.show t.attributes in + Format.sprintf {|"{zarr_format=%d; node_type=%s; attributes=%s}"|} t.zarr_format x y end diff --git a/zarr/src/metadata.mli b/zarr/src/metadata.mli index 1213fcc9..d3562ccd 100644 --- a/zarr/src/metadata.mli +++ b/zarr/src/metadata.mli @@ -8,23 +8,6 @@ exception Parse_error of string (** raised when parsing a metadata JSON document fails. *) -module FillValue : sig - type t = - | Char of char (** A single character string. *) - | Bool of bool (** Must be a JSON boolean. *) - | Int of Stdint.uint64 (** Value must be a JSON number with no fractional or exponent part that is within the representable range of the corresponding integer data type. *) - | Float of float (** Value representing a JSON float. *) - | FloatBits of float (** A JSON string specifying a byte representation of the float a hexstring. *) - | IntComplex of Complex.t (** A JSON 2-element array of integers representing a complex number. *) - | FloatComplex of Complex.t (** A JSON 2-element array of floats representing a complex number. *) - | FFComplex of Complex.t - | FBComplex of Complex.t - | BFComplex of Complex.t - | BBComplex of Complex.t - (** Provides an element value to use for uninitialised portions of - a Zarr array. The permitted values depend on the data type. *) -end - module Array : sig (** A module which contains functionality to work with a parsed JSON Zarr array metadata document. *) @@ -37,10 +20,10 @@ module Array : sig ?dimension_names:string option list -> ?attributes:Yojson.Safe.t -> codecs:Codecs.Chain.t -> - shape:int array -> + shape:int list -> 'a Ndarray.dtype -> 'a -> - int array -> + int list -> t (** [create ~codecs ~shape kind fv cshp] Creates a new array metadata document with codec chain [codecs], shape [shape], fill value [fv], @@ -56,10 +39,10 @@ module Array : sig @raise Parse_error if metadata string is invalid. *) - val shape : t -> int array + val shape : t -> int list (** [shape t] returns the shape of the zarr array represented by metadata type [t]. *) - val chunk_shape : t -> int array + val chunk_shape : t -> int list (** [chunk_shape t] returns the shape a chunk in this zarr array. *) val is_valid_kind : t -> 'a Ndarray.dtype -> bool @@ -83,23 +66,23 @@ module Array : sig (** [codecs t] Returns a type representing the chain of codecs applied when decoding/encoding a Zarr array chunk. *) - val index_coord_pair : t -> int array -> int array * int array + val index_coord_pair : t -> int list -> int list * int list (** [index_coord_pair t coord] maps a coordinate of this Zarr array to a pair of chunk index and coordinate {i within} that chunk. *) - val chunk_indices : t -> int array -> int array list + val chunk_indices : t -> int list -> int list list (** [chunk_indices t shp] returns a list of all chunk indices that would be contained in a zarr array of shape [shp] given the regular grid defined in array metadata [t]. *) - val chunk_key : t -> int array -> string + val chunk_key : t -> int list -> string (** [chunk_key t idx] returns a key encoding of a the chunk index [idx]. *) val update_attributes : t -> Yojson.Safe.t -> t (** [update_attributes t json] returns a new metadata type with an updated attribute field containing contents in [json] *) - val update_shape : t -> int array -> t + val update_shape : t -> int list -> t (** [update_shape t new_shp] returns a new metadata type containing shape [new_shp]. *) @@ -136,4 +119,8 @@ module Group : sig val attributes : t -> Yojson.Safe.t (** [attributes t] Returns a Yojson type containing user attributes assigned to the zarr group represented by [t]. *) + + val ( = ) : t -> t -> bool + (** [a = b] returns true if [a] [b] are equal array metadata documents + and false otherwise. *) end diff --git a/zarr/src/ndarray.ml b/zarr/src/ndarray.ml index f17eaf94..02b490e0 100644 --- a/zarr/src/ndarray.ml +++ b/zarr/src/ndarray.ml @@ -32,27 +32,31 @@ let dtype_size : type a. a dtype -> int = function | Int -> Sys.word_size / 8 | Nativeint -> Sys.word_size / 8 +let prod x = List.fold_left Int.mul 1 x + let cumprod x start stop = let acc = ref 1 in - for i = start to stop do acc := !acc * x.(i) done; !acc + for i = start to stop do acc := !acc * (List.nth x i) done; !acc (*strides[k] = [cumulative_product with start=k+1 end=n-1] of shape *) let make_strides shape = - let n = Array.length shape - 1 in + let n = List.length shape - 1 in Array.init (n + 1) (fun i -> cumprod shape (i + 1) n) -type 'a t = {shape : int array; strides : int array; dtype : 'a dtype; data : 'a array} +type 'a t = {shape : int list; strides : int array; dtype : 'a dtype; data : 'a array} let equal x y = x.data = y.data && x.shape = y.shape && x.dtype = y.dtype && x.strides = y.strides (* 1d index of coord [i0; ...; in] is SUM(i0 * strides[0] + ... + in * strides[n-1] *) let coord_to_index i s = Array.fold_left (fun a (x, y) -> Int.add a (x * y)) 0 @@ Array.combine i s -let create dtype shape fv = {shape; dtype; strides = make_strides shape; data = Array.make (Util.prod shape) fv} -let init dtype shape f = {shape; dtype; strides = make_strides shape; data = Array.init (Util.prod shape) f} +let coord_to_index' x s = let acc = ref 0 in List.iteri (fun i v -> acc := !acc + v * s.(i)) x; !acc +let create dtype shape fv = {shape; dtype; strides = make_strides shape; data = Array.make (prod shape) fv} +let init dtype shape f = {shape; dtype; strides = make_strides shape; data = Array.init (prod shape) f} let of_array dtype shape xs = {shape; dtype; strides = make_strides shape; data = xs} let data_type t = t.dtype -let size t = Util.prod t.shape -let ndims t = Array.length t.shape -let get t i = t.data.(coord_to_index i t.strides) -let set t i x = t.data.(coord_to_index i t.strides) <- x +let size t = prod t.shape +let ndims t = List.length t.shape +let get t i = t.data.(coord_to_index' i t.strides) +let set t i x = t.data.(coord_to_index' i t.strides) <- x +let set' t i x = t.data.(coord_to_index i t.strides) <- x let fill t v = Array.iteri (fun i _ -> t.data.(i) <- v) t.data let map f t = {t with data = Array.map f t.data} let iteri f t = Array.iteri f t.data @@ -77,7 +81,8 @@ let to_bigarray : type a b. a t -> (a, b) B.kind -> (a, b, B.c_layout) B.Genarray.t = fun x kind -> let initialize ~x c = x.data.(coord_to_index c x.strides) in - let f k = B.Genarray.init k C_layout x.shape (initialize ~x) in + let shape = Array.of_list x.shape in + let f k = B.Genarray.init k C_layout shape (initialize ~x) in match[@warning "-8"] kind with | B.Char as k -> f k | B.Int8_signed as k -> f k @@ -96,7 +101,7 @@ let to_bigarray : let of_bigarray : type a b c. (a, b, c) B.Genarray.t -> a t = fun x -> let x' = B.Genarray.change_layout x C_layout in - let shape = B.Genarray.dims x' in + let shape = B.Genarray.dims x' |> Array.to_list in let coord = Array.make (B.Genarray.num_dims x') 0 in let strides = make_strides shape in let initialize ~strides ~coord ~x' i = @@ -126,8 +131,8 @@ let of_bigarray : be an issue since the input array is never used again after it is transposed. *) let transpose ?axes x = let n = ndims x in - let p = Option.fold ~none:(Array.init n (fun i -> n - 1 - i)) ~some:Fun.id axes in - let shape = Array.map (fun i -> x.shape.(i)) p in + let p = Option.fold ~none:(List.init n (fun i -> n - 1 - i)) ~some:Fun.id axes in + let shape = List.map (fun i -> List.nth x.shape i) p in let x' = {x with shape; strides = make_strides shape; data = Array.copy x.data} in let c = Array.make n 0 and c' = Array.make n 0 in (* Project a 1d-indexed value of the input ndarray into its corresponding @@ -135,95 +140,87 @@ let transpose ?axes x = permutation described by [p].*) let project_1d_to_nd i a = index_to_coord ~strides:x.strides i c; - Array.iteri (fun j b -> c'.(j) <- c.(b)) p; - set x' c' a + List.iteri (fun j b -> c'.(j) <- c.(b)) p; + set' x' c' a in iteri project_1d_to_nd x; x' (* The [index] type definition as well as functions tagged with [@coverage off] - in this Indexing module were directly copied from the Owl project to emulate - its logic for munipulating slices. The code is licenced under the MIT license - and can be found at: https://github.com/owlbarn/owl + in this Indexing module were directly copied and modified from the Owl project + to emulate its logic for munipulating slices. The code is licenced under the + MIT license and can be found at: https://github.com/owlbarn/owl The MIT License (MIT) Copyright (c) 2016-2022 Liang Wang liang@ocaml.xyz *) module Indexing = struct type index = + | F | I of int - | L of int array - | R of int array + | T of int + | L of int list + | R of int * int + | R' of int * int * int + + (* internal restricted representation of index type *) + type index' = L of int list | R' of int * int * int (* this is copied from the Owl project so we skip testing it. *) let[@coverage off] check_slice_definition axis shp = - let axis_len = Array.length axis in - let shp_len = Array.length shp in + let axis_len = List.length axis in + let shp_len = List.length shp in assert (axis_len <= shp_len); (* add missing definition on higher dimensions *) - let axis = - if axis_len < shp_len - then ( - let suffix = Array.make (shp_len - axis_len) (R [||]) in - Array.append axis suffix) - else axis - in + let axis = if axis_len < shp_len then axis @ List.init (shp_len - axis_len) (fun _ -> F) else axis in (* re-format slice definition, note I_ will be replaced with L_ *) - Array.map2 - (fun i n -> - match i with - | I x -> - let x = if x >= 0 then x else n + x in - assert (x < n); - R [| x; x; 1 |] - | L x -> - let is_cont = ref true in - if Array.length x <> n then is_cont := false; - let x = - Array.mapi - (fun i j -> - let j = if j >= 0 then j else n + j in - assert (j < n); - if i <> j then is_cont := false; - j) - x - in - if !is_cont = true then R [| 0; n - 1; 1 |] else L x - | R x -> - (match Array.length x with - | 0 -> R [| 0; n - 1; 1 |] - | 1 -> - let a = if x.(0) >= 0 then x.(0) else n + x.(0) in - assert (a < n); - R [| a; a; 1 |] - | 2 -> - let a = if x.(0) >= 0 then x.(0) else n + x.(0) in - let b = if x.(1) >= 0 then x.(1) else n + x.(1) in - let c = if a <= b then 1 else -1 in - assert (not (a >= n || b >= n)); - R [| a; b; c |] - | 3 -> - let a = if x.(0) >= 0 then x.(0) else n + x.(0) in - let b = if x.(1) >= 0 then x.(1) else n + x.(1) in - let c = x.(2) in - assert (not (a >= n || b >= n || c = 0)); - assert (not ((a < b && c < 0) || (a > b && c > 0))); - R [| a; b; c |] - | _ -> failwith "check_slice_definition: error")) - axis - shp + List.map2 + (fun i n -> match i with + | I x -> + let x = if x >= 0 then x else n + x in + assert (x < n); + R' (x, x, 1) + | L x -> + let is_cont = ref true in + if List.length x <> n then is_cont := false; + let x = + List.mapi + (fun i j -> + let j = if j >= 0 then j else n + j in + assert (j < n); + if i <> j then is_cont := false; + j) + x + in + if !is_cont = true then R' (0, n-1, 1) else L x + | F -> R' (0, n - 1, 1) + | T x -> + let a = if x >= 0 then x else n + x in + assert (a < n); + R' (a, a, 1) + | R (x, y) -> + let a = if x >= 0 then x else n + x in + let b = if y >= 0 then y else n + y in + let c = if a <= b then 1 else -1 in + assert (not (a >= n || b >= n)); + R' (a, b, c) + | R' (x, y, c) -> + let a = if x >= 0 then x else n + x in + let b = if y >= 0 then y else n + y in + assert (not (a >= n || b >= n || c = 0)); + assert (not ((a < b && c < 0) || (a > b && c > 0))); + R' (a, b, c)) axis shp (* this was opied from the Owl project so we skip testing it. *) let[@coverage off] calc_slice_shape axis = - Array.map - (function - | I _x -> 1 (* never reached *) - | L x -> Array.length x - | R x -> abs ((x.(1) - x.(0)) / x.(2)) + 1) axis + let f = function + | L x -> List.length x + | R' (x, y, z) -> abs ((y - x) / z) + 1 + in + List.map f axis let rec cartesian_prod : int list list -> int list list = function | [] -> [[]] - | x :: xs -> - List.concat_map (fun i -> List.map (List.cons i) (cartesian_prod xs)) x + | x :: xs -> List.concat_map (fun i -> List.map (List.cons i) (cartesian_prod xs)) x let range ~step start stop = let rec aux ~step ~stop acc = function @@ -234,28 +231,22 @@ module Indexing = struct (* get indices from a reformated slice *) let indices_of_slice = function - | R [|start; stop; step|] -> range ~step start stop - | L l -> Array.to_list l - (* this is added for exhaustiveness but is never reached since - a reformatted slice replaces a I index with an R index.*) - | _ -> failwith "Invalid slice index." + | R' (start, stop, step) -> range ~step start stop + | L x -> x let coords_of_slice slice shape = - let indices = Array.map indices_of_slice (check_slice_definition slice shape) in - let cprod = cartesian_prod (Array.to_list indices) in - let listofarrays = List.map Array.of_list cprod in - Array.of_list listofarrays + cartesian_prod @@ List.map indices_of_slice (check_slice_definition slice shape) let slice_of_coords = function - | [] -> [||] + | [] as x -> x | x :: _ as xs -> let module S = Set.Make(Int) in - let to_slice_index s = L (Array.of_list (S.elements s)) in let add_unique ~acc i y = if S.mem y acc.(i) then () else acc.(i) <- S.add y acc.(i) in - let fill_dims coord acc = Array.iteri (add_unique ~acc) coord; acc in - let ndims = Array.length x in + let fill_dims coord acc = List.iteri (add_unique ~acc) coord; acc in + let ndims = List.length x in let indices = Array.make ndims S.empty in - Array.map to_slice_index (List.fold_right fill_dims xs indices) + let dimsets = List.fold_right fill_dims xs indices in + List.map (fun s -> (L (S.elements s) : index)) (Array.to_list dimsets) let slice_shape slice array_shape = calc_slice_shape (check_slice_definition slice array_shape) diff --git a/zarr/src/ndarray.mli b/zarr/src/ndarray.mli index a671d10a..f2473e33 100644 --- a/zarr/src/ndarray.mli +++ b/zarr/src/ndarray.mli @@ -22,11 +22,11 @@ type 'a t val dtype_size : 'a dtype -> int (** [dtype_size kind] returns the size in bytes of data type [kind].*) -val create : 'a dtype -> int array -> 'a -> 'a t +val create : 'a dtype -> int list -> 'a -> 'a t (** [create k s v] creates an N-dimensional array with data_type [k], shape [s] and fill value [v].*) -val init : 'a dtype -> int array -> (int -> 'a) -> 'a t +val init : 'a dtype -> int list -> (int -> 'a) -> 'a t (** [init k s f] creates an N-dimensional array with data_type [k], shape [s] and every element value is assigned using function [f].*) @@ -39,7 +39,7 @@ val size : 'a t -> int val ndims : 'a t -> int (** [ndims x] is the number of dimensions of [x].*) -val shape : 'a t -> int array +val shape : 'a t -> int list (** [shape x] returns an array with the size of each dimension of [x].*) val byte_size : 'a t -> int @@ -51,15 +51,15 @@ val to_array : 'a t -> 'a array {!data_type}. Note that data is not copied, so if the caller modifies the returned array, the changes will be reflected in [x].*) -val of_array : 'a dtype -> int array -> 'a array -> 'a t +val of_array : 'a dtype -> int list -> 'a array -> 'a t (** [of_array k s x] creates an n-dimensional array of shape [s] and data_type [k] using elements of [x]. Note that the data is not copied, so the caller must ensure not to modify [x] afterwards.*) -val get : 'a t -> int array -> 'a +val get : 'a t -> int list -> 'a (** [get x c] returns element of [x] at coordinate [c].*) -val set : 'a t -> int array -> 'a -> unit +val set : 'a t -> int list -> 'a -> unit (** [set x c v] sets coordinate [c] of [x] to value [v].*) val iteri : (int -> 'a -> unit) -> 'a t -> unit @@ -79,7 +79,7 @@ val iter : ('a -> unit) -> 'a t -> unit val equal : 'a t -> 'a t -> bool (** [equal x y] is [true] iff [x] and [y] are equal, else [false].*) -val transpose : ?axes:int array -> 'a t -> 'a t +val transpose : ?axes:int list -> 'a t -> 'a t (** [transpose o x] permutes the axes of [x] according to [o].*) val to_bigarray : 'a t -> ('a, 'b) Bigarray.kind -> ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t @@ -93,16 +93,19 @@ module Indexing : sig slices for working with Zarr arrays. *) type index = + | F | I of int - | L of int array - | R of int array + | T of int + | L of int list + | R of int * int + | R' of int * int * int - val slice_of_coords : int array list -> index array + val slice_of_coords : int list list -> index list (** [slice_of_coords c] takes a list of array coordinates and returns a slice corresponding to the coordinates. Elements of each slice variant are sorted in increasing order.*) - val coords_of_slice : index array -> int array -> int array array + val coords_of_slice : index list -> int list -> int list list (** [coords_of_slice s shp] returns an array of coordinates given a slice [s] and array shape [shp]. *) @@ -111,7 +114,7 @@ module Indexing : sig list [ll]. It is mainly used to generate a C-order of chunk indices in a regular Zarr array grid. *) - val slice_shape : index array -> int array -> int array + val slice_shape : index list -> int list -> int list (** [slice_shape s shp] returns the shape of slice [s] within an array of shape [shp]. *) end diff --git a/zarr/src/storage/storage.ml b/zarr/src/storage/storage.ml index 136e3d07..17b5d294 100644 --- a/zarr/src/storage/storage.ml +++ b/zarr/src/storage/storage.ml @@ -74,7 +74,7 @@ module Make (IO : Types.IO) (Store : Types.Store with type 'a io = 'a IO.t) = st end module Array = struct - module ArrayMap = Util.ArrayMap + module CoordMap = Util.CoordMap module Indexing = Ndarray.Indexing let exists t node = is_member t (Node.Array.to_metakey node) let delete t node = erase_prefix t (Node.Array.to_key node ^ "/") @@ -91,9 +91,9 @@ module Make (IO : Types.IO) (Store : Types.Store with type 'a io = 'a IO.t) = st let write t node slice x = let update_ndarray ~arr (c, v) = Ndarray.set arr c v in - let add_coord_value ~meta acc (co, y) = + let add_coord_value ~meta acc co y = let chunk_idx, c = Metadata.Array.index_coord_pair meta co in - ArrayMap.add_to_list chunk_idx (c, y) acc + CoordMap.add_to_list chunk_idx (c, y) acc in let update_chunk ~t ~meta ~prefix ~chain ~fv ~repr (idx, pairs) = let ckey = prefix ^ Metadata.Array.chunk_key meta idx in @@ -121,19 +121,17 @@ module Make (IO : Types.IO) (Store : Types.Store with type 'a io = 'a IO.t) = st let kind = Ndarray.data_type x in if not (Metadata.Array.is_valid_kind meta kind) then raise Invalid_data_type else let coords = Indexing.coords_of_slice slice shape in - let coord_value_pair = Array.combine coords (Ndarray.to_array x) in - let m = Array.fold_left (add_coord_value ~meta) ArrayMap.empty coord_value_pair in + let m = List.fold_left2 (add_coord_value ~meta) CoordMap.empty coords (Ndarray.to_array x |> Array.to_list) in let fv = Metadata.Array.fillvalue_of_kind meta kind and repr = Codecs.{kind; shape = Metadata.Array.chunk_shape meta} and prefix = Node.Array.to_key node ^ "/" - and chain = Metadata.Array.codecs meta - and bindings = ArrayMap.bindings m in - IO.iter (update_chunk ~t ~meta ~prefix ~chain ~fv ~repr) bindings + and chain = Metadata.Array.codecs meta in + IO.iter (update_chunk ~t ~meta ~prefix ~chain ~fv ~repr) (CoordMap.bindings m) let read (type a) t node slice (kind : a Ndarray.dtype) = - let add_indexed_coord ~meta acc (i, y) = + let add_indexed_coord ~meta acc i y = let chunk_idx, c = Metadata.Array.index_coord_pair meta y in - ArrayMap.add_to_list chunk_idx (i, c) acc + CoordMap.add_to_list chunk_idx (i, c) acc in let read_chunk ~t ~meta ~prefix ~chain ~fv ~repr (idx, pairs) = let ckey = prefix ^ Metadata.Array.chunk_key meta idx in @@ -153,13 +151,14 @@ module Make (IO : Types.IO) (Store : Types.Store with type 'a io = 'a IO.t) = st let slice_shape = try Indexing.slice_shape slice shape with | Assert_failure _ -> raise Invalid_array_slice in - let icoords = Array.mapi (fun i v -> i, v) (Indexing.coords_of_slice slice shape) in - let m = Array.fold_left (add_indexed_coord ~meta) ArrayMap.empty icoords + let numel = List.fold_left Int.mul 1 slice_shape in + let coords = Indexing.coords_of_slice slice shape in + let m = List.fold_left2 (add_indexed_coord ~meta) CoordMap.empty List.(init numel Fun.id) coords and chain = Metadata.Array.codecs meta and prefix = Node.Array.to_key node ^ "/" and fv = Metadata.Array.fillvalue_of_kind meta kind and repr = Codecs.{kind; shape = Metadata.Array.chunk_shape meta} in - let+ ps = IO.concat_map (read_chunk ~t ~meta ~prefix ~chain ~fv ~repr) (ArrayMap.bindings m) in + let+ ps = IO.concat_map (read_chunk ~t ~meta ~prefix ~chain ~fv ~repr) (CoordMap.bindings m) in (* sorting restores the C-order of the decoded array coordinates.*) let sorted_pairs = List.fast_sort (fun (x, _) (y, _) -> Int.compare x y) ps in let vs = List.map snd sorted_pairs in @@ -167,7 +166,7 @@ module Make (IO : Types.IO) (Store : Types.Store with type 'a io = 'a IO.t) = st let reshape t node new_shape = let module S = Set.Make (struct - type t = int array + type t = int list let compare : t -> t -> int = Stdlib.compare end) in @@ -181,7 +180,7 @@ module Make (IO : Types.IO) (Store : Types.Store with type 'a io = 'a IO.t) = st in let* meta = metadata t node in let old_shape = Metadata.Array.shape meta in - if Array.(length new_shape <> length old_shape) then raise Invalid_resize_shape else + if List.(length new_shape <> length old_shape) then raise Invalid_resize_shape else let s = S.of_list (Metadata.Array.chunk_indices meta old_shape) and s' = S.of_list (Metadata.Array.chunk_indices meta new_shape) in let unreachable_chunks = S.elements (S.diff s s') diff --git a/zarr/src/storage/storage_intf.ml b/zarr/src/storage/storage_intf.ml index c23312c5..6f81145b 100644 --- a/zarr/src/storage/storage_intf.ml +++ b/zarr/src/storage/storage_intf.ml @@ -52,8 +52,8 @@ module type S = sig ?dimension_names:string option list -> ?attributes:Yojson.Safe.t -> codecs:Codecs.codec list -> - shape:int array -> - chunks:int array -> + shape:int list -> + chunks:int list -> 'a Ndarray.dtype -> 'a -> Node.Array.t -> @@ -89,7 +89,7 @@ module type S = sig (** [exists t n] returns [true] if array node [n] is a member of store [t] and [false] otherwise. *) - val write : t -> Node.Array.t -> Ndarray.Indexing.index array -> 'a Ndarray.t -> unit io + val write : t -> Node.Array.t -> Ndarray.Indexing.index list -> 'a Ndarray.t -> unit io (** [write t n s x] writes n-dimensional array [x] to the slice [s] of array node [n] in store [t]. @@ -99,7 +99,7 @@ module type S = sig if the kind of [x] is not compatible with node [n]'s data type as described in its metadata document. *) - val read : t -> Node.Array.t -> Ndarray.Indexing.index array -> 'a Ndarray.dtype -> 'a Ndarray.t io + val read : t -> Node.Array.t -> Ndarray.Indexing.index list -> 'a Ndarray.dtype -> 'a Ndarray.t io (** [read t n s k] reads an n-dimensional array of size determined by slice [s] from array node [n]. @@ -109,7 +109,7 @@ module type S = sig @raise Invalid_array_slice if the slice [s] is not a valid slice of array node [n].*) - val reshape : t -> Node.Array.t -> int array -> unit io + val reshape : t -> Node.Array.t -> int list -> unit io (** [reshape t n shape] resizes array node [n] of store [t] into new size [shape]. Note that when the resizing involves shrinking an array along any dimensions, any old unreachable chunks that fall outside of diff --git a/zarr/src/util.ml b/zarr/src/util.ml index 6255e11b..4c7fc777 100644 --- a/zarr/src/util.ml +++ b/zarr/src/util.ml @@ -1,6 +1,6 @@ -module ArrayMap = struct +module CoordMap = struct include Map.Make (struct - type t = int array + type t = int list let compare : t -> t -> int = Stdlib.compare end) @@ -18,7 +18,6 @@ module Result_syntax = struct end let get_name j = Yojson.Safe.Util.(member "name" j |> to_string) -let prod x = Array.fold_left Int.mul 1 x let max = Array.fold_left Int.max Int.min_int (* Obtained from: https://discuss.ocaml.org/t/how-to-create-a-new-file-while-automatically-creating-any-intermediate-directories/14837/5?u=zoj613 *) diff --git a/zarr/src/util.mli b/zarr/src/util.mli index e836545d..698a3942 100644 --- a/zarr/src/util.mli +++ b/zarr/src/util.mli @@ -1,7 +1,7 @@ -(** A finite map over integer array keys. *) -module ArrayMap : sig - include Map.S with type key = int array - val add_to_list : int array -> 'a -> 'a list t -> 'a list t +(** A finite map over Zarr array coordinate keys. *) +module CoordMap : sig + include Map.S with type key = int list + val add_to_list : int list -> 'a -> 'a list t -> 'a list t (** [add_to_list k v map] is [map] with [k] mapped to [l] such that [l] is [v :: ArrayMap.find k map] if [k] was bound in [map] and [v] otherwise.*) end @@ -17,9 +17,6 @@ val get_name : Yojson.Safe.t -> string configuration of the form [{"name": value, "configuration": ...}], as defined in the Zarr V3 specification. *) -val prod : int array -> int -(** [prod x] returns the product of the elements of [x]. *) - val max : int array -> int (** [max x] returns the maximum element of an integer array [x]. *) diff --git a/zarr/test/test_codecs.ml b/zarr/test/test_codecs.ml index ed8155b5..916fe5d7 100644 --- a/zarr/test/test_codecs.ml +++ b/zarr/test/test_codecs.ml @@ -2,132 +2,106 @@ open OUnit2 open Zarr open Zarr.Codecs -let decode_chain ~shape ~str ~msg = - (match Chain.of_yojson shape @@ Yojson.Safe.from_string str with +let decode_chain ~shape ~str ~msg = begin match Chain.of_yojson shape @@ Yojson.Safe.from_string str with | Ok _ -> assert_failure "Impossible to decode an unsupported codec."; - | Error s -> assert_equal ~printer:Fun.id msg s) + | Error s -> assert_equal ~printer:Fun.id msg s end -let bytes_encode_decode - : type a. a array_repr -> a -> unit - = fun decoded_repr fill_value -> +let bytes_encode_decode (type a) (decoded_repr : a array_repr) (fill_value : a) = List.iter (fun bytes_codec -> let chain = [bytes_codec] in let c = Chain.create decoded_repr.shape chain in let arr = Ndarray.create decoded_repr.kind decoded_repr.shape fill_value in - let decoded = Chain.decode c decoded_repr @@ Chain.encode c arr in + let decoded = Chain.decode c decoded_repr (Chain.encode c arr) in assert_equal arr decoded) [`Bytes LE; `Bytes BE] let tests = [ "test codec chain" >:: (fun _ -> - let shape = [|10; 15; 10|] in + let shape = [10; 15; 10] in let kind = Ndarray.Int16 in let fill_value = 10 in let shard_cfg = - {chunk_shape = [|2; 5; 5|] + {chunk_shape = [2; 5; 5] ;index_location = End ;index_codecs = [`Bytes LE; `Crc32c] - ;codecs = [`Transpose [|0; 1; 2|]; `Bytes BE; `Gzip L1]} - in - let chain = - [`Transpose [|2; 1; 0; 3|]; `ShardingIndexed shard_cfg; `Crc32c; `Gzip L9] + ;codecs = [`Transpose [0; 1; 2]; `Bytes BE; `Gzip L1]} in - assert_raises - (Zarr.Codecs.Invalid_transpose_order) - (fun () -> Chain.create shape chain); - - let chain = [`ShardingIndexed shard_cfg; `Transpose [|2; 1; 0|]; `Gzip L0] in - assert_raises - (Zarr.Codecs.Invalid_codec_ordering) - (fun () -> Chain.create shape chain); - - let chain = [`Transpose [|2; 1; 0|]; `Crc32c] in - assert_raises - (Zarr.Codecs.Array_to_bytes_invariant) - (fun () -> Chain.create shape chain); - - let chain = - [`Transpose [|2; 1; 0|]; `ShardingIndexed shard_cfg; `Crc32c; `Gzip L9] in + let chain = [`Transpose [2; 1; 0; 3]; `ShardingIndexed shard_cfg; `Crc32c; `Gzip L9] in + assert_raises (Zarr.Codecs.Invalid_transpose_order) (fun () -> Chain.create shape chain); + let chain = [`ShardingIndexed shard_cfg; `Transpose [2; 1; 0]; `Gzip L0] in + assert_raises (Zarr.Codecs.Invalid_codec_ordering) (fun () -> Chain.create shape chain); + let chain = [`Transpose [2; 1; 0]; `Crc32c] in + assert_raises (Zarr.Codecs.Array_to_bytes_invariant) (fun () -> Chain.create shape chain); + let chain = [`Transpose [2; 1; 0]; `ShardingIndexed shard_cfg; `Crc32c; `Gzip L9] in let c = Chain.create shape chain in let arr = Ndarray.create kind shape fill_value in let encoded = Chain.encode c arr in assert_equal arr @@ Chain.decode c {shape; kind} encoded; - - decode_chain ~shape ~str:"[]" ~msg:"No codec specified."; - + decode_chain ~shape ~str:"[]" ~msg:"Must be exactly one array->bytes codec."; decode_chain ~shape ~str:{|[{"name": "gzip", "configuration": {"level": 1}}]|} ~msg:"Must be exactly one array->bytes codec."; - decode_chain ~shape - ~str:{|[{"name": "fake_codec"}, {"name": "bytes", - "configuration": {"endian": "little"}}]|} + ~str:{|[{"name": "fake_codec"}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"fake_codec codec is unsupported or has invalid configuration."; let str = Chain.to_yojson c |> Yojson.Safe.to_string in (match Chain.of_yojson shape @@ Yojson.Safe.from_string str with | Ok v -> assert_equal v c; - | Error _ -> - assert_failure "a serialized chain should successfully deserialize")) + | Error _ -> assert_failure "a serialized chain should successfully deserialize")) ; "test transpose codec" >:: (fun _ -> (* test decoding of chain with misspelled configuration name *) decode_chain - ~shape:[|1; 1|] + ~shape:[1; 1] ~str:{|[{"name": "transpose", "configuration": {"ordeR": [0, 1]}}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"transpose codec is unsupported or has invalid configuration."; (* test decoding of chain with empty transpose order *) decode_chain - ~shape:[||] + ~shape:[] ~str:{|[{"name": "transpose", "configuration": {"order": []}}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"transpose codec is unsupported or has invalid configuration."; (* test decoding of chain with duplicated transpose order *) decode_chain - ~shape:[|1; 1|] + ~shape:[1; 1] ~str:{|[{"name": "transpose", "configuration": {"order": [0, 0]}}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"transpose codec is unsupported or has invalid configuration."; (* test decoding with negative transpose dimensions. *) decode_chain - ~shape:[|1|] + ~shape:[1] ~str:{|[{"name": "transpose", "configuration": {"order": [-1]}}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"transpose codec is unsupported or has invalid configuration."; - (* test decoding transpose order bigger than an array's dimensionality. *) decode_chain - ~shape:[|2; 2|] + ~shape:[2; 2] ~str:{|[{"name": "transpose", "configuration": {"order": [0, 1, 2]}}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"transpose codec is unsupported or has invalid configuration."; (* test decoding transpose order containing non-integer value(s). *) decode_chain - ~shape:[|2; 2|] + ~shape:[2; 2] ~str:{|[{"name": "transpose", "configuration": {"order": [0, 1, 2.0]}}, {"name": "bytes", "configuration": {"endian": "little"}}]|} ~msg:"transpose codec is unsupported or has invalid configuration."; - (* test encoding of chain with an empty or too big transpose order. *) - let shape = [|2; 2; 2|] in - let chain = [`Transpose [||]; `Bytes LE] in - assert_raises - (Zarr.Codecs.Invalid_transpose_order) - (fun () -> Chain.create shape chain); - assert_raises - (Zarr.Codecs.Invalid_transpose_order) - (fun () -> Chain.create shape [`Transpose [|4; 0; 1|]; `Bytes LE])) + let shape = [2; 2; 2] in + let chain = [`Transpose []; `Bytes LE] in + assert_raises (Zarr.Codecs.Invalid_transpose_order) (fun () -> Chain.create shape chain); + assert_raises (Zarr.Codecs.Invalid_transpose_order) (fun () -> Chain.create shape [`Transpose [4; 0; 1]; `Bytes LE])) ; "test sharding indexed codec" >:: (fun _ -> (* test missing chunk_shape field. *) decode_chain - ~shape:[||] + ~shape:[] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -139,7 +113,7 @@ let tests = [ ~msg:"Must be exactly one array->bytes codec."; (*test missing index_location field. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -151,7 +125,7 @@ let tests = [ ~msg:"Must be exactly one array->bytes codec."; (* test missing codecs field. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -162,7 +136,7 @@ let tests = [ ~msg:"Must be exactly one array->bytes codec."; (* tests missing index_codecs field. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -173,7 +147,7 @@ let tests = [ ~msg:"Must be exactly one array->bytes codec."; (* tests incorrect value for index_location field. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -186,7 +160,7 @@ let tests = [ ~msg:"Must be exactly one array->bytes codec."; (* tests incorrect non-integer values for chunk_shape field. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -199,7 +173,7 @@ let tests = [ ~msg:"Must be exactly one array->bytes codec."; (* tests unspecified codecs field. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -212,7 +186,7 @@ let tests = [ (* tests ill-formed codecs/index_codecs field. In this case, missing the required bytes->bytes codec. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -225,7 +199,7 @@ let tests = [ (* tests ill-formed codecs/index_codecs field. In this case, parsing an unsupported/unknown codec. *) decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:{|[ {"name": "sharding_indexed", "configuration": @@ -241,7 +215,7 @@ let tests = [ List.iter (fun c -> decode_chain - ~shape:[|5; 5; 5|] + ~shape:[5; 5; 5] ~str:(Format.sprintf {|[ {"name": "sharding_indexed", "configuration": @@ -255,30 +229,27 @@ let tests = [ [{|{"name": "zstd", "configuration": {"level": 0, "checksum": false}}|} ;{|{"name": "gzip", "configuration": {"level": 1}}|}]; - let shape = [|10; 15; 10|] in + let shape = [10; 15; 10] in let kind = Ndarray.Float64 in let cfg = - {chunk_shape = [|3; 5; 5|] + {chunk_shape = [3; 5; 5] ;index_location = Start - ;index_codecs = [`Transpose [|0; 3; 1; 2|]; `Bytes LE; `Crc32c] + ;index_codecs = [`Transpose [0; 3; 1; 2]; `Bytes LE; `Crc32c] ;codecs = [`Bytes BE]} in let chain = [`ShardingIndexed cfg] in (*test failure for chunk shape not evenly dividing shard. *) - assert_raises - (Zarr.Codecs.Invalid_sharding_chunk_shape) - (fun () -> Chain.create shape chain); + assert_raises (Zarr.Codecs.Invalid_sharding_chunk_shape) (fun () -> Chain.create shape chain); (* test failure for chunk shape length not equal to dimensionality of shard.*) assert_raises (Zarr.Codecs.Invalid_sharding_chunk_shape) - (fun () -> - Chain.create shape @@ [`ShardingIndexed {cfg with chunk_shape = [|5|]}]); + (fun () -> Chain.create shape @@ [`ShardingIndexed {cfg with chunk_shape = [5]}]); - let chain = [`ShardingIndexed {cfg with chunk_shape = [|5; 3; 5|]}] in + let chain = [`ShardingIndexed {cfg with chunk_shape = [5; 3; 5]}] in let c = Chain.create shape chain in let arr = Ndarray.create kind shape (-10.) in let encoded = Chain.encode c arr in - assert_equal arr @@ Chain.decode c {shape; kind} encoded; + assert_equal arr (Chain.decode c {shape; kind} encoded); (* test correctness of decoding nested sharding codecs.*) let str = @@ -300,8 +271,7 @@ let tests = [ [{"name": "bytes", "configuration": {"endian": "big"}}]}}]}}]|} in let r = Chain.of_yojson shape @@ Yojson.Safe.from_string str in - assert_bool - "Encoding this nested sharding chain should not fail" @@ Result.is_ok r; + assert_bool "Encoding this nested sharding chain should not fail" @@ Result.is_ok r; (* test if decoding of indexed_codec with sharding for array->bytes fails.*) let str = {|[ @@ -322,38 +292,34 @@ let tests = [ [{"name": "bytes", "configuration": {"endian": "big"}}]}}]}}]|} in let r = Chain.of_yojson shape @@ Yojson.Safe.from_string str in - assert_bool - "Decoding of index_codec chain with sharding should fail" @@ - Result.is_error r) + assert_bool "Decoding of index_codec chain with sharding should fail" @@ Result.is_error r) ; "test gzip codec" >:: (fun _ -> (* test wrong compression level *) decode_chain - ~shape:[||] + ~shape:[] ~str:{|[{"name": "bytes", "configuration": {"endian": "little"}}, {"name": "gzip", "configuration": {"level": -1}}]|} ~msg:"gzip codec is unsupported or has invalid configuration."; (* test incorrect configuration *) decode_chain - ~shape:[||] + ~shape:[] ~str:{|[{"name": "bytes", "configuration": {"endian": "little"}}, {"name": "gzip", "configuration": {"something": -1}}]|} ~msg:"gzip codec is unsupported or has invalid configuration."; (* test correct deserialization of gzip compression level *) - let shape = [|10; 15; 10|] in + let shape = [10; 15; 10] in List.iter (fun level -> let str = Format.sprintf {|[{"name": "bytes", "configuration": {"endian": "little"}}, - {"name": "gzip", "configuration": {"level": %d}}]|} level - in + {"name": "gzip", "configuration": {"level": %d}}]|} level in let r = Chain.of_yojson shape @@ Yojson.Safe.from_string str in - assert_bool - "Encoding this chain should not fail" @@ Result.is_ok r) + assert_bool "Encoding this chain should not fail" @@ Result.is_ok r) [0; 1; 2; 3; 4; 5; 6; 7; 8; 9]; (* test encoding/decoding for various compression levels *) @@ -374,7 +340,7 @@ let tests = [ List.iter (fun l -> decode_chain - ~shape:[||] + ~shape:[] ~str:(Format.sprintf {|[{"name": "bytes", "configuration": {"endian": "little"}}, {"name": "zstd", "configuration": {"level": %d, "checksum": false}}]|} l) ~msg:"zstd codec is unsupported or has invalid configuration.") @@ -382,13 +348,13 @@ let tests = [ (* test incorrect configuration *) decode_chain - ~shape:[||] + ~shape:[] ~str:{|[{"name": "bytes", "configuration": {"endian": "little"}}, {"name": "zstd", "configuration": {"something": -1}}]|} ~msg:"zstd codec is unsupported or has invalid configuration."; (* test correct deserialization of zstd compression level *) - let shape = [|10; 15; 10|] in + let shape = [10; 15; 10] in List.iter (fun level -> let str = @@ -410,7 +376,7 @@ let tests = [ ; "test bytes codec" >:: (fun _ -> - let shape = [|2; 2; 2|] in + let shape = [2; 2; 2] in (* test decoding of chain with invalid endianness name *) decode_chain ~shape @@ -424,44 +390,31 @@ let tests = [ (* test encoding/decoding of Char *) bytes_encode_decode {shape; kind = Ndarray.Char} '?'; - (* test encoding/decoding of Bool *) bytes_encode_decode {shape; kind = Ndarray.Bool} false; bytes_encode_decode {shape; kind = Ndarray.Bool} true; - (* test encoding/decoding of int8 *) bytes_encode_decode {shape; kind = Ndarray.Int8} 0; - (* test encoding/decoding of uint8 *) bytes_encode_decode {shape; kind = Ndarray.Uint8} 0; - (* test encoding/decoding of int16 *) bytes_encode_decode {shape; kind = Ndarray.Int16} 0; - (* test encoding/decoding of uint16 *) bytes_encode_decode {shape; kind = Ndarray.Uint16} 0; - (* test encoding/decoding of int32 *) bytes_encode_decode {shape; kind = Ndarray.Int32} 0l; - (* test encoding/decoding of int64 *) bytes_encode_decode {shape; kind = Ndarray.Int64} 0L; - (* test encoding/decoding of float32 *) bytes_encode_decode {shape; kind = Ndarray.Float32} 0.0; - (* test encoding/decoding of float64 *) bytes_encode_decode {shape; kind = Ndarray.Float64} 0.0; - (* test encoding and decoding of Complex32 *) bytes_encode_decode {shape; kind = Ndarray.Complex32} Complex.zero; - (* test encoding/decoding of complex64 *) bytes_encode_decode {shape; kind = Ndarray.Complex64} Complex.zero; - (* test encoding/decoding of int *) bytes_encode_decode {shape; kind = Ndarray.Int} Int.max_int; - (* test encoding/decoding of int *) bytes_encode_decode {shape; kind = Ndarray.Nativeint} Nativeint.max_int) ] diff --git a/zarr/test/test_indexing.ml b/zarr/test/test_indexing.ml index 103bcd13..5094ce6c 100644 --- a/zarr/test/test_indexing.ml +++ b/zarr/test/test_indexing.ml @@ -6,66 +6,49 @@ open Zarr.Indexing let tests = [ "slice from coords" >:: (fun _ -> - let coords = - [[|0; 1; 2; 3|] - ;[|9; 8; 7; 6|] - ;[|9; 8; 7; 6|] (* slice_of_coords should be duplicate coord-aware *) - ;[|5; 4; 3; 2|]] - in - let expected = - [|L [|0; 5; 9|] - ;L [|1; 4; 8|] - ;L [|2; 3; 7|] - ;L [|2; 3; 6|]|] - in + (* slice_of_coords should be duplicate coord-aware *) + let coords = [[0; 1; 2; 3]; [9; 8; 7; 6]; [9; 8; 7; 6]; [5; 4; 3; 2]] in + let expected = [L [0; 5; 9]; L [1; 4; 8]; L [2; 3; 7]; L [2; 3; 6]] in assert_equal expected @@ Indexing.slice_of_coords coords; - assert_equal [||] @@ Indexing.slice_of_coords []) + assert_equal [] @@ Indexing.slice_of_coords []) ; "coords from slice" >:: (fun _ -> - let shape = [|10; 10; 10|] in - let slice = [|L [|0; 9; 5|]; I 1; R [|9; 3; -3|]|] in - let expected = - [|[|0; 1; 9|]; [|0; 1; 6|]; [|0; 1; 3|] - ;[|9; 1; 9|]; [|9; 1; 6|]; [|9; 1; 3|] - ;[|5; 1; 9|]; [|5; 1; 6|]; [|5; 1; 3|]|] - in - assert_equal ~printer:[%show: int array array] expected @@ Indexing.coords_of_slice slice shape; + let shape = [10; 10; 10] in + let slice = [L [0; 9; 5]; I 1; R' (9, 3, -3)] in + let expected = [[0; 1; 9]; [0; 1; 6]; [0; 1; 3]; [9; 1; 9]; [9; 1; 6]; [9; 1; 3] ;[5; 1; 9]; [5; 1; 6]; [5; 1; 3]] in + assert_equal ~printer:[%show: int list list] expected @@ Indexing.coords_of_slice slice shape; (* test using an empty slice translates to selection the whole array. *) - assert_equal - [|[|0; 0|]; [|0; 1|]; [|1; 0|]; [|1; 1|]|] - (Indexing.coords_of_slice [||] [|2; 2|]); + assert_equal [[0; 0]; [0; 1]; [1; 0]; [1; 1]] (Indexing.coords_of_slice [] [2; 2]); (* test missing definition on higher dimensions *) - let shape = [|3; 3; 3|] in - let expected = [|[|2; 0; 0|]; [|2; 0; 1|]; [|2; 0; 2|]|] in - let slice = [|I 2; I 0|] in + let shape = [3; 3; 3] in + let expected = [[2; 0; 0]; [2; 0; 1]; [2; 0; 2]] in + let slice = [I 2; I 0] in assert_equal expected @@ Indexing.coords_of_slice slice shape; (* test negative I value *) - let expected = [|[|2; 2; 0|]; [|2; 2; 1|]; [|2; 2; 2|]|] in - let slice = [|I 2; I (-1)|] in + let expected = [[2; 2; 0]; [2; 2; 1]; [2; 2; 2]] in + let slice = [I 2; I (-1)] in assert_equal expected @@ Indexing.coords_of_slice slice shape; - let slice = [|R [|(-1); 2|]; L [|(-1)|]; L [|0; 0; 0|]|] in - let expected = [|[|2; 2; 0|]; [|2; 2; 0|]; [|2; 2; 0|]|] in + let slice = [R (-1, 2); L [-1]; L [0; 0; 0]] in + let expected = [[2; 2; 0]; [2; 2; 0]; [2; 2; 0]] in assert_equal expected @@ Indexing.coords_of_slice slice shape; - let slice = [|R [|0; (-2)|]; R [|1|]; R [|(-1)|]|] in - let expected = [|[|0; 1; 2|]; [|1; 1; 2|]|] in + let slice = [R (0, -2); T 1; T (-1)] in + let expected = [[0; 1; 2]; [1; 1; 2]] in assert_equal expected @@ Indexing.coords_of_slice slice shape; - let slice = [|R [|1; 0|]; R [|1|]; R [|(-1)|]|] in - let expected = [|[|1; 1; 2|]; [|0; 1; 2|]|] in + let slice = [R (1, 0); T 1; T (-1)] in + let expected = [[1; 1; 2]; [0; 1; 2]] in assert_equal expected @@ Indexing.coords_of_slice slice shape; - let slice = [|I 2; I (-1); R [|(-1); (-1); 1|]|] in - let expected = [|[|2; 2; 2|]|] in + let slice = [I 2; I (-1); R' (-1, -1, 1)] in + let expected = [[2; 2; 2]] in assert_equal expected @@ Indexing.coords_of_slice slice shape ) ; "compute slice shape" >:: (fun _ -> - let shape = [|10; 10; 10|] in - let slice = - Ndarray.Indexing.[|L [|0; 9; 5|]; I 1; R [|2; 9; 1|]|] - in - assert_equal [|3; 1; 8|] @@ Indexing.slice_shape slice shape; - assert_equal shape @@ Indexing.slice_shape [||] shape) + let shape = [10; 10; 10] in + let slice = Ndarray.Indexing.[L [0; 9; 5]; I 1; R' (2, 9, 1)] in + assert_equal [3; 1; 8] @@ Indexing.slice_shape slice shape; + assert_equal shape @@ Indexing.slice_shape [] shape) ; "cartesian product" >:: (fun _ -> let ll = [[1; 2]; [3; 8]; [9; 4]] in diff --git a/zarr/test/test_metadata.ml b/zarr/test/test_metadata.ml index e84677bd..0e34127d 100644 --- a/zarr/test/test_metadata.ml +++ b/zarr/test/test_metadata.ml @@ -1,162 +1,108 @@ open OUnit2 open Zarr -let flatten_fstring s = - String.(split_on_char ' ' s |> concat "" |> split_on_char '\n' |> concat "") - -let decode_bad_group_metadata ~str ~msg = - assert_raises - (Metadata.Parse_error msg) - (fun () -> Metadata.Group.decode str) +let flatten_fstring s = String.(split_on_char ' ' s |> concat "" |> split_on_char '\n' |> concat "") +let decode_bad_group_metadata ~str ~msg = assert_raises (Metadata.Parse_error msg) (fun () -> Metadata.Group.decode str) let group = [ "group metadata" >:: (fun _ -> let meta = Metadata.Group.default in - let expected = {|{"zarr_format":3,"node_type":"group"}|} in let got = Metadata.Group.encode meta in - assert_equal ~printer:Fun.id expected got; - - assert_equal ~printer:Metadata.Group.show meta @@ Metadata.Group.decode got; + assert_bool "should not fail" Metadata.Group.((encode meta |> decode) = meta); + assert_equal ~printer:Fun.id {|{"zarr_format":3,"node_type":"group"}|} got; + assert_equal ~printer:Metadata.Group.show meta Metadata.Group.(decode got); assert_raises - (Metadata.Parse_error "group metadata must contain a zarr_format field.") + (Metadata.Parse_error "metadata must contain a zarr_format field.") (fun () -> Metadata.Group.decode {|{"bad_json":0}|}); - - let meta' = - Metadata.Group.update_attributes - meta @@ `Assoc [("spam", `String "ham"); ("eggs", `Int 42)] - in - let expected = - {|{"zarr_format":3,"node_type":"group","attributes":{"spam":"ham","eggs":42}}|} - in - assert_equal expected @@ Metadata.Group.encode meta'; - + let meta' = Metadata.Group.update_attributes meta (`Assoc [("spam", `String "ham"); ("eggs", `Int 42)]) in + let expected = {|{"zarr_format":3,"node_type":"group","attributes":{"spam":"ham","eggs":42}}|} in + assert_equal expected (Metadata.Group.encode meta'); (* test bad zarr_format field value. *) - decode_bad_group_metadata - ~str:{|{"zarr_format":[],"node_type":"group"}|} - ~msg:"zarr_format field must be the integer 3."; - + decode_bad_group_metadata ~str:{|{"zarr_format":[],"node_type":"group"}|} ~msg:"zarr_format field must be the integer 3."; (* test missing node_type field or bad value. *) - decode_bad_group_metadata - ~str:{|{"zarr_format":3,"node_type":"ARRAY"}|} - ~msg:"node_type field must be 'group."; - decode_bad_group_metadata - ~str:{|{"zarr_format":3}|} - ~msg:"group metadata must contain a node_type field.") + decode_bad_group_metadata ~str:{|{"zarr_format":3,"node_type":"ARRAY"}|} ~msg:"node_type field must be 'group'."; + decode_bad_group_metadata ~str:{|{"zarr_format":3}|} ~msg:"group metadata must contain a node_type field.") ] -let test_array_metadata - : type a b. - ?dimension_names:string option list -> - shape:int array -> - chunks:int array -> - a Ndarray.dtype -> - b Ndarray.dtype -> - a -> - unit +let test_array_metadata : + type a b. + ?dimension_names:string option list -> + shape:int list -> + chunks:int list -> + a Ndarray.dtype -> + b Ndarray.dtype -> + a -> + unit = fun ?dimension_names ~shape ~chunks kind bad_kind fv -> let codecs = Codecs.Chain.create chunks [`Bytes LE] in - let meta = - match dimension_names with - | Some d -> - Metadata.Array.create ~codecs ~shape ~dimension_names:d kind fv chunks - | None -> - Metadata.Array.create ~codecs ~shape kind fv chunks + let meta = match dimension_names with + | Some d -> Metadata.Array.create ~codecs ~shape ~dimension_names:d kind fv chunks + | None -> Metadata.Array.create ~codecs ~shape kind fv chunks in - assert_bool - "should not fail" - Metadata.Array.(Metadata.Array.(encode meta |> decode) = meta); - assert_raises - (Metadata.Parse_error "array metadata must contain a zarr_format field.") - (fun () -> Metadata.Array.decode {|{"bad_json":0}|}); - - let show_int_array = [%show: int array] in - assert_equal ~printer:show_int_array shape @@ Metadata.Array.shape meta; - assert_equal ~printer:show_int_array chunks @@ Metadata.Array.chunk_shape meta; - let show_int_array_tuple = [%show: int array * int array] in - - assert_equal - ~printer:show_int_array_tuple - ([|1; 3; 1|], [|3; 1; 0|]) @@ - Metadata.Array.index_coord_pair meta [|8; 7; 6|]; - - assert_equal - ~printer:show_int_array_tuple - ([|2; 5; 1|], [|0; 0; 4|]) @@ - Metadata.Array.index_coord_pair meta [|10; 10; 10|]; - - assert_equal - ~printer:Fun.id - "c/2/5/1" @@ - Metadata.Array.chunk_key meta [|2; 5; 1|]; - - let indices = - [[|0; 0; 0|]; [|0; 0; 1|]; [|0; 1; 0|]; [|0; 1; 1|] - ;[|1; 0; 0|]; [|1; 0; 1|]; [|1; 1; 0|]; [|1; 1; 1|]] - in - assert_equal - ~printer:[%show: int array list] - indices @@ - Metadata.Array.chunk_indices meta [|10; 4; 10|]; - + assert_bool "should not fail" Metadata.Array.((encode meta |> decode) = meta); + let meta' = Metadata.Array.update_shape meta (10 :: shape) in + assert_equal ~msg:"should not be equal" false Metadata.Array.(meta' = meta); + let show_int_list = [%show: int list] in + assert_equal ~printer:show_int_list shape (Metadata.Array.shape meta); + assert_equal ~printer:show_int_list chunks (Metadata.Array.chunk_shape meta); + let show_int_list_tuple = [%show: int list * int list] in + assert_equal ~printer:show_int_list_tuple ([1; 3; 1], [3; 1; 0]) (Metadata.Array.index_coord_pair meta [8; 7; 6]); + assert_equal ~printer:show_int_list_tuple ([2; 5; 1], [0; 0; 4]) (Metadata.Array.index_coord_pair meta [10; 10; 10]); + assert_equal ~printer:Fun.id "c/2/5/1" (Metadata.Array.chunk_key meta [2; 5; 1]); + let indices = [[0; 0; 0]; [0; 0; 1]; [0; 1; 0]; [0; 1; 1] ;[1; 0; 0]; [1; 0; 1]; [1; 1; 0]; [1; 1; 1]] in + assert_equal ~printer:[%show: int list list] indices (Metadata.Array.chunk_indices meta [10; 4; 10]); assert_equal ~printer:[%show: string option list] (if dimension_names = None then [] else Option.get dimension_names) (Metadata.Array.dimension_names meta); - - assert_equal - ~printer:Yojson.Safe.show - `Null @@ - Metadata.Array.attributes meta; - + assert_equal ~printer:Yojson.Safe.show `Null (Metadata.Array.attributes meta); let attrs = `Assoc [("questions", `String "answer")] in - assert_equal - ~printer:Yojson.Safe.show - attrs - Metadata.Array.(attributes @@ update_attributes meta attrs); - - let new_shape = [|20; 10; 6|] in - assert_equal - ~printer:show_int_array - new_shape @@ - Metadata.Array.(shape @@ update_shape meta new_shape); - - assert_bool - "Using the correct kind must not fail this op" @@ - Metadata.Array.is_valid_kind meta kind; - - assert_bool - "Float32 is the only valid kind for this metadata" - (not @@ Metadata.Array.is_valid_kind meta bad_kind); - - assert_equal fv @@ Metadata.Array.fillvalue_of_kind meta kind; - - assert_raises - (Failure "kind is not compatible with node's fill value.") - (fun () -> Metadata.Array.fillvalue_of_kind meta bad_kind) - + assert_equal ~printer:Yojson.Safe.show attrs Metadata.Array.(attributes @@ update_attributes meta attrs); + let new_shape = [20; 10; 6] in + assert_equal ~printer:show_int_list new_shape Metadata.Array.(shape @@ update_shape meta new_shape); + assert_bool "Using the correct kind must not fail this op" Metadata.Array.(is_valid_kind meta kind); + assert_bool "Float32 is the only valid kind for this metadata" (not @@ Metadata.Array.is_valid_kind meta bad_kind); + assert_equal fv Metadata.Array.(fillvalue_of_kind meta kind); + assert_raises (Failure "kind is not compatible with node's fill value.") (fun () -> Metadata.Array.fillvalue_of_kind meta bad_kind); + assert_raises (Metadata.Parse_error "metadata must contain a zarr_format field.") (fun () -> Metadata.Array.decode {|{"bad_json":0}|}) + +let test_scalar_array_metadata () = + let codecs = Codecs.Chain.create [] [`Bytes LE] in + let meta = Metadata.Array.create ~codecs ~shape:[] Float32 0.0 [] in + assert_bool "should not fail" Metadata.Array.((encode meta |> decode) = meta); + let show_int_list = [%show: int list] in + assert_equal ~printer:show_int_list [] (Metadata.Array.shape meta); + assert_equal ~printer:show_int_list [] (Metadata.Array.chunk_shape meta); + let show_int_list_tuple = [%show: int list * int list] in + assert_equal ~printer:show_int_list_tuple ([], []) (Metadata.Array.index_coord_pair meta []); + assert_equal ~printer:[%show: int list list] [[]] (Metadata.Array.chunk_indices meta []) + (*assert_raises + (Metadata.Parse_error "dimension_names length and array dimensionality must be equal.") + (fun () -> Metadata.Array.create ~codecs ~dimension_names:[Some ""] ~shape:[] Float32 0.0 []) *) (* test decoding an ill-formed array metadata with an expected error message.*) -let decode_bad_array_metadata ~str ~msg = - assert_raises (Metadata.Parse_error msg) (fun () -> Metadata.Array.decode str) +let decode_bad_array_metadata ~str ~msg = assert_raises (Metadata.Parse_error msg) (fun () -> Metadata.Array.decode str) -let test_encode_decode_fill_value fv = - let str = Format.sprintf {|{ +let test_encode_decode_fill_value d f1 f2 f3 = + let fmt = Format.sprintf {|{ "zarr_format": 3, "shape": [10000, 1000], "node_type": "array", - "data_type": "float64", - "codecs": [ - {"name": "bytes", "configuration": {"endian": "big"}}], + "data_type": "%s", + "codecs": [{"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": %s, - "chunk_grid": - {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}, + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}, "chunk_key_encoding": {"name": "default"}, - "attributes": {"question": 7}}|} fv + "attributes": {"question": 7}, + "dimension_names": ["x", null]}|} in - assert_equal - ~printer:Fun.id - (flatten_fstring str) - (Metadata.Array.encode @@ Metadata.Array.decode str) + let str = fmt d f1 in + let meta = Metadata.Array.decode str in + let meta' = Metadata.Array.decode (fmt d f2) in + assert_equal false Metadata.Array.(meta = meta'); + assert_bool "Metadata must be equal to itself." Metadata.Array.(meta = meta); + assert_equal ~printer:Fun.id (flatten_fstring str) Metadata.Array.(encode meta); + assert_raises (Metadata.Parse_error "Unsupported fill value.") (fun () -> Metadata.Array.decode (fmt d f3)) let test_decode_encode_chunk_key name sep (key, exp_encode, exp_null) = let str = Format.sprintf {|{ @@ -170,16 +116,16 @@ let test_decode_encode_chunk_key name sep (key, exp_encode, exp_null) = "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}, "chunk_key_encoding": - {"name": %s, "configuration": {"separator": %s}}}|} name sep + {"name": %s, "configuration": {"separator": %s}}, + "attributes": {"question": 7}}|} name sep in let meta = Metadata.Array.decode str in - assert_equal ~printer:Fun.id exp_encode @@ Metadata.Array.chunk_key meta key; - assert_equal ~printer:Fun.id exp_null @@ Metadata.Array.chunk_key meta [||]; - assert_equal ~printer:Fun.id (flatten_fstring str) @@ Metadata.Array.encode meta + assert_equal ~printer:Fun.id exp_encode (Metadata.Array.chunk_key meta key); + assert_equal ~printer:Fun.id exp_null (Metadata.Array.chunk_key meta []); + assert_equal ~printer:Fun.id (flatten_fstring str) (Metadata.Array.encode meta) let array = [ "array metadata" >:: (fun _ -> - (* test missing zarr_format field and non-specific value. *) let str = {|{ "node_type": "array", @@ -192,8 +138,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"array metadata must contain a zarr_format field."; + decode_bad_array_metadata ~str:str ~msg:"metadata must contain a zarr_format field."; let str = {|{ "zarr_format": "3", "node_type": "array", @@ -206,9 +151,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"zarr_format field must be the integer 3."; - + decode_bad_array_metadata ~str:str ~msg:"zarr_format field must be the integer 3."; (* test missing node_type field or wrong value. *) let str = {|{ "zarr_format": 3, @@ -220,8 +163,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"array metadata must contain a node_type field."; + decode_bad_array_metadata ~str:str ~msg:"metadata must contain a node_type field."; let str = {|{ "zarr_format": 3, "node_type": "group", @@ -234,9 +176,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"node_type field must be 'array'."; - + decode_bad_array_metadata ~str:str ~msg:"node_type field must be 'array'."; (* test missing shape field, and incorrect values *) let str = {|{ "zarr_format": 3, @@ -250,8 +190,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"shape field list must only contain positive integers."; + decode_bad_array_metadata ~str:str ~msg:"shape field list must only contain positive integers."; let str = {|{ "zarr_format": 3, "node_type": "array", @@ -264,8 +203,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"shape field must be a list of integers."; + decode_bad_array_metadata ~str:str ~msg:"shape field must be a list of integers."; let str = {|{ "zarr_format": 3, "node_type": "array", @@ -277,9 +215,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"array metadata must contain a shape field."; - + decode_bad_array_metadata ~str:str ~msg:"array metadata must contain a shape field."; (* test missing codecs field or wrong codec config. *) let str = {|{ "zarr_format": 3, @@ -292,8 +228,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"codecs field must be a list of objects."; + decode_bad_array_metadata ~str:str ~msg:"codecs field must be a list of objects."; let str = {|{ "zarr_format": 3, "node_type": "array", @@ -304,9 +239,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"array metadata must contain a codecs field."; - + decode_bad_array_metadata ~str:str ~msg:"array metadata must contain a codecs field."; (* tests incorrect dimension_name field values and incorrect size. *) let str = {|{ "zarr_format": 3, @@ -321,8 +254,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"dimension_names must contain strings or null values."; + decode_bad_array_metadata ~str:str ~msg:"dimension_names must contain strings or null values."; let str = {|{ "zarr_format": 3, "node_type": "array", @@ -336,9 +268,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str - ~msg:"dimension_names length and array dimensionality must be equal."; + decode_bad_array_metadata ~str:str ~msg:"dimension_names length and array dimensionality must be equal."; let str = {|{ "zarr_format": 3, "node_type": "array", @@ -352,9 +282,7 @@ let array = [ "fill_value": "0x7fc00000", "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}}|} in - decode_bad_array_metadata - ~str:str ~msg:"dimension_names field must be a list."; - + decode_bad_array_metadata ~str:str ~msg:"dimension_names field must be a list."; (* test if storage transformer unsupported error is reported. *) let str = {|{ "zarr_format": 3, @@ -369,9 +297,7 @@ let array = [ "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [100, 10]}}, "storage_transformers": ["CACHE"]}|} in - decode_bad_array_metadata - ~str:str ~msg:"storage_transformers field is not yet supported."; - + decode_bad_array_metadata ~str:str ~msg:"storage_transformers field is not yet supported."; (* test missing chunk grid field. *) let str = {|{ "zarr_format": 3, @@ -383,9 +309,7 @@ let array = [ "codecs": [ {"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": "0x7fc00000"}|} in - decode_bad_array_metadata - ~str:str ~msg:"array metadata must contain a chunk_grid field."; - + decode_bad_array_metadata ~str:str ~msg:"array metadata must contain a chunk_grid field."; (* test if the decoding fails if regular grid chunk shape is empty, * has non-positive integer values or the grid name is unsupported *) let template = Format.sprintf {|{ @@ -401,17 +325,10 @@ let array = [ {"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": "0x7fc00000"}|} in - decode_bad_array_metadata - ~str:(template {|"regular"|} {|[1, 20, 20]|}) ~msg:"grid shape mismatch."; - decode_bad_array_metadata - ~str:(template {|"regular"|} {|[100000, 20]|}) ~msg:"grid shape mismatch."; - decode_bad_array_metadata - ~str:(template {|"regular"|} {|[-4, 4]|}) - ~msg:"chunk_shape must only contain positive ints."; - decode_bad_array_metadata - ~str:(template {|"UNKNOWN"|} {|[2, 4]|}) - ~msg:"Invalid Chunk grid name or configuration."; - + decode_bad_array_metadata ~str:(template {|"regular"|} {|[1, 20, 20]|}) ~msg:"grid shape mismatch."; + decode_bad_array_metadata ~str:(template {|"regular"|} {|[100000, 20]|}) ~msg:"grid shape mismatch."; + decode_bad_array_metadata ~str:(template {|"regular"|} {|[-4, 4]|}) ~msg:"chunk_shape must only contain positive ints."; + decode_bad_array_metadata ~str:(template {|"UNKNOWN"|} {|[2, 4]|}) ~msg:"Invalid Chunk grid name or configuration."; (* test if decoding a chunk key encoding field without a configuration leads to a default value being used. *) let str = {|{ @@ -427,17 +344,10 @@ let array = [ "chunk_key_encoding": {"name": "v2"}}|} in let meta = Metadata.Array.decode str in (* we except it to use the default "." separator. *) - assert_equal - ~printer:Fun.id "2.0.1" @@ Metadata.Array.chunk_key meta [|2; 0; 1|]; - (* we expect the default (unspecified) config seperator to be - dropped when serializing the metadata to JSON format. *) - assert_equal - ~printer:Fun.id - Yojson.Safe.(from_string str |> to_string) @@ - Metadata.Array.encode meta; - - (* test if the decoding fails if chunk key encoding contains unknown - * separator or name. *) + assert_equal ~printer:Fun.id "2.0.1" Metadata.Array.(chunk_key meta [2; 0; 1]); + (* we expect the default (unspecified) config seperator to be dropped when serializing the metadata to JSON format. *) + assert_equal ~printer:Fun.id Yojson.Safe.(from_string str |> to_string) Metadata.Array.(encode meta); + (* test if the decoding fails if chunk key encoding contains unknown separator or name.*) let str = {|{ "zarr_format": 3, "node_type": "array", @@ -448,9 +358,7 @@ let array = [ "codecs": [ {"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": "0x7fc00000"}|} in - decode_bad_array_metadata - ~str ~msg:"array metadata must contain a chunk_key_encoding field."; - + decode_bad_array_metadata ~str ~msg:"array metadata must contain a chunk_key_encoding field."; let template = Format.sprintf {|{ "zarr_format": 3, "node_type": "array", @@ -464,15 +372,9 @@ let array = [ {"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": ["Infinity", "0x7fc00000"]}|} in - decode_bad_array_metadata - ~str:(template {|"default"|} {|"_"|}) - ~msg:"Invalid chunk key encoding configuration."; - decode_bad_array_metadata - ~str:(template {|"V3"|} {|"."|}) - ~msg:"Invalid chunk key encoding configuration."; - - (* test if the decoding fails if data type is missing not - a string or unsupported *) + decode_bad_array_metadata ~str:(template {|"default"|} {|"_"|}) ~msg:"Invalid chunk key encoding configuration."; + decode_bad_array_metadata ~str:(template {|"V3"|} {|"."|}) ~msg:"Invalid chunk key encoding configuration."; + (* test if the decoding fails if data type is missing not a string or unsupported *) decode_bad_array_metadata ~str:{|{ "zarr_format": 3, @@ -486,7 +388,7 @@ let array = [ "codecs": [ {"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": "NaN"}|} - ~msg:"data_type field must be a string."; + ~msg:"Unsupported metadata data_type"; decode_bad_array_metadata ~str:{|{ "zarr_format": 3, @@ -514,7 +416,6 @@ let array = [ {"name": "bytes", "configuration": {"endian": "big"}}], "fill_value": "NaN"}|} ~msg:"Unsupported metadata data_type"; - (* test missing fill_value field. *) decode_bad_array_metadata ~str:{|{ @@ -532,20 +433,30 @@ let array = [ (* test if the JSON document fill value form is preserved when decoding * and encoding back into a JSON. * See: https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#fill-value *) - test_encode_decode_fill_value {|false|}; - test_encode_decode_fill_value {|0|}; - test_encode_decode_fill_value {|5.9|}; - test_encode_decode_fill_value {|"Infinity"|}; - test_encode_decode_fill_value {|"-Infinity"|}; - test_encode_decode_fill_value {|"NaN"|}; - test_encode_decode_fill_value {|"?"|}; - test_encode_decode_fill_value {|"0x7fc00000"|}; - test_encode_decode_fill_value {|[10, 0]|}; - test_encode_decode_fill_value {|[10.0, 0.0]|}; - test_encode_decode_fill_value {|["0x7fc00000", "0x7fc00000"]|}; - test_encode_decode_fill_value {|["0x7fc00000", "NaN"]|}; - test_encode_decode_fill_value {|["Infinity", "0x7fc00000"]|}; - test_encode_decode_fill_value {|["NaN", "Infinity"]|}; + test_encode_decode_fill_value "char" {|"?"|} {|"-"|} {|"??"|}; + test_encode_decode_fill_value "bool" {|false|} {|true|} {|"??"|}; + test_encode_decode_fill_value "bool" {|true|} {|false|} {|"??"|}; + test_encode_decode_fill_value "int8" {|-10|} {|10|} {|5000|}; + test_encode_decode_fill_value "uint8" {|0|} {|255|} {|-1|}; + test_encode_decode_fill_value "int16" {|-1000|} {|10|} {|50000|}; + test_encode_decode_fill_value "uint16" {|0|} {|1|} {|-10.5|}; + test_encode_decode_fill_value "int32" {|0|} {|1|} {|21474836475|}; + test_encode_decode_fill_value "int" {|0|} {|1|} {|-4611686018427387909|}; + test_encode_decode_fill_value "int64" {|-4611686018427387909|} {|0|} {|18446744073709551619|}; + test_encode_decode_fill_value "uint64" {|4611686018427387909|} {|0|} {|18446744073709551619|}; + test_encode_decode_fill_value "float32" {|1|} {|0.0|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "float32" {|-4611686018427387908|} {|-4611686018427387909|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "float32" {|"Infinity"|} {|-4611686018427387909|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "float32" {|"NaN"|} {|"-Infinity"|} {|"0x2a032f00000000000000000000000000"|}; + test_encode_decode_fill_value "float32" {|"0x7fc00000"|} {|"-Infinity"|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "float64" {|0.0|} {|1|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "float64" {|"Infinity"|} {|-4611686018427387909|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "float64" {|"-Infinity"|} {|"NaN"|} {|"0x2a032f00000000000000000000000000"|}; + test_encode_decode_fill_value "float64" {|"-Infinity"|} {|"0x7fc00000"|} {|"adlalkjdald"|}; + test_encode_decode_fill_value "complex32" {|[1, 2]|} {|[-4611686018427387909, -4611686018427387909]|} {|[1, 0.5]|}; + test_encode_decode_fill_value "complex32" {|[1.0, 2.0]|} {|["Infinity", "NaN"]|} {|[1, 0.5]|}; + test_encode_decode_fill_value "complex64" {|[-4611686018427387909, -4611686018427387909]|} {|[1, 2]|} {|[1, 0.5]|}; + test_encode_decode_fill_value "complex64" {|["Infinity", "NaN"]|} {|[1.0, 2.0]|} {|[1, 0.5]|}; (* tests decoding failure of unsupported fill value. *) let template = Format.sprintf {|{ "zarr_format": 3, @@ -561,163 +472,53 @@ let array = [ {"name": "default", "configuration": {"separator": "."}}}|} in (* we dont support float literals as strings *) - decode_bad_array_metadata - ~str:(template {|["0.5", "5.0"]|}) - ~msg:"Unsupported fill value."; - decode_bad_array_metadata - ~str:(template {|["Infinity", "?"]|}) - ~msg:"Unsupported fill value."; + decode_bad_array_metadata ~str:(template {|["0.5", "5.0"]|}) ~msg:"Unsupported fill value."; + decode_bad_array_metadata ~str:(template {|["Infinity", "?"]|}) ~msg:"Unsupported fill value."; (* a complex number cannot be a list with less or more than 2 elements. *) - decode_bad_array_metadata - ~str:(template {|[1, 4, 3]|}) - ~msg:"Unsupported fill value."; - + decode_bad_array_metadata ~str:(template {|[1, 4, 3]|}) ~msg:"Unsupported fill value."; (* Test correctness of chunk-key encoding of keys. *) - test_decode_encode_chunk_key - {|"default"|} {|"/"|} ([|5; 32; 4|], "c/5/32/4", "c"); - test_decode_encode_chunk_key - {|"default"|} {|"."|} ([|5; 32; 4|], "c.5.32.4", "c"); - test_decode_encode_chunk_key - {|"v2"|} {|"/"|} ([|5; 32; 4|], "5/32/4", "0"); - test_decode_encode_chunk_key - {|"v2"|} {|"."|} ([|5; 32; 4|], "5.32.4", "0"); - - let shape = [|10; 10; 10|] in - let chunks = [|5; 2; 6|] in + test_decode_encode_chunk_key {|"default"|} {|"/"|} ([5; 32; 4], "c/5/32/4", "c"); + test_decode_encode_chunk_key {|"default"|} {|"."|} ([5; 32; 4], "c.5.32.4", "c"); + test_decode_encode_chunk_key {|"v2"|} {|"/"|} ([5; 32; 4], "5/32/4", "0"); + test_decode_encode_chunk_key {|"v2"|} {|"."|} ([5; 32; 4], "5.32.4", "0"); + let shape = [10; 10; 10] in + let chunks = [5; 2; 6] in let dimension_names = [Some "x"; None; Some "z"] in - (* tests using bool data type. *) - test_array_metadata - ~shape - ~chunks - Ndarray.Bool - Ndarray.Float32 - false; - + test_array_metadata ~shape ~chunks Ndarray.Bool Ndarray.Float32 false; (* tests using char data type. *) - test_array_metadata - ~shape - ~chunks - Ndarray.Char - Ndarray.Float32 - '?'; - + test_array_metadata ~shape ~chunks Ndarray.Char Ndarray.Float32 '?'; (* tests using int8 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Int8 - Ndarray.Float32 - 0; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Int8 Ndarray.Float32 0; (* tests using uint8 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Uint8 - Ndarray.Float32 - 0; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Uint8 Ndarray.Float32 0; (* tests using int16 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Int16 - Ndarray.Float32 - 0; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Int16 Ndarray.Float32 0; (* tests using uint16 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Uint16 - Ndarray.Float32 - 0; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Uint16 Ndarray.Float32 0; (* tests using int32 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Int32 - Ndarray.Float32 - Int32.max_int; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Int32 Ndarray.Float32 Int32.max_int; (* tests using int64 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Int64 - Ndarray.Float32 - 0L; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Int64 Ndarray.Float32 0L; + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Int64 Ndarray.Float32 Int64.max_int; (* tests using uint64 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Uint64 - Ndarray.Float32 - Stdint.Uint64.max_int; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Uint64 Ndarray.Float32 Stdint.Uint64.min_int; + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Uint64 Ndarray.Float32 Stdint.Uint64.max_int; (* tests using float32 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Float32 - Ndarray.Int - Float.neg_infinity; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Float32 Ndarray.Int Float.neg_infinity; (* tests using float64 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Float64 - Ndarray.Int - Float.neg_infinity; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Float64 Ndarray.Int Float.neg_infinity; (* tests using complex32 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Complex32 - Ndarray.Float32 - Complex.zero; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Complex32 Ndarray.Float32 Complex.zero; (* tests using complex64 data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Complex64 - Ndarray.Float32 - Complex.zero; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Complex64 Ndarray.Float32 Complex.zero; (* tests using int data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Int - Ndarray.Float32 - Int.max_int; - + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Int Ndarray.Float32 Int.max_int; (* tests using nativeint data type. *) - test_array_metadata - ~dimension_names - ~shape - ~chunks - Ndarray.Nativeint - Ndarray.Float32 - 0n) + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Nativeint Ndarray.Float32 0n; + test_array_metadata ~dimension_names ~shape ~chunks Ndarray.Nativeint Ndarray.Float32 Nativeint.max_int; + (* tests correctness when using an array with empty shape (scalar). *) + test_scalar_array_metadata ()); ] let tests = group @ array diff --git a/zarr/test/test_ndarray.ml b/zarr/test/test_ndarray.ml index 0f47e890..cca7cfe1 100644 --- a/zarr/test/test_ndarray.ml +++ b/zarr/test/test_ndarray.ml @@ -7,24 +7,24 @@ let run_test : let x = M.create repr.kind repr.shape fv in assert_equal repr.shape (M.shape x); - let num_elt = Array.fold_left Int.mul 1 repr.shape in + let num_elt = List.fold_left Int.mul 1 repr.shape in assert_equal (num_elt * is) (M.byte_size x); assert_equal num_elt (M.size x); assert_equal is (M.dtype_size @@ M.data_type x); - assert_equal (Array.length repr.shape) (M.ndims x); + assert_equal (List.length repr.shape) (M.ndims x); let y = M.init repr.kind repr.shape (Fun.const fv) in assert_equal x y; M.fill y fv; assert_equal x y; - assert_equal fv (M.get x [|0; 0; 0|]); - M.set x [|0; 0; 0|] fv; + assert_equal fv (M.get x [0; 0; 0]); + M.set x [0; 0; 0] fv; assert_bool "" @@ M.equal x y; M.iteri (fun _ v -> ignore v) x let tests = [ "test char ndarray" >:: (fun _ -> - let shape = [|2; 5; 3|] in + let shape = [2; 5; 3] in run_test {shape; kind = M.Char} '?' 1; run_test {shape; kind = M.Bool} false 1; @@ -55,31 +55,31 @@ let tests = [ ) ; "test map, iter and fold" >:: (fun _ -> - let shape = [|2; 5; 3|] in + let shape = [2; 5; 3] in let x = M.create Int32 shape 0l in let x' = M.map (Int32.add 1l) x in - assert_equal 1l (M.get x' [|0;0;0|]); + assert_equal 1l (M.get x' [0;0;0]); - let x = M.create Char [|4|] '?' in + let x = M.create Char [4] '?' in let buf = Buffer.create @@ M.byte_size x in M.iter (Buffer.add_char buf) x; assert_equal ~printer:Fun.id "????" (Buffer.contents buf); ) ; "test transpose functionality" >:: (fun _ -> - let shape = [|2; 1; 3|] - and axes = [|2; 0; 1|] + let shape = [2; 1; 3] + and axes = [2; 0; 1] and a = [|0.15458236; 0.94363903; 0.63893012; 0.29207497; 0.31390295; 0.42341309|] in let x = M.of_array Float32 shape a in let x' = M.transpose ~axes x in - assert_equal ~printer:[%show: int array] [|3; 2; 1|] (M.shape x'); + assert_equal ~printer:[%show: int list] [3; 2; 1] (M.shape x'); (* test if a particular value is transposed correctly. *) - assert_equal ~printer:string_of_float (M.get x [|1; 0; 2|]) (M.get x' [|2; 1; 0|]); + assert_equal ~printer:string_of_float (M.get x [1; 0; 2]) (M.get x' [2; 1; 0]); let flat_exp = [|0.15458236; 0.29207497; 0.94363903; 0.31390295; 0.63893012; 0.42341309|] in assert_equal ~printer:[%show: float array] flat_exp (M.to_array x'); - let inv_order = Array.(make (length axes) 0) in - Array.iteri (fun i x -> inv_order.(x) <- i) axes; - assert_equal true @@ M.equal x (M.transpose ~axes:inv_order x') + let inv_order = Array.(make (List.length axes) 0) in + List.iteri (fun i x -> inv_order.(x) <- i) axes; + assert_equal true @@ M.equal x (M.transpose ~axes:(Array.to_list inv_order) x') ) ; "test interop with bigarrays" >:: (fun _ -> @@ -89,7 +89,7 @@ let tests = [ let convert_to : type a b. a M.dtype -> (a, b) B.kind -> a -> unit = fun fromdtype todtype fv -> - let x = M.create fromdtype s fv in + let x = M.create fromdtype (Array.to_list s) fv in let y = M.to_bigarray x todtype in assert_equal s (B.Genarray.dims y); assert_equal fv (B.Genarray.get y [|0; 0; 0|]); @@ -109,17 +109,15 @@ let tests = [ convert_to M.Int B.Int Int.max_int; convert_to M.Nativeint B.Nativeint Nativeint.max_int; - let showarray = [%show: int array] in - let convert_from : type a b c. (a, b, c) B.Genarray.t -> a M.dtype -> unit = fun x dtype -> let y = M.of_bigarray x in assert_equal dtype (M.data_type y); assert_equal - ~printer:showarray - (Array.of_list @@ List.rev @@ Array.to_list @@ B.Genarray.dims x) + ~printer:[%show: int list] + (List.rev @@ Array.to_list @@ B.Genarray.dims x) (M.shape y); - assert_equal (B.Genarray.get x [|1; 1; 1|]) (M.get y [|0; 0; 0|]) + assert_equal (B.Genarray.get x [|1; 1; 1|]) (M.get y [0; 0; 0]) in convert_from (B.Genarray.create Char Fortran_layout s) M.Char; convert_from (B.Genarray.create Int8_signed Fortran_layout s) M.Int8;