Skip to content

Commit

Permalink
Broken: abandon localized / cross-task rebalanced
Browse files Browse the repository at this point in the history
nodes or tensors.
  • Loading branch information
lukstafi committed May 29, 2023
1 parent a32a3d4 commit 298abb7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 115 deletions.
101 changes: 12 additions & 89 deletions lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ let to_low_level (code : t) : unit low_level =
let for_loops =
for_loop [] (Array.to_list projections.product_space, Array.to_list projections.product_iterators)
in
let s = Some ("Computing node " ^ NodeUI.tensor_ptr_name lhs) in
(* In case the computation is neither virtual nor parallelized, it would be invalid to replicate it. *)
let s = Comment ("Computing node " ^ NodeUI.tensor_ptr_name lhs) in
(* Note: it might be invalid to replicate computation across tasks. *)
if zero_out then
Rebalance (s, [| Lines [| loop (Fetch { tensor = lhs; fetch_op = Constant 0. }); for_loops |] |])
else Rebalance (s, [| for_loops |])
Lines [| s; loop (Fetch { tensor = lhs; fetch_op = Constant 0. }); for_loops |]
else Lines [| s; for_loops |]
| Accum_unop { zero_out; accum; op; lhs; rhs; projections } ->
let projections = projections () in
let lhs_idx = Shape.(derive_index projections.product_iterators projections.project_lhs) in
Expand Down Expand Up @@ -245,11 +245,11 @@ let to_low_level (code : t) : unit low_level =
let for_loops =
for_loop [] (Array.to_list projections.product_space, Array.to_list projections.product_iterators)
in
let s = Some ("Computing node " ^ NodeUI.tensor_ptr_name lhs) in
(* In case the computation is neither virtual nor parallelized, it would be invalid to replicate it. *)
let s = Comment ("Computing node " ^ NodeUI.tensor_ptr_name lhs) in
(* Note: it might be invalid to replicate computation across tasks. *)
if zero_out then
Rebalance (s, [| Lines [| loop (Fetch { tensor = lhs; fetch_op = Constant 0. }); for_loops |] |])
else Rebalance (s, [| for_loops |])
Lines [| s; loop (Fetch { tensor = lhs; fetch_op = Constant 0. }); for_loops |]
else Lines [| s; for_loops |]
| Noop -> Lines [||]
| Block_comment (s, (Par _ as c)) -> loop_par ~s c
| Block_comment (s, (ParHint _ as c)) when !force_unsafe_parhint -> loop_par ~s c
Expand Down Expand Up @@ -528,7 +528,7 @@ let interpret_code ?task_id llc =
loop_proc (Map.Poly.empty, Map.Poly.empty) llc

let code_sexp_margin = ref 200

let fprint_code ppf c =
(* TODO: something nicely concise. *)
Caml.Format.pp_set_margin ppf !code_sexp_margin;
Expand Down Expand Up @@ -650,62 +650,6 @@ let partition_tf_with_comment cs ~f =
in
(trues, falses)

let localize_tensors store ~for_task_id llc =
let for_task = Some for_task_id in
let debug = ref "" in
let rec loop = function
| Lines llcs -> Array.iter ~f:loop llcs
| For_loop { body; _ } | Dynamic_indices { body; _ } -> loop body
| Rebalance (_, cs) -> Array.iter ~f:loop cs
| Set (ptr, _, llv) ->
let n = NodeUI.get ptr.id in
if Option.is_some n.localized_to then assert ([%equal: int option] n.localized_to for_task)
else n.localized_to <- Some for_task_id;
let never_device_only =
match ptr.field with
| NodeUI.Value -> n.value_never_device_only
| NodeUI.Grad -> n.grad_never_device_only
in
if never_device_only then (
debug := NodeUI.tensor_ptr_name ptr;
loop_float llv)
| Set_local (_, llv) -> loop_float llv
| Comment _ -> ()
| If_task_id_is { for_task_id = id2; body; _ } ->
assert (id2 = for_task_id);
loop body
and loop_float = function
| Local_scope { body; _ } -> loop body
| Get_local _ | Get_global _ -> ()
| Get (ptr, _) ->
let n = NodeUI.get ptr.id in
let dn = get_node store ptr in
dn.non_device_only <- true;
n.read_by_localized <- for_task_id :: n.read_by_localized;
n.debug_read_by_localized <- !debug :: n.debug_read_by_localized
| Constant _ -> ()
| Binop (_, v1, v2) ->
loop_float v1;
loop_float v2
| Unop (_, v) -> loop_float v
in
loop llc

