-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Network graph for style_transfer_n model #33
Comments
Dear vikasrs,
I could not retrieve the original StyleTransferN model from a codebase I can publicly release, sorry about that.
…-- Michaël
From: vikasrs <[email protected]>
Reply-To: mgharbi/hdrnet_legacy <[email protected]>
Date: Monday, October 8, 2018 at 3:34 PM
To: mgharbi/hdrnet_legacy <[email protected]>
Cc: Subscribed <[email protected]>
Subject: [mgharbi/hdrnet_legacy] Network graph for style_transfer_n model (#33)
Dear Gharbi,
I notice that TF network graph definition fro style_transfer_n is missing in models.py. Looks like the pre-trained-model specifies the computation graph should be 'StyleTransferCurves'. Can you please provide the computation graph?
I took a stab defining the graph, but get visually poor results (missing high frequency detail). I have attached my attempt at the graph below. Do you see any issues?
class StyleTransferCurves(HDRNetCurves):
@classmethod<https://github.com/classmethod>
def n_in(cls):
return 6 + 1
@classmethod<https://github.com/classmethod>
def inference(cls, lowres_input, fullres_input, params,
is_training=False):
with tf.variable_scope('coefficients'):
bilateral_coeffs = cls._coefficients(lowres_input, params, is_training)
tf.add_to_collection('bilateral_coefficients', bilateral_coeffs)
with tf.variable_scope('guide'):
guide = cls._guide(fullres_input, params, is_training)
tf.add_to_collection('guide', guide)
with tf.variable_scope('output'):
output = cls._output(
fullres_input, guide, bilateral_coeffs)
tf.add_to_collection('output', output)
return output
@classmethod<https://github.com/classmethod>
def _coefficients(cls, input_tensor, params, is_training):
bs = input_tensor.get_shape().as_list()[0]
gd = params['luma_bins']
cm = params['channel_multiplier']
spatial_bin = params['spatial_bin']
# -----------------------------------------------------------------------
with tf.variable_scope('splat'):
n_ds_layers = int(np.log2(params['net_input_size']/spatial_bin))
current_layer = input_tensor
for i in range(n_ds_layers):
if i > 0: # don't normalize first layer
use_bn = params['batch_norm']
else:
use_bn = False
current_layer = conv(current_layer, cm*(2**i)*gd, 3, stride=2,
batch_norm=use_bn, is_training=is_training,
scope='conv{}'.format(i+1))
splat_features = current_layer
# -----------------------------------------------------------------------
# -----------------------------------------------------------------------
with tf.variable_scope('global'):
n_global_layers = int(np.log2(spatial_bin/4)) # 4x4 at the coarsest lvl
current_layer = splat_features
for i in range(2):
current_layer = conv(current_layer, 8*cm*gd, 3, stride=2,
batch_norm=params['batch_norm'], is_training=is_training,
scope="conv{}".format(i+1))
_, lh, lw, lc = current_layer.get_shape().as_list()
current_layer = tf.reshape(current_layer, [bs, lh*lw*lc])
current_layer = fc(current_layer, 32*cm*gd,
batch_norm=params['batch_norm'], is_training=is_training,
scope="fc1")
current_layer = fc(current_layer, 16*cm*gd,
batch_norm=params['batch_norm'], is_training=is_training,
scope="fc2")
# don't normalize before fusion
current_layer = fc(current_layer, 8*cm*gd, activation_fn=None, scope="fc3")
global_features = current_layer
# -----------------------------------------------------------------------
# -----------------------------------------------------------------------
with tf.variable_scope('local'):
current_layer = splat_features
current_layer = conv(current_layer, 8*cm*gd, 3,
batch_norm=params['batch_norm'],
is_training=is_training,
scope='conv1')
# don't normalize before fusion
current_layer = conv(current_layer, 8*cm*gd, 3, activation_fn=None,
use_bias=False, scope='conv2')
grid_features = current_layer
# -----------------------------------------------------------------------
# -----------------------------------------------------------------------
with tf.name_scope('fusion'):
fusion_grid = grid_features
fusion_global = tf.reshape(global_features, [bs, 1, 1, 8*cm*gd])
fusion = tf.nn.relu(fusion_grid+fusion_global)
# -----------------------------------------------------------------------
# -----------------------------------------------------------------------
with tf.variable_scope('prediction'):
current_layer = fusion
current_layer = conv(current_layer, gd*cls.n_out()*(cls.n_in()-3), 1,
activation_fn=None, scope='conv1')
with tf.name_scope('unroll_grid'):
current_layer = tf.stack(
tf.split(current_layer, cls.n_out()*(cls.n_in()-3), axis=3), axis=4)
current_layer = tf.stack(
tf.split(current_layer, cls.n_in()-3, axis=4), axis=5)
tf.add_to_collection('packed_coefficients', current_layer)
# -----------------------------------------------------------------------
return current_layer
@classmethod<https://github.com/classmethod>
def _guide(cls, input_tensor, params, is_training):
npts = 16 # number of control points for the curve
nchans = input_tensor.get_shape().as_list()[-1]
guidemap = input_tensor
# Color space change
idtity = np.identity(nchans, dtype=np.float32) + np.random.randn(1).astype(np.float32) * 1e-4
ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity)
with tf.name_scope('ccm'):
ccm_bias = tf.get_variable('ccm_bias', shape=[nchans, ], dtype=tf.float32,
initializer=tf.constant_initializer(0.0))
guidemap = tf.matmul(tf.reshape(input_tensor, [-1, nchans]), ccm)
guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add')
guidemap = tf.reshape(guidemap, tf.shape(input_tensor))
# Per-channel curve
with tf.name_scope('curve'):
shifts_ = np.linspace(0, 1, npts, endpoint=False, dtype=np.float32)
shifts_ = shifts_[np.newaxis, np.newaxis, np.newaxis, :]
shifts_ = np.tile(shifts_, (1, 1, nchans, 1))
guidemap = tf.expand_dims(guidemap, 4)
shifts = tf.get_variable('shifts', dtype=tf.float32, initializer=shifts_)
slopes_ = np.zeros([1, 1, 1, nchans, npts], dtype=np.float32)
slopes_[:, :, :, :, 0] = 1.0
slopes = tf.get_variable('slopes', dtype=tf.float32, initializer=slopes_)
guidemap = tf.reduce_sum(slopes * tf.nn.relu(guidemap - shifts), reduction_indices=[4])
guidemap = tf.contrib.layers.convolution2d(
inputs=guidemap,
num_outputs=1, kernel_size=1,
weights_initializer=tf.constant_initializer(1.0 / nchans),
biases_initializer=tf.constant_initializer(0),
activation_fn=None,
variables_collections={'weights': [tf.GraphKeys.WEIGHTS], 'biases': [tf.GraphKeys.BIASES]},
outputs_collections=[tf.GraphKeys.ACTIVATIONS],
scope='channel_mixing')
guidemap = tf.clip_by_value(guidemap, 0, 1)
guidemap = tf.squeeze(guidemap, squeeze_dims=[3, ])
return guidemap
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub<#33>, or mute the thread<https://github.com/notifications/unsubscribe-auth/ABib33KAFkUVYl39LxAc1S_8QhGF5zHkks5ui9MOgaJpZM4XNyIS>.
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Dear Gharbi,
I notice that TF network graph definition fro style_transfer_n is missing in models.py. Looks like the pre-trained-model specifies the computation graph should be 'StyleTransferCurves'. Can you please provide the computation graph?
I took a stab defining the graph, but get visually poor results (missing high frequency detail). I have attached my attempt at the graph below. Do you see any issues?
class StyleTransferCurves(HDRNetCurves):
@classmethod
def n_in(cls):
return 6 + 1
@classmethod
def inference(cls, lowres_input, fullres_input, params,
is_training=False):
@classmethod
def _coefficients(cls, input_tensor, params, is_training):
bs = input_tensor.get_shape().as_list()[0]
gd = params['luma_bins']
cm = params['channel_multiplier']
spatial_bin = params['spatial_bin']
@classmethod
def _guide(cls, input_tensor, params, is_training):
npts = 16 # number of control points for the curve
nchans = input_tensor.get_shape().as_list()[-1]
The text was updated successfully, but these errors were encountered: