module type OrderedType = sig type t val compare: t -> t -> int end module Make(Ord: OrderedType) = struct type elt = Ord.t type t = Empty | Node2 of t * elt * t | Node3 of t * elt * t * elt * t | Node4 of t * elt * t * elt * t * elt * t | Node5 of t * elt * t * elt * t * elt * t * elt * t (* Only intermediate, never in final result *) let empty = Empty let rec mem x = function Empty -> false | Node2(t1,d1,t2) -> let c = Ord.compare x d1 in c = 0 || mem x (if c < 0 then t1 else t2) | Node3(t1,d1,t2,d2,t3) -> let c = Ord.compare x d1 in c = 0 || if c < 0 then mem x t1 else let c = Ord.compare x d2 in c = 0 || mem x (if c < 0 then t2 else t3) | Node4(t1,d1,t2,d2,t3,d3,t4) -> let c = Ord.compare x d2 in c = 0 || if c < 0 then let c = Ord.compare x d1 in c = 0 || mem x (if c < 0 then t1 else t2) else let c = Ord.compare x d3 in c = 0 || mem x (if c < 0 then t3 else t4) | Node5 _ -> assert false let bNode2 t1 d1 t2 = match t1, t2 with | Node5(t11,d11,t12,d12,t13,d13,t14,d14,t15), _ -> Node3(Node3(t11,d11,t12,d12,t13),d13,Node2(t14,d14,t15),d1,t2) | _, Node5(t21,d21,t22,d22,t23,d23,t24,d24,t25) -> Node3(t1,d1,Node2(t21,d21,t22),d22,Node3(t23,d23,t24,d24,t25)) | Node2(Empty,d11,Empty),Empty -> Node3(Empty,d11,Empty,d1,Empty) | Empty, Node2(Empty,d21,Empty) -> Node3(Empty,d1,Empty,d21,Empty) | _ -> Node2(t1,d1,t2) let bNode3 t1 d1 t2 d2 t3 = match t1, t2, t3 with | Node5(t11,d11,t12,d12,t13,d13,t14,d14,t15), _, _ -> Node4(Node3(t11,d11,t12,d12,t13),d13,Node2(t14,d14,t15),d1,t2,d2,t3) | _, Node5(t21,d21,t22,d22,t23,d23,t24,d24,t25), _ -> Node4(t1,d1,Node3(t21,d21,t22,d22,t23),d23,Node2(t24,d24,t25),d2,t3) | _, _, Node5(t31,d31,t32,d32,t33,d33,t34,d34,t35) -> Node4(t1,d1,t2,d2,Node3(t31,d31,t32,d32,t33),d33,Node2(t34,d34,t35)) | Node2(Empty,d11,Empty),Empty, Empty -> Node4(Empty,d11,Empty,d1,Empty,d2,Empty) | Empty, Node2(Empty,d21,Empty),Empty -> Node4(Empty,d1,Empty,d21,Empty,d2,Empty) | Empty, Empty, Node2(Empty,d31,Empty) -> Node4(Empty,d1,Empty,d2,Empty,d31,Empty) | _ -> Node3(t1,d1,t2,d2,t3) let bNode4 t1 d1 t2 d2 t3 d3 t4 = match t1, t2, t3, t4 with | Node5(t11,d11,t12,d12,t13,d13,t14,d14,t15), _, _, _ -> Node5(Node3(t11,d11,t12,d12,t13),d13,Node2(t14,d14,t15),d1,t2,d2,t3,d3,t4) | _, Node5(t21,d21,t22,d22,t23,d23,t24,d24,t25), _, _ -> Node5(t1,d1,Node3(t21,d21,t22,d22,t23),d23,Node2(t24,d24,t25),d2,t3,d3,t4) | _, _, Node5(t31,d31,t32,d32,t33,d33,t34,d34,t35), _ -> Node5(t1,d1,t2,d2,Node3(t31,d31,t32,d32,t33),d33,Node2(t34,d34,t35),d3,t4) | _, _, _, Node5(t41,d41,t42,d42,t43,d43,t44,d44,t45)-> Node5(t1,d1,t2,d2,t3,d3,Node3(t41,d41,t42,d42,t43),d43,Node2(t44,d44,t45)) | Node2(Empty,d11,Empty),Empty, Empty, Empty -> Node5(Empty,d11,Empty,d1,Empty,d2,Empty,d3,Empty) | Empty, Node2(Empty,d21,Empty),Empty, Empty -> Node5(Empty,d1,Empty,d21,Empty,d2,Empty,d3,Empty) | Empty, Empty, Node2(Empty,d31,Empty), Empty -> Node5(Empty,d1,Empty,d2,Empty,d31,Empty,d3,Empty) | Empty, Empty, Empty, Node2(Empty,d41,Empty) -> Node5(Empty,d1,Empty,d2,Empty,d3,Empty,d41,Empty) | _ -> Node4(t1,d1,t2,d2,t3,d3,t4) let treat_root = function | Node5(t11,d11,t12,d12,t13,d13,t14,d14,t15) -> Node2(Node3(t11,d11,t12,d12,t13),d13,Node2(t14,d14,t15)) | t -> t let add x s = let rec fn s = match s with Empty -> Node2(Empty,x,Empty) | Node2(t1,d1,t2) -> let c = Ord.compare x d1 in if c = 0 then s else if c < 0 then bNode2 (fn t1) d1 t2 else bNode2 t1 d1 (fn t2) | Node3(t1,d1,t2,d2,t3) -> let c = Ord.compare x d1 in if c = 0 then s else if c < 0 then bNode3 (fn t1) d1 t2 d2 t3 else let c = Ord.compare x d2 in if c = 0 then s else if c < 0 then bNode3 t1 d1 (fn t2) d2 t3 else bNode3 t1 d1 t2 d2 (fn t3) | Node4(t1,d1,t2,d2,t3,d3,t4) -> let c = Ord.compare x d2 in if c = 0 then s else if c < 0 then let c = Ord.compare x d1 in if c = 0 then s else if c < 0 then bNode4 (fn t1) d1 t2 d2 t3 d3 t4 else bNode4 t1 d1 (fn t2) d2 t3 d3 t4 else let c = Ord.compare x d3 in if c = 0 then s else if c < 0 then bNode4 t1 d1 t2 d2 (fn t3) d3 t4 else bNode4 t1 d1 t2 d2 t3 d3 (fn t4) | Node5 _ -> assert false in treat_root (fn s) let rec cardinal = function Empty -> 0 | Node2(t1,_,t2) -> cardinal t1 + cardinal t2 + 1 | Node3(t1,_,t2,_,t3) -> cardinal t1 + cardinal t2 + cardinal t3 + 2 | Node4(t1,_,t2,_,t3,_,t4) -> cardinal t1 + cardinal t2 + cardinal t3 + cardinal t4 + 3 | Node5 _ -> assert false let rec test_height t = match t with Empty -> 0 | Node2(t1,_,t2) -> let h1 = test_height t1 in let h2 = test_height t2 in assert (h1 = h2); h1 + 1 | Node3(t1,_,t2,_,t3) -> let h1 = test_height t1 in let h2 = test_height t2 in let h3 = test_height t3 in assert (h1 = h2); assert (h1 = h3); h1 + 1 | Node4(t1,_,t2,_,t3,_,t4) -> let h1 = test_height t1 in let h2 = test_height t2 in let h3 = test_height t3 in let h4 = test_height t4 in assert (h1 = h2); assert (h1 = h3); assert (h1 = h4); h1 + 1 | Node5 _ -> assert false end