let rebalance_across_tasks = ref false

let rebalance store llcs =
if not !rebalance_across_tasks then (
let for_task_id = 0 in
let body = Lines (flat_lines llcs) in
localize_tensors store ~for_task_id body;
If_task_id_is { for_task_id; body })
else
let tasks =
Array.mapi llcs ~f:(fun task body ->
If_task_id_is { for_task_id = task % !Shape.num_parallel_tasks; body })
in
Lines tasks

let rec has_parallel_dim : type a. a low_level -> bool = function
| Comment _ -> false
| Lines ls -> Array.exists ~f:has_parallel_dim ls
Expand Down Expand Up @@ -1109,30 +1053,9 @@ let cleanup_virtual_llc traced_store reverse_node_map (llc : unit low_level) : u
| None ->
Option.map ~f:(fun body -> For_loop { for_config with body }) @@ loop_proc ~balanced ~env_dom body
)
| Rebalance (s, cs) when balanced -> (
let cs = flat_lines @@ Array.filter_map cs ~f:loop in
match (s, cs) with
| _, [||] -> None
| None, [| c |] -> Some c
| _, cs ->
let c = Array.map ~f:(fun s -> Comment s) @@ Option.to_array s in
Some (Lines (Array.append c cs)))
| Rebalance (s, cs) -> (
(* Don't flatten lines before rebalancing: keep elements of [cs] as single units. *)
let multitask, unitask = partition_tf_with_comment ~f:has_parallel_dim cs in
let rebalanced =
let unitask = Array.filter_map unitask ~f:(loop_proc ~balanced:true ~env_dom) in
if Array.is_empty unitask then None else Some (rebalance traced_store unitask)
in
let multitask = flat_lines @@ Array.filter_map ~f:loop multitask in
if Array.is_empty multitask && Option.is_none rebalanced then None
else
match s with
| None -> Some (Lines (Array.append (Option.to_array rebalanced) multitask))
| Some s ->
Some
(Lines (flat_lines @@ Array.concat [ [| Comment s |]; Option.to_array rebalanced; multitask ]))
)
| Rebalance (s, cs) ->
let cs = Array.filter_map cs ~f:loop in
if Array.is_empty cs then None else Some (Rebalance (s, cs))
| If_task_id_is { for_task_id; body } ->
Option.map ~f:(fun body -> If_task_id_is { for_task_id; body }) @@ loop body
| Set (tensor, indices, llv) ->
Expand Down
26 changes: 15 additions & 11 deletions lib/exec_as_gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type state = {
traced_store : Code.traced_store;
task_init_block : Gccjit.block;
task_finalize_block : Gccjit.block;
localized_finalize_block : Gccjit.block;
replicated_finalize_block : Gccjit.block;
}

let jit_array_offset ctx ~idcs ~dims =
Expand All @@ -53,7 +53,7 @@ let jit_array_offset ctx ~idcs ~dims =
@@ RValue.binary_op ctx Mult c_index offset (RValue.int ctx c_index dim))

