Spaces:
Runtime error
Runtime error
| from keras.engine.base_layer import Layer | |
| from keras.engine.input_spec import InputSpec | |
| from keras import initializers, regularizers, constraints | |
| from keras import backend as K | |
| from keras.saving.object_registration import get_custom_objects | |
| import tensorflow as tf | |
| class InstanceNormalization(Layer): | |
| """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). | |
| Normalize the activations of the previous layer at each step, | |
| i.e. applies a transformation that maintains the mean activation | |
| close to 0 and the activation standard deviation close to 1. | |
| # Arguments | |
| axis: Integer, the axis that should be normalized | |
| (typically the features axis). | |
| For instance, after a `Conv2D` layer with | |
| `data_format="channels_first"`, | |
| set `axis=1` in `InstanceNormalization`. | |
| Setting `axis=None` will normalize all values in each instance of the batch. | |
| Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. | |
| epsilon: Small float added to variance to avoid dividing by zero. | |
| center: If True, add offset of `beta` to normalized tensor. | |
| If False, `beta` is ignored. | |
| scale: If True, multiply by `gamma`. | |
| If False, `gamma` is not used. | |
| When the next layer is linear (also e.g. `nn.relu`), | |
| this can be disabled since the scaling | |
| will be done by the next layer. | |
| beta_initializer: Initializer for the beta weight. | |
| gamma_initializer: Initializer for the gamma weight. | |
| beta_regularizer: Optional regularizer for the beta weight. | |
| gamma_regularizer: Optional regularizer for the gamma weight. | |
| beta_constraint: Optional constraint for the beta weight. | |
| gamma_constraint: Optional constraint for the gamma weight. | |
| # Input shape | |
| Arbitrary. Use the keyword argument `input_shape` | |
| (tuple of integers, does not include the samples axis) | |
| when using this layer as the first layer in a model. | |
| # Output shape | |
| Same shape as input. | |
| # References | |
| - [Layer Normalization](https://arxiv.org/abs/1607.06450) | |
| - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) | |
| """ | |
| def __init__(self, | |
| axis=None, | |
| epsilon=1e-3, | |
| center=True, | |
| scale=True, | |
| beta_initializer='zeros', | |
| gamma_initializer='ones', | |
| beta_regularizer=None, | |
| gamma_regularizer=None, | |
| beta_constraint=None, | |
| gamma_constraint=None, | |
| **kwargs): | |
| super(InstanceNormalization, self).__init__(**kwargs) | |
| self.supports_masking = True | |
| self.axis = axis | |
| self.epsilon = epsilon | |
| self.center = center | |
| self.scale = scale | |
| self.beta_initializer = initializers.get(beta_initializer) | |
| self.gamma_initializer = initializers.get(gamma_initializer) | |
| self.beta_regularizer = regularizers.get(beta_regularizer) | |
| self.gamma_regularizer = regularizers.get(gamma_regularizer) | |
| self.beta_constraint = constraints.get(beta_constraint) | |
| self.gamma_constraint = constraints.get(gamma_constraint) | |
| def build(self, input_shape): | |
| ndim = len(input_shape) | |
| if self.axis == 0: | |
| raise ValueError('Axis cannot be zero') | |
| if (self.axis is not None) and (ndim == 2): | |
| raise ValueError('Cannot specify axis for rank 1 tensor') | |
| self.input_spec = InputSpec(ndim=ndim) | |
| if self.axis is None: | |
| shape = (1,) | |
| else: | |
| shape = (input_shape[self.axis],) | |
| if self.scale: | |
| self.gamma = self.add_weight(shape=shape, | |
| name='gamma', | |
| initializer=self.gamma_initializer, | |
| regularizer=self.gamma_regularizer, | |
| constraint=self.gamma_constraint) | |
| else: | |
| self.gamma = None | |
| if self.center: | |
| self.beta = self.add_weight(shape=shape, | |
| name='beta', | |
| initializer=self.beta_initializer, | |
| regularizer=self.beta_regularizer, | |
| constraint=self.beta_constraint) | |
| else: | |
| self.beta = None | |
| self.built = True | |
| def call(self, inputs, training=None): | |
| input_shape = K.int_shape(inputs) | |
| reduction_axes = list(range(0, len(input_shape))) | |
| if (self.axis is not None): | |
| del reduction_axes[self.axis] | |
| del reduction_axes[0] | |
| mean, var = tf.nn.moments(inputs, reduction_axes, keepdims=True) | |
| stddev = tf.sqrt(var) + self.epsilon | |
| normed = (inputs - mean) / stddev | |
| broadcast_shape = [1] * len(input_shape) | |
| if self.axis is not None: | |
| broadcast_shape[self.axis] = input_shape[self.axis] | |
| if self.scale: | |
| broadcast_gamma = K.reshape(self.gamma, broadcast_shape) | |
| normed = normed * broadcast_gamma | |
| if self.center: | |
| broadcast_beta = K.reshape(self.beta, broadcast_shape) | |
| normed = normed + broadcast_beta | |
| return normed | |
| def get_config(self): | |
| config = { | |
| 'axis': self.axis, | |
| 'epsilon': self.epsilon, | |
| 'center': self.center, | |
| 'scale': self.scale, | |
| 'beta_initializer': initializers.serialize(self.beta_initializer), | |
| 'gamma_initializer': initializers.serialize(self.gamma_initializer), | |
| 'beta_regularizer': regularizers.serialize(self.beta_regularizer), | |
| 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), | |
| 'beta_constraint': constraints.serialize(self.beta_constraint), | |
| 'gamma_constraint': constraints.serialize(self.gamma_constraint) | |
| } | |
| base_config = super(InstanceNormalization, self).get_config() | |
| return dict(list(base_config.items()) + list(config.items())) | |
| get_custom_objects().update({'InstanceNormalization': InstanceNormalization}) | |