Skip to content

Commit

Permalink
In progress: revamping device-host sync
Browse files Browse the repository at this point in the history
The goal: proper replication for non-parallel-dependent,
requiring updatable-on-host for non-parallel but parallel-dependent
tensors. Follow-up: (optionally by default if num tasks > 1) convert
commutative assignments into from-const inits + updates.
  • Loading branch information
lukstafi committed May 29, 2023
1 parent 298abb7 commit b5b9e38
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 37 deletions.
11 changes: 9 additions & 2 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ let classify_moons ~on_device executor ~opti_level ~inlining_cutoff ?(inline_con
Exec_as_gccjit.optimization_level := opti_level;
SDSL.num_parallel_tasks := num_parallel_tasks;
SDSL.disable_all_debugs ();
(* Code.with_debug := true;
Code.keep_files_in_run_directory := true; *)
Code.with_debug := true;
Code.keep_files_in_run_directory := true;
(* SDSL.enable_all_debugs (); *)
SDSL.drop_all_sessions ();
let open SDSL.O in
Expand Down Expand Up @@ -245,6 +245,13 @@ let classify_moons ~on_device executor ~opti_level ~inlining_cutoff ?(inline_con
let loss = ref 0.0 in
Stdio.printf "\n%!";
SDSL.refresh_session ~run_for_steps ();
Stdio.printf "\nTotal loss:\n%!";
SDSL.print_node_tree ~with_backend_info:true ~with_grad:true ~depth:9 total_loss.id;
Stdio.printf "\nEpoch loss:\n%!";
SDSL.print_node_tree ~with_backend_info:true ~with_grad:true ~depth:9 epoch_loss.id;
Stdio.printf "\nMinus learning rate:\n%!";
SDSL.print_node_tree ~with_backend_info:true ~with_grad:true ~depth:9 minus_lr.id;
Stdio.printf "\nTrain loop.\n%!";
let start_time = Time_now.nanoseconds_since_unix_epoch () in
while not !stop do
step := !step + advance_per_run;
Expand Down
124 changes: 124 additions & 0 deletions lib/attic.mld
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,128 @@ let get_grad (type val_t arr_t) (prec : (val_t, arr_t) precision) uid : arr_t =
| None -> Lines (flat_lines [| nontask_ts; task_ts |])
| Some s -> Lines (flat_lines [| Comment s; nontask_ts; task_ts |])





mutable localized_to : int option;
(** The ID of the task to which the tensor is localized. A non-none value by itself does not guarantee
that all of the tensor's computations are localized to a single task, only that those which are
only use the given task. *)
mutable read_by_localized : int list;
(** Tasks from which this tensor is read by localized computations. *)
mutable debug_read_by_localized : string list;


let localize_tensors store ~for_task_id llc =
let for_task = Some for_task_id in
let debug = ref "" in
let rec loop = function
| Lines llcs -> Array.iter ~f:loop llcs
| For_loop { body; _ } | Dynamic_indices { body; _ } -> loop body
| Rebalance (_, cs) -> Array.iter ~f:loop cs
| Set (ptr, _, llv) ->
let n = NodeUI.get ptr.id in
if Option.is_some n.localized_to then assert ([%equal: int option] n.localized_to for_task)
else n.localized_to <- Some for_task_id;
let never_device_only =
match ptr.field with
| NodeUI.Value -> n.value_never_device_only
| NodeUI.Grad -> n.grad_never_device_only
in
if never_device_only then (
debug := NodeUI.tensor_ptr_name ptr;
loop_float llv)
| Set_local (_, llv) -> loop_float llv
| Comment _ -> ()
| If_task_id_is { for_task_id = id2; body; _ } ->
assert (id2 = for_task_id);
loop body
and loop_float = function
| Local_scope { body; _ } -> loop body
| Get_local _ | Get_global _ -> ()
| Get (ptr, _) ->
let n = NodeUI.get ptr.id in
let dn = get_node store ptr in
dn.non_device_only <- true;
n.read_by_localized <- for_task_id :: n.read_by_localized;
n.debug_read_by_localized <- !debug :: n.debug_read_by_localized
| Constant _ -> ()
| Binop (_, v1, v2) ->
loop_float v1;
loop_float v2
| Unop (_, v) -> loop_float v
in
loop llc


let rebalance_across_tasks = ref true

let rebalance store llcs =
if not !rebalance_across_tasks then (
let for_task_id = 0 in
let body = Lines (flat_lines llcs) in
localize_tensors store ~for_task_id body;
If_task_id_is { for_task_id; body })
else
let tasks = Array.create ~len:!Shape.num_parallel_tasks [] in
Array.iteri llcs ~f:(fun task body ->
let i = task % !Shape.num_parallel_tasks in
tasks.(i) <- body :: tasks.(i));
Lines
(Array.map tasks ~f:(Fn.compose flat_lines Array.of_list_rev)
|> Array.mapi ~f:(fun for_task_id lines ->
let body = Lines (flat_lines lines) in
If_task_id_is { for_task_id; body }))


(* Inside cleanup_virtual_llc: *)
| Rebalance (s, cs) when balanced -> (
let cs = flat_lines @@ Array.filter_map cs ~f:loop in
match (s, cs) with
| _, [||] -> None
| None, [| c |] -> Some c
| _, cs ->
let c = Array.map ~f:(fun s -> Comment s) @@ Option.to_array s in
Some (Lines (Array.append c cs)))
| Rebalance (s, cs) -> (
(* Don't flatten lines before rebalancing: keep elements of [cs] as single units. *)
let multitask, unitask = partition_tf_with_comment ~f:has_parallel_dim cs in
let rebalanced =
let unitask = Array.filter_map unitask ~f:(loop_proc ~balanced:true ~env_dom) in
if Array.is_empty unitask then None else Some (rebalance traced_store unitask)
in
let multitask = flat_lines @@ Array.filter_map ~f:loop multitask in
if Array.is_empty multitask && Option.is_none rebalanced then None
else
match s with
| None -> Some (Lines (Array.append (Option.to_array rebalanced) multitask))
| Some s ->
Some
(Lines (flat_lines @@ Array.concat [ [| Comment s |]; Option.to_array rebalanced; multitask ]))
)


let rec has_parallel_dim : type a. a low_level -> bool = function
| Comment _ -> false
| Lines ls -> Array.exists ~f:has_parallel_dim ls
| For_loop { body; _ } -> has_parallel_dim body
| Rebalance (_, cs) -> Array.exists ~f:has_parallel_dim cs
| If_task_id_is { body; _ } -> has_parallel_dim body
| Dynamic_indices { tensor = _; tensor_idcs; dynamic_idcs = _; target_dims; body } ->
Array.exists tensor_idcs ~f:Shape.is_task_id
|| Array.exists ~f:Shape.is_parallel target_dims
|| has_parallel_dim body
| Set (_, indices, llv) -> Array.exists indices ~f:Shape.is_task_id || has_parallel_dim llv
| Set_local (_, llv) -> has_parallel_dim llv
| Local_scope { body; orig_indices; _ } ->
Array.exists orig_indices ~f:Shape.is_task_id || has_parallel_dim body
| Get_local _ -> false
| Get_global Task_id -> true
| Get_global _ -> false
| Get (_, indices) -> Array.exists indices ~f:Shape.is_task_id
| Binop (_, llv1, llv2) -> has_parallel_dim llv1 || has_parallel_dim llv2
| Unop (_, llv) -> has_parallel_dim llv
| Constant _ -> false

]}
22 changes: 0 additions & 22 deletions lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -650,28 +650,6 @@ let partition_tf_with_comment cs ~f =
in
(trues, falses)

let rec has_parallel_dim : type a. a low_level -> bool = function
| Comment _ -> false
| Lines ls -> Array.exists ~f:has_parallel_dim ls
| For_loop { body; _ } -> has_parallel_dim body
| Rebalance (_, cs) -> Array.exists ~f:has_parallel_dim cs
| If_task_id_is { body; _ } -> has_parallel_dim body
| Dynamic_indices { tensor = _; tensor_idcs; dynamic_idcs = _; target_dims; body } ->
Array.exists tensor_idcs ~f:Shape.is_task_id
|| Array.exists ~f:Shape.is_parallel target_dims
|| has_parallel_dim body
| Set (_, indices, llv) -> Array.exists indices ~f:Shape.is_task_id || has_parallel_dim llv
| Set_local (_, llv) -> has_parallel_dim llv
| Local_scope { body; orig_indices; _ } ->
Array.exists orig_indices ~f:Shape.is_task_id || has_parallel_dim body
| Get_local _ -> false
| Get_global Task_id -> true
| Get_global _ -> false
| Get (_, indices) -> Array.exists indices ~f:Shape.is_task_id
| Binop (_, llv1, llv2) -> has_parallel_dim llv1 || has_parallel_dim llv2
| Unop (_, llv) -> has_parallel_dim llv
| Constant _ -> false

let precompute_constants ?idcs traced_store top_node llv =
let exception Non_literal of int in
let rec loop llv =
Expand Down
47 changes: 34 additions & 13 deletions lib/exec_as_gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,25 @@ let session_context =
set_option ctx Optimization_level !optimization_level;
ref ctx

type sync_properties =
| Device_only (** The tensor is only needed for a task-local computation and does not exist on host. *)
| Update_on_host
(** All assignments are update assignments. They happen directly on the host, simultaneously
syncing the tensor's cell value. *)
| Parallel_dim
(** The shape of the tensor has a [Parallel] dimension. Each task computes a slice of this dimension,
independently transferring to the host. *)
| Replicated
(** If true, the tensor computation happens on-device in all tasks, but result is transferred to host
on only one task ([task_id = 0]). *)
[@@deriving sexp, equal, compare, variants]

type tensor = {
hosted_ptr : Gccjit.rvalue option;
(** Pointer to the first value of the associated [Bigarray], if hosted. Usually it does not correspond
to the local tensor (e.g. if task id > 0). *)
local : Gccjit.lvalue option; (** A local array, if any. *)
update_on_host : bool;
(** If true, in case of update assignment ([Block.assign_op]), the assignment will happen directly
on the host. *)
is_parallel : bool; (** Whether the shape of the tensor has a [Parallel] dimension. *)
sync : sync_properties;
host_dims : int array;
(** Dimensions (shape) of the tensor as a whole, or an empty array if [hosted_ptr]
is [None]. *)
Expand Down Expand Up @@ -86,9 +96,21 @@ let get_tensor
let local = Function.local func arr_typ @@ NodeUI.tensor_ptr_name ptr in
let host_dims = Bigarray.Genarray.dims arr in
let is_parallel = Array.exists ~f:Shape.is_parallel @@ Shape.to_dims n.shape in
let can_be_replicated =
(* TODO: currently we do not check for gradient tensors, since their computation dependencies are
different than the node dependencies. *)
NodeUI.(equal_data_kind ptr.field Value && (not @@ has_parallel_deps n))
in
let update_on_host =
(not is_parallel) && tn.read_before_write && tn.reduced_racyness && Option.is_some hosted_ptr
in
let sync =
if Option.is_none hosted_ptr then Device_only
else if is_parallel then Parallel_dim
else if update_on_host then Update_on_host
else if can_be_replicated || !Shape.num_parallel_tasks <= 1 then Replicated
else failwith "exec_as_gccjit: synchronization pattern NOT IMPLEMENTED YET"
in
Option.iter hosted_ptr ~f:(fun hosted_ptr ->
if local_is_slice_of_host then (
let offset_idcs =
Expand Down Expand Up @@ -135,8 +157,7 @@ let get_tensor
{
hosted_ptr;
local = Some local;
update_on_host;
is_parallel;
sync;
host_dims;
device_dims;
host_size_in_bytes;
Expand Down Expand Up @@ -189,7 +210,7 @@ let get_ptr tensor =

let get_sync_ptr tensor =
match (tensor.hosted_ptr, tensor.local) with
| Some rv, _ when tensor.update_on_host -> rv
| Some rv, _ when is_update_on_host tensor.sync -> rv
| _, Some lv -> Gccjit.RValue.lvalue lv
| Some rv, _ -> rv
| None, None -> assert false
Expand Down Expand Up @@ -287,10 +308,10 @@ let jit_code ~name ~env ~task_id ({ ctx; func; _ } as state) initial_block (body
let host_idcs = lookup ~on_host:true env idcs in
let tensor = get_tensor state ~jit_code:loop_proc ~host_idcs tensor in
let value = loop_float ~name ~env ~num_typ:tensor.num_typ ~is_double:tensor.is_double c2 in
let idcs = lookup ~on_host:tensor.update_on_host env idcs in
let idcs = lookup ~on_host:(is_update_on_host tensor.sync) env idcs in
let device_offset = jit_array_offset ctx ~idcs ~dims:tensor.device_dims in
let device_lhs = LValue.access_array (get_ptr tensor) device_offset in
if tensor.update_on_host then (
if is_update_on_host tensor.sync then (
let host_offset = jit_array_offset ctx ~idcs:host_idcs ~dims:tensor.host_dims in
let host_lhs = LValue.access_array (get_sync_ptr tensor) host_offset in
Block.assign_op !current_block host_lhs (builtin_op op) value;
Expand All @@ -301,10 +322,10 @@ let jit_code ~name ~env ~task_id ({ ctx; func; _ } as state) initial_block (body
let host_idcs = lookup ~on_host:true env idcs in
let tensor = get_tensor state ~jit_code:loop_proc ~host_idcs tensor in
let value = loop_float ~name ~env ~num_typ:tensor.num_typ ~is_double:tensor.is_double c2 in
let idcs = lookup ~on_host:tensor.update_on_host env idcs in
let idcs = lookup ~on_host:(is_update_on_host tensor.sync) env idcs in
let device_offset = jit_array_offset ctx ~idcs ~dims:tensor.device_dims in
let device_lhs = LValue.access_array (get_ptr tensor) device_offset in
if tensor.update_on_host then (
if is_update_on_host tensor.sync then (
let host_offset = jit_array_offset ctx ~idcs:host_idcs ~dims:tensor.host_dims in
let host_lhs = LValue.access_array (get_sync_ptr tensor) host_offset in
let result =
Expand All @@ -324,10 +345,10 @@ let jit_code ~name ~env ~task_id ({ ctx; func; _ } as state) initial_block (body
let host_idcs = lookup ~on_host:true env idcs in
let tensor = get_tensor state ~jit_code:loop_proc ~host_idcs tensor in
let value = loop_float ~name ~env ~num_typ:tensor.num_typ ~is_double:tensor.is_double c2 in
let idcs = lookup ~on_host:tensor.update_on_host env idcs in
let idcs = lookup ~on_host:(is_update_on_host tensor.sync) env idcs in
let device_offset = jit_array_offset ctx ~idcs ~dims:tensor.device_dims in
let device_lhs = LValue.access_array (get_ptr tensor) device_offset in
if tensor.update_on_host then (
if is_update_on_host tensor.sync then (
let host_offset = jit_array_offset ctx ~idcs:host_idcs ~dims:tensor.host_dims in
let host_lhs = LValue.access_array (get_sync_ptr tensor) host_offset in
let result =
Expand Down
6 changes: 6 additions & 0 deletions lib/nodeUI.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ let host_size_in_bytes ptr =
let f arr = if Array.is_empty @@ A.dims arr then 0 else A.size_in_bytes arr in
Option.value ~default:0 @@ Option.map ~f:(map_as_bigarray { f }) @@ get_tensor ptr

(** Whether the node or any of its descendants have a [Parallel] dimension in their shape.
The negative can only be guaranteed after shape inference. *)
let rec has_parallel_deps n =
if Array.exists ~f:Shape.is_parallel @@ Shape.to_dims n.shape then true
else List.exists ~f:has_parallel_deps @@ List.map n.children ~f:(fun sn -> get sn.sub_node_id)

type prec =
| Void_prec : prec
(* | Bit_as_bool: (bool, bit_as_bool_nd) precision *)
Expand Down

0 comments on commit b5b9e38

Please sign in to comment.