let get_tensor
{ ctx; func; tensors; traced_store; task_init_block; task_finalize_block; localized_finalize_block }
{ ctx; func; tensors; traced_store; task_init_block; task_finalize_block; replicated_finalize_block }
~jit_code ~host_idcs ptr : tensor =
let open Gccjit in
Hashtbl.find_or_add tensors ptr ~default:(fun () ->
Expand Down Expand Up @@ -112,7 +112,7 @@ let get_tensor
RValue.int ctx c_index device_size_in_bytes;
];
if is_parallel || not update_on_host then
Block.eval (if is_parallel then task_finalize_block else localized_finalize_block)
Block.eval (if is_parallel then task_finalize_block else replicated_finalize_block)
@@ RValue.call ctx (Function.builtin ctx "memcpy")
[
cast_void @@ LValue.address lhs;
Expand Down Expand Up @@ -219,7 +219,10 @@ let jit_code ~name ~env ~task_id ({ ctx; func; _ } as state) initial_block (body
We also need unique ids for computation ordering lvalues. *)
let uid = ref 0 in
let get_uid () =
let id = Int.incr uid; !uid in
let id =
Int.incr uid;
!uid
in
Int.to_string id
in
let locals = ref Map.Poly.empty in
Expand Down Expand Up @@ -255,9 +258,10 @@ let jit_code ~name ~env ~task_id ({ ctx; func; _ } as state) initial_block (body
match body with
| Code.Lines lines ->
Array.iteri lines ~f:(fun i line -> loop ~name:(name ^ "_at_line_" ^ Int.to_string i) line)
| For_loop { index; from_; to_; body; trace_it = _ } -> jit_for_loop ~env index ~from_ ~to_ (Either.First body)
| For_loop { index; from_; to_; body; trace_it = _ } ->
jit_for_loop ~env index ~from_ ~to_ (Either.First body)
| Rebalance (_, cs) ->
(* FIXME: NOT IMPLEMENTED YET *)
(* This backend does not implement a relevant form of fine-grain parallelism. *)
Array.iteri cs ~f:(fun i line -> loop ~name:(name ^ "_at_par_line_" ^ Int.to_string i) line)
| If_task_id_is { for_task_id = _; body } when !Shape.num_parallel_tasks <= 1 -> loop ~name body
| If_task_id_is { for_task_id; body } ->
Expand Down Expand Up @@ -451,7 +455,7 @@ let jit_func ~name ctx (traced_store, proc) =
let func = Function.create ctx fkind (Type.get ctx Void) name [ task_id ] in
let task_init_block = Block.create ~name:("init_" ^ name) func in
let task_finalize_block = Block.create ~name:("finalize_" ^ name) func in
let localized_finalize_block = Block.create ~name:("finalize_localized_" ^ name) func in
let replicated_finalize_block = Block.create ~name:("finalize_replicated_" ^ name) func in
let main_block = Block.create ~name func in
let state =
{
Expand All @@ -460,18 +464,18 @@ let jit_func ~name ctx (traced_store, proc) =
traced_store;
task_init_block;
task_finalize_block;
localized_finalize_block;
replicated_finalize_block;
tensors = Hashtbl.Poly.create ();
}
in
let after_proc = jit_code ~name ~env ~task_id state main_block proc in
Block.jump task_init_block main_block;
Block.jump after_proc task_finalize_block;
let c_index = Type.get ctx Type.Int in
let b_after_if = Block.create ~name:("after_finalize_localized_" ^ name) func in
let b_after_if = Block.create ~name:("after_finalize_replicated_" ^ name) func in
let guard = RValue.comparison ctx Eq (RValue.param task_id) (RValue.zero ctx c_index) in
Block.cond_jump task_finalize_block guard localized_finalize_block (* on true *) b_after_if (* on false *);
Block.jump localized_finalize_block b_after_if;
Block.cond_jump task_finalize_block guard replicated_finalize_block (* on true *) b_after_if (* on false *);
Block.jump replicated_finalize_block b_after_if;
Block.return_void b_after_if;
if !Code.with_debug then
let suf = "-gccjit-debug.c" in
Expand Down
16 changes: 1 addition & 15 deletions lib/nodeUI.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ type t = {
- [node.grad] is always [None]. *)
mutable backend_info : string;
(** Information about e.g. the memory strategy that the most recent backend chose for the tensor. *)
mutable localized_to : int option;
(** The ID of the task to which the tensor is localized. A non-none value by itself does not guarantee
that all of the tensor's computations are localized to a single task, only that those which are
only use the given task. *)
mutable read_by_localized : int list;
(** Tasks from which this tensor is read by localized computations. *)
mutable debug_read_by_localized : string list;
}
[@@deriving sexp_of]
(** A DAG of decorated [Node]s, also storing the shape information. *)
Expand Down Expand Up @@ -147,9 +140,6 @@ let create ~(value_prec : prec) ?(grad_prec : prec option) ?(literal = false) ~n
grad_never_device_only = false;
literal;
backend_info = "";
localized_to = None;
read_by_localized = [];
debug_read_by_localized = [];
}
in
Hashtbl.add_exn global_node_store ~key:node.id ~data;
Expand Down Expand Up @@ -540,12 +530,8 @@ let to_printbox ?single_node ?entries_per_axis ?with_backend_info ?(with_id = fa
let print_node_preamble id =
try
let n = get id in
Caml.Format.printf "Node %s%s,@ read-by-task-id: %a@ via %a;\n%!" (node_header n)
Caml.Format.printf "Node %s%s\n%!" (node_header n)
(if String.is_empty n.backend_info then "" else " " ^ n.backend_info)
Sexp.pp_hum
([%sexp_of: int list] n.read_by_localized)
Sexp.pp_hum
([%sexp_of: string list] n.debug_read_by_localized)
with Not_found_s _ | Caml.Not_found -> Caml.Format.printf "Node #%d does not exist.\n%!" id

let print_preamble ?(from = 0) () =
Expand Down

0 comments on commit 298abb7

Please sign in to comment.