Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Huan Ling committed Sep 12, 2021
1 parent dee6d7d commit d9564d4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,17 @@ python test_deeplab_cross_validation.py --exp experiments/face_34.json\
--resume [path-to-downstream task checkpoint] --cross_validate True
```

**June 21st Update:**
**June 21 Update:**

For training interpreter, we change the upsampling method from nearnest upsampling to bilinar upsampling in [line](https://github.com/nv-tlabs/datasetGAN_release/blob/release_finallll/datasetGAN/train_interpreter.py#L163) and update results in Table 1. The table reports mIOU.

<img src = "./figs/new_table.png" width="80%"/>
**Sep 12 Update:**

Thanks for [@greatwallet.](https://github.com/greatwallet) According to [issue](https://github.com/nv-tlabs/datasetGAN_release/issues/27), we fixed a uncertainty score calculation bug. The Ours-Fix row shows the results.

<img src = "./figs/new_table2.png" width="80%"/>



## Create your own model

Expand Down
7 changes: 4 additions & 3 deletions datasetGAN/train_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def generate_data(args, checkpoint_path, num_sample, start_step=0, vis=True):
classifier.eval()
classifier_list.append(classifier)

softmax_f = nn.Softmax(dim=1)
with torch.no_grad():
latent_cache = []
image_cache = []
Expand Down Expand Up @@ -289,9 +290,9 @@ def generate_data(args, checkpoint_path, num_sample, start_step=0, vis=True):

all_seg.append(img_seg)
if mean_seg is None:
mean_seg = img_seg
mean_seg = softmax_f(img_seg)
else:
mean_seg += img_seg
mean_seg += softmax_f(img_seg)

img_seg_final = oht_to_scalar(img_seg)
img_seg_final = img_seg_final.reshape(args['dim'][0], args['dim'][1], 1)
Expand All @@ -301,7 +302,7 @@ def generate_data(args, checkpoint_path, num_sample, start_step=0, vis=True):

mean_seg = mean_seg / len(all_seg)

full_entropy = Categorical(logits=mean_seg).entropy()
full_entropy = Categorical(mean_seg).entropy()

js = full_entropy - torch.mean(torch.stack(all_entropy), 0)

Expand Down
Binary file added figs/new_table2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d9564d4

Please sign in to comment.