-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCelebA_models.py
24 lines (18 loc) · 1.09 KB
/
CelebA_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# https://github.com/alps-lab/dpgan/tree/master/models/gans
import torch
from torch import nn
import torch.nn.functional as F
from DCResNet_models import *
import util
class CelebA_DCRN_G64(DCResNetGenerator):
def __init__(self, z_dim=128, channels=[512,512,256,128,64], first_filter_size=4, **kwargs):
super().__init__(z_dim=z_dim, channels=channels, first_filter_size=first_filter_size, out_ch=3, **kwargs)
class CelebA_DCRN_D64(DCResNetDiscriminator):
def __init__(self, channels=[3, 64, 128, 256, 512], last_filter_size=4, **kwargs):
super().__init__(channels=channels, last_filter_size=last_filter_size, **kwargs)
class CelebA_DCRN_G48(DCResNetGenerator):
def __init__(self, z_dim=128, channels=[512,512,256,128], first_filter_size=6, **kwargs):
super().__init__(z_dim=z_dim, channels=channels, first_filter_size=first_filter_size, out_ch=3, **kwargs)
class CelebA_DCRN_D48(DCResNetDiscriminator):
def __init__(self, channels=[3,128,256,512], last_filter_size=6, **kwargs):
super().__init__(channels=channels, last_filter_size=last_filter_size, **kwargs)