Skip to content

Commit

Permalink
✨Enhanced get_outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Nov 21, 2024
1 parent b884039 commit acba190
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
14 changes: 11 additions & 3 deletions core/learn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_outputs(
return_outputs: bool = True,
target_outputs: Union[str, List[str]] = PREDICTIONS_KEY,
recover_labels: bool = True,
recover_predictions: bool = True,
return_labels: bool = False,
target_labels: Union[str, List[str]] = LABEL_KEY,
stack_outputs: bool = True,
Expand Down Expand Up @@ -129,7 +130,15 @@ def recover_labels_of(tensors: tensor_dict_type) -> tensor_dict_type:
if recover_labels:
tensors = shallow_copy_dict(tensors)
for k, v in tensors.items():
if v is not None and k in need_recover:
if v is not None and k in target_labels:
tensors[k] = loader.recover_labels(k, v)
return tensors

def recover_predictions_of(tensors: tensor_dict_type) -> tensor_dict_type:
if recover_predictions:
tensors = shallow_copy_dict(tensors)
for k, v in tensors.items():
if v is not None and isinstance(v, Tensor):
tensors[k] = loader.recover_labels(k, v)
return tensors

Expand Down Expand Up @@ -196,7 +205,7 @@ def _run() -> InferenceOutputs:
tensor_batch,
shallow_copy_dict(kwargs),
get_losses=use_losses_as_metrics,
recover_labels_fn=recover_labels_of,
recover_predictions_fn=recover_predictions_of,
)
flags.in_step = False
tensor_outputs = step_outputs.forward_results
Expand Down Expand Up @@ -339,7 +348,6 @@ def run() -> InferenceOutputs:
target_outputs = [target_outputs]
if isinstance(target_labels, str):
target_labels = [target_labels]
need_recover = target_outputs + target_labels
try:
return run()
except KeyboardInterrupt:
Expand Down
4 changes: 4 additions & 0 deletions core/learn/pipeline/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def predict(
return_probabilities: bool = False,
target_outputs: Union[str, List[str]] = PREDICTIONS_KEY,
recover_labels: bool = True,
recover_predictions: bool = True,
accelerator: Optional[Accelerator] = None,
pad_dim: Optional[Union[int, Dict[str, int]]] = None,
**kwargs: Any,
Expand All @@ -171,6 +172,7 @@ def predict(
kw["loader"] = loader
kw["target_outputs"] = target_outputs
kw["recover_labels"] = recover_labels
kw["recover_predictions"] = recover_predictions
kw["accelerator"] = accelerator
kw["pad_dim"] = pad_dim
outputs = safe_execute(self.build_inference.inference.get_outputs, kw)
Expand Down Expand Up @@ -226,6 +228,7 @@ def evaluate(
return_outputs: bool = False,
target_outputs: Union[str, List[str]] = PREDICTIONS_KEY,
recover_labels: bool = True,
recover_predictions: bool = True,
return_labels: bool = False,
target_labels: Union[str, List[str]] = LABEL_KEY,
progress: Optional[Progress] = None,
Expand All @@ -245,6 +248,7 @@ def evaluate(
return_outputs=return_outputs,
target_outputs=target_outputs,
recover_labels=recover_labels,
recover_predictions=recover_predictions,
return_labels=return_labels,
target_labels=target_labels,
progress=progress,
Expand Down
9 changes: 6 additions & 3 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ def get_outputs(
return_outputs: bool = True,
target_outputs: Union[str, List[str]] = PREDICTIONS_KEY,
recover_labels: bool = True,
recover_predictions: bool = True,
return_labels: bool = False,
target_labels: Union[str, List[str]] = LABEL_KEY,
stack_outputs: bool = True,
Expand Down Expand Up @@ -1413,7 +1414,7 @@ def step(
get_losses: bool = False,
detach_losses: bool = True,
loss_kwargs: Optional[Dict[str, Any]] = None,
recover_labels_fn: Optional[Callable[[td_type], td_type]] = None,
recover_predictions_fn: Optional[Callable[[td_type], td_type]] = None,
) -> StepOutputs:
loss_tensors = {}
loss_kwargs = loss_kwargs or {}
Expand All @@ -1428,8 +1429,8 @@ def step(
continue
if fw is None or train_step.requires_new_forward:
fw = get_fw()
if recover_labels_fn is not None:
fw = recover_labels_fn(fw)
if recover_predictions_fn is not None:
fw = recover_predictions_fn(fw)
if get_losses:
loss_res = train_step.loss_fn(self, None, batch, fw, **loss_kwargs)
if not detach_losses:
Expand Down Expand Up @@ -1565,6 +1566,7 @@ def evaluate(
return_outputs: bool = False,
target_outputs: Union[str, List[str]] = PREDICTIONS_KEY,
recover_labels: bool = True,
recover_predictions: bool = True,
return_labels: bool = False,
target_labels: Union[str, List[str]] = LABEL_KEY,
progress: Optional[Progress] = None,
Expand All @@ -1583,6 +1585,7 @@ def evaluate(
return_outputs=return_outputs,
target_outputs=target_outputs,
recover_labels=recover_labels,
recover_predictions=recover_predictions,
return_labels=return_labels,
target_labels=target_labels,
progress=progress,
Expand Down

0 comments on commit acba190

Please sign in to comment.