Python torch.nn 模块,InstanceNorm2d() 实例源码
我们从Python开源项目中,提取了以下32个代码示例,用于说明如何使用torch.nn.InstanceNorm2d()。
def __init__(self):
super(SuccessorPredictor, self).__init__()
def identity(v):
return lambda x: x
bn2d = nn.InstanceNorm2d
bn1d = identity
self.input_size = 9
self.hidden_size = 512
self.nb_layers = 1
self.hidden_fc1 = nn.Linear(512, self.nb_layers*2*self.hidden_size)
self.hidden_fc1_bn = bn1d(self.nb_layers*2*self.hidden_size)
self.rnn = nn.LSTM(self.input_size, self.hidden_size, self.nb_layers, dropout=0.1, batch_first=False)
self.fc1 = nn.Linear(self.hidden_size, 512)
init_weights(self)
def __init__(self):
super(AEDecoder, self).__init__()
def identity(v):
return lambda x: x
bn2d = nn.InstanceNorm2d
bn1d = identity
self.ae_fc1 = nn.Linear(512, 128*3*5)
self.ae_fc1_bn = bn1d(128*3*5)
self.ae_c1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.ae_c1_bn = bn2d(128)
self.ae_c2 = nn.Conv2d(128, 128, kernel_size=3, padding=(0, 1))
self.ae_c2_bn = bn2d(128)
self.ae_c3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.ae_c3_bn = bn2d(128)
self.ae_c4 = nn.Conv2d(128, 3, kernel_size=3, padding=1)
init_weights(self)
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(PixelDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
if use_sigmoid:
self.net.append(nn.Sigmoid())
self.net = nn.Sequential(*self.net)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type):
if norm_type == 'batch':
norm_layer = nn.BatchNorm2d
elif norm_type == 'instance':
norm_layer = nn.InstanceNorm2d
else:
print('normalization layer [%s] is not found' % norm)
return norm_layer
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def test_InstanceNorm2d(self):
b = random.randint(3, 5)
c = random.randint(1, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
input = torch.Tensor(b, c, h, w).uniform_()
self._test_InstanceNorm(nn.InstanceNorm2d, input)
def test_InstanceNorm2d(self):
b = random.randint(3, 5)
c = random.randint(1, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
input = torch.Tensor(b, c, h, w).uniform_()
self._test_InstanceNorm(nn.InstanceNorm2d, input)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return norm_layer
def __init__(self):
super(Embedder, self).__init__()
def identity(v):
return lambda x: x
bn2d = nn.InstanceNorm2d
bn1d = identity
self.nb_previous_images = 2
self.emb_sup_c1 = nn.Conv2d(512, 1024, kernel_size=3, padding=0, stride=1)
self.emb_sup_c1_bn = bn2d(1024)
self.emb_sup_c1_sd = nn.Dropout2d(0.0)
self.emb_add_fc1 = nn.Linear(
(self.nb_previous_images+1) # speeds
+ (self.nb_previous_images+1) # is_reverse
+ (self.nb_previous_images+1) # steering wheel
+ (self.nb_previous_images+1) # steering wheel raw
+ self.nb_previous_images*9,
128
)
self.emb_add_fc1_bn = bn1d(128)
self.emb_fc1 = nn.Linear(1024*3 + 128, 512)
self.emb_fc1_bn = bn1d(512)
init_weights(self)
def __init__(self, nb_bins):
super(DirectRewardPredictor, self).__init__()
def identity(v):
return lambda x: x
bn2d = nn.InstanceNorm2d
bn1d = identity
self.fc1 = nn.Linear(512, 128)
self.fc1_bn = bn1d(128)
self.fc2 = nn.Linear(128, nb_bins)
init_weights(self)
def __init__(self):
super(IndirectRewardPredictor, self).__init__()
def identity(v):
return lambda x: x
bn2d = nn.InstanceNorm2d
bn1d = identity
self.fc1 = nn.Linear(512, 128)
self.fc1_bn = bn1d(128)
self.fc_v = nn.Linear(128, 1)
self.fc_advantage = nn.Linear(128, 9)
init_weights(self)
def test_InstanceNorm2d(self):
b = random.randint(3, 5)
c = random.randint(1, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
input = torch.Tensor(b, c, h, w).uniform_()
self._test_InstanceNorm(nn.InstanceNorm2d, input)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type):
if norm_type == 'batch':
norm_layer = nn.BatchNorm2d
elif norm_type == 'instance':
norm_layer = nn.InstanceNorm2d
else:
print('normalization layer [%s] is not found' % norm_layer)
return norm_layer
def test_InstanceNorm2d(self):
b = random.randint(3, 5)
c = random.randint(1, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
input = torch.Tensor(b, c, h, w).uniform_()
self._test_InstanceNorm(nn.InstanceNorm2d, input)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def weights_init(m):
# m is an operation under the network calling apply()
# e.g. type(m) returns <class 'torch.nn.modules.conv.Conv2d'>
# m.weight.data returns a torch tensor indicating the weight values
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# get requested normalization layer
def get_norm_layer(norm_type):
if norm_type == 'batch':
norm_layer = nn.BatchNorm2d
elif norm_type == 'instance':
norm_layer = nn.InstanceNorm2d
else:
print('normalization layer [{}] is not found'.format(norm))
return norm_layer
# define generator G
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], use_parallel = True, learn_residual = False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
self.use_parallel = use_parallel
self.learn_residual = learn_residual
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def __init__(self, outer_nc, inner_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[], use_parallel = True):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type,
norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def build_layers(img_sz, img_fm, init_fm, max_fm, n_layers, n_attr, n_skip,
deconv_method, instance_norm, enc_dropout, dec_dropout):
"""
Build auto-encoder layers.
"""
assert init_fm <= max_fm
assert n_skip <= n_layers - 1
assert np.log2(img_sz).is_integer()
assert n_layers <= int(np.log2(img_sz))
assert type(instance_norm) is bool
assert 0 <= enc_dropout < 1
assert 0 <= dec_dropout < 1
norm_fn = nn.InstanceNorm2d if instance_norm else nn.BatchNorm2d
enc_layers = []
dec_layers = []
n_in = img_fm
n_out = init_fm
for i in range(n_layers):
enc_layer = []
dec_layer = []
skip_connection = n_layers - (n_skip + 1) <= i < n_layers - 1
n_dec_in = n_out + n_attr + (n_out if skip_connection else 0)
n_dec_out = n_in
# encoder layer
enc_layer.append(nn.Conv2d(n_in, n_out, 4, 2, 1))
if i > 0:
enc_layer.append(norm_fn(n_out, affine=True))
enc_layer.append(nn.LeakyReLU(0.2, inplace=True))
if enc_dropout > 0:
enc_layer.append(nn.Dropout(enc_dropout))
# decoder layer
if deconv_method == 'upsampling':
dec_layer.append(nn.UpsamplingNearest2d(scale_factor=2))
dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out, 3, 1, 1))
elif deconv_method == 'convtranspose':
dec_layer.append(nn.ConvTranspose2d(n_dec_in, n_dec_out, 4, 2, 1, bias=False))
else:
assert deconv_method == 'pixelshuffle'
dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out * 4, 3, 1, 1))
dec_layer.append(nn.PixelShuffle(2))
if i > 0:
dec_layer.append(norm_fn(n_dec_out, affine=True))
if dec_dropout > 0 and i >= n_layers - 3:
dec_layer.append(nn.Dropout(dec_dropout))
dec_layer.append(nn.ReLU(inplace=True))
else:
dec_layer.append(nn.Tanh())
# update
n_in = n_out
n_out = min(2 * n_out, max_fm)
enc_layers.append(nn.Sequential(*enc_layer))
dec_layers.insert(0, nn.Sequential(*dec_layer))
return enc_layers, dec_layers