Domain Adaptation with a Gradient Reverse Layer

April 21, 2019    Domain Adaptation

아래의 코드는 Unsupervised Domain Adaptation by Backpropagation의 내용을 코드로 정리한 내용입니다.


Code

import tensorflow as tf
import numpy as np
import random
import os
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from tensorflow.python.framework import ops
import pickle


mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
WARNING:tensorflow:From <ipython-input-2-71e12f4bac70>:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


load mnist

mnist.train.images.shape
(55000, 784)


# Process MNIST (55000, 784)
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)

mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)


mnist_train.shape
(55000, 28, 28, 3)


Load MNIST-M

  • 여기mnistm_data.pkl를 다운 받을 수 있음
# Load MNIST-M
mnistm = pickle.load(open('mnistm_data.pkl', 'rb'))
mnistm_train = mnistm['train']
mnistm_test = mnistm['test']
mnistm_valid = mnistm['valid']


  • mean pixel
# average of each RGB
pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2))
pixel_mean.shape
(3,)


# Create a mixed dataset for TSNE visualization
num_test = 500
combined_test_imgs = np.vstack([mnist_test[:num_test], mnistm_test[:num_test]])
combined_test_labels = np.vstack([mnist.test.labels[:num_test], mnist.test.labels[:num_test]])
combined_test_domain = np.vstack([np.tile([1., 0.], [num_test, 1]),
        np.tile([0., 1.], [num_test, 1])])


def imshow_grid(images, shape=[2, 8]):
    """Plot images in a grid of a given shape."""
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i])  # The AxesGrid object work as a list of axes.

    plt.show()


imshow_grid(mnist_train)

Imgur


imshow_grid(mnistm_train)

Imgur


  • Placeholders 정의
batch_size = 64
X = tf.placeholder(tf.uint8, [None, 28, 28, 3])
y = tf.placeholder(tf.float32, [None, 10])
domainID = tf.placeholder(tf.float32, [None, 2])
lb = tf.placeholder(tf.float32, []) # lambda
lr = tf.placeholder(tf.float32, []) # learning rate
is_training = tf.placeholder(tf.bool, [])


  • Input normalize
X_input = (tf.cast(X, tf.float32) - pixel_mean) / 255.
X_input
<tf.Tensor 'truediv:0' shape=(?, 28, 28, 3) dtype=float32>



featureExtractor

  • Conv block 정의
def conv_layer(inputs,
               name,
               kernelSize,
               inChannel,
               outChannel,
               stride = 1):        
    with tf.variable_scope(name) as name:
        convWeights = tf.get_variable(
                "W",
                shape= [kernelSize, kernelSize, inChannel, outChannel], 
                initializer=tf.initializers.truncated_normal(stddev=0.1)
                )

        convBiases = tf.get_variable(
                "b",
                shape=[outChannel],
                initializer=tf.constant_initializer(0.1)
                )

        conv = tf.nn.conv2d(inputs, convWeights, [1,stride,stride,1], padding='SAME')
        convBias = tf.nn.bias_add(conv, convBiases)

        relu = tf.nn.relu(convBias)

    return relu  


with tf.variable_scope('feature_extractor') as scope:
    
    # Conv1: n x 28(=(28-5+4)/1+1) x 28 x 32 ; 
    conv1 = conv_layer(X_input, kernelSize=5, name= "conv1", inChannel=3, outChannel=32) 

    # Pool1: n x 14(=(28-2)/2+1) x 14 x 32
    pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')

    # Conv2: n x 14(=(14-5+4)/1+1) x 14 x 48
    conv2 = conv_layer(pool1, kernelSize=5, name= "conv2", inChannel=32, outChannel=48)

    # Pool2: n x 7(=(14-2)/2+1) x 7 x 48
    pool2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2')
    
    # domain-invariant feature
    feature = tf.reshape(pool2, [-1, np.prod(pool2.shape.as_list()[1:])])


feature
<tf.Tensor 'feature_extractor/Reshape:0' shape=(?, 2352) dtype=float32>



labelPredictor

  • Train할때는 source/target domain을 구분하여 학습하고, Test할때는 데이터를 전부 학습해야한다.
  • tf.cond를 사용하여 ifelse조건에 따라 함수1(for train) or 함수2(for test)를 반환한다.
  • 그 함수의 형태로 바꿔주는 간단한 방법은 lambda를 사용하는 것이다.

  • when training
    • tf.slice를 사용하여 soure데이터와 target데이터를 나눠준다.
    • [0, 0]: location
    • [batch_size // 2, -1]: location으로부터 batch_size // 2 행만큼, `-1(2352) 열만큼 데이터를 잘라냄
    • 결국 주어진 데이터를 행으로 반으로 자르는 역할
# when training
source_features = lambda: tf.slice(feature, [0, 0], [batch_size // 2, -1])
source_labels = lambda: tf.slice(y, [0, 0], [batch_size // 2, -1])

# when testing
all_features = lambda: feature 
all_labels = lambda: y
# features/labels
training_features = tf.cond(is_training, source_features, all_features) 
true_labels = tf.cond(is_training, source_labels, all_labels)


def fc_layer(inputs,
             name,
            in_channel,
            out_channel, act_output = True): 
    
    with tf.variable_scope(name) as name:
        fcWeights = tf.get_variable(
                "W",
                shape= [in_channel, out_channel], 
                initializer=tf.initializers.truncated_normal(stddev=0.1)
                )

        fcBiases = tf.get_variable(
                "b",
                shape=[out_channel],
                initializer=tf.constant_initializer(0.1)
                )
        
        logit = tf.nn.bias_add(tf.matmul(inputs, fcWeights), fcBiases)

        if act_output==True:
            return tf.nn.relu(logit)

    return logit  


  • feature dimensions
feature.shape.as_list()[-1]
2352


with tf.variable_scope('label_predictor') as scope:
    label_fc1_output= fc_layer(training_features, name ="fc1", in_channel= feature.shape.as_list()[-1], out_channel=100, act_output=True)
    label_fc2_output= fc_layer(label_fc1_output, name ="fc2", in_channel= 100, out_channel=100, act_output=True)
    label_fc3_output= fc_layer(label_fc2_output, name ="fc3", in_channel= 100, out_channel=10, act_output=False)


label_fc3_output
<tf.Tensor 'label_predictor/fc3/BiasAdd:0' shape=(?, 10) dtype=float32>


label_pred = tf.nn.softmax(label_fc3_output, axis = -1)
label_pred
<tf.Tensor 'Softmax:0' shape=(?, 10) dtype=float32>


# logit: label_fc3_output
pred_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=label_fc3_output, labels=true_labels)
pred_loss
<tf.Tensor 'softmax_cross_entropy_with_logits/Reshape_2:0' shape=(?,) dtype=float32>



domainPredictor

  • Gradient reverse layer를 구하는 방법을 살펴보자
    • Forward propagation: identity에 의해서 같은 값이 나옴
    • backward propagation: negative gradient(-lambda * gradient)이 계산됨
class FlipGradientBuilder(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, lb=1.0):
        
        grad_name = "FlipGradient%d" % self.num_calls
        
        @ops.RegisterGradient(grad_name)
        def _flip_gradients(op, grad):
            return [tf.negative(grad) * lb]
        
        g = tf.get_default_graph()
        with g.gradient_override_map({"Identity": grad_name}):
            y = tf.identity(x) # copy for assign op
            
        self.num_calls += 1
        return y


  • 위의 flip gradient를 적용하면 아래와 같다.
    • lb(lambda)는 label loss의 grad와 domain loss의 -grad의 상대적 중요도를 의미한다.
flip_gradient = FlipGradientBuilder()


with tf.variable_scope('domain_predictor'):

    # Flip the gradient when backpropagating through this operation
    grad_reversed_feature = flip_gradient(feature, lb)
    domain_fc1_output= fc_layer(grad_reversed_feature, name ="fc1", in_channel= feature.shape.as_list()[-1], out_channel=100, act_output=True)
    domain_fc2_output= fc_layer(domain_fc1_output, name ="fc2", in_channel= 100, out_channel=2, act_output=False)
    


domain_pred = tf.nn.softmax(domain_fc2_output)


domain_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=domain_fc2_output, labels=domainID)
domain_loss
<tf.Tensor 'softmax_cross_entropy_with_logits_1/Reshape_2:0' shape=(?,) dtype=float32>



Loss

# loss
predLoss = tf.reduce_mean(pred_loss)
domainLoss = tf.reduce_mean(domain_loss)
totalLoss = predLoss + domainLoss

# train op
# label_train_op = tf.train.AdamOptimizer(0.9).minimize(predLoss)
DA_train_op = tf.train.MomentumOptimizer(lr,0.9).minimize(totalLoss)

# Accuacy
correct_label_pred = tf.equal(tf.argmax(true_labels, 1), tf.argmax(label_pred, 1)) #boolean
label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32)) #scalar
correct_domain_pred = tf.equal(tf.argmax(domainID, 1), tf.argmax(domain_pred, 1))
domain_acc = tf.reduce_mean(tf.cast(correct_domain_pred, tf.float32))



Session on

sess=tf.Session()
sess.run(tf.global_variables_initializer())


def shuffle_aligned_list(pairs):
    """Shuffle arrays in a list by shuffling each array identically."""
    num = pairs[0].shape[0]
    pm_idx = np.random.permutation(num)
    return [pair[pm_idx] for pair in pairs]


def batch_generator(pairs, batch_size, shuffle=True):
    """Generate batches of data.
    
    Given a list of array-like objects, generate batches of a given
    size by yielding a list of array-like objects corresponding to the
    same slice of each input.
    """
    if shuffle:
        pairs = shuffle_aligned_list(pairs)

    batch_num = 0
    while True:
        if batch_num * batch_size + batch_size >= len(pairs[0]):
            batch_num = 0

            if shuffle:
                pairs = shuffle_aligned_list(pairs)

        start = batch_num * batch_size
        end = start + batch_size
        batch_num += 1
        yield [pair[start:end] for pair in pairs]


  • an half of source and target domain examples
  • source/target domain examples을 반씩 잘라 붙여야 사전에 설정해 놓은 batch size랑 일치한다.
    • [(32, 28, 28, 3), (32, 10)]
gen_source_batch = batch_generator(pairs=[mnist_train, mnist.train.labels], batch_size= batch_size // 2)
gen_target_batch = batch_generator(pairs=[mnistm_train, mnist.train.labels], batch_size= batch_size // 2)


  • All of source/target domain examples
    • [(64, 28, 28, 3), (64, 10)]
gen_source_only_batch = batch_generator(pairs=[mnist_train, mnist.train.labels],batch_size= batch_size)
gen_target_only_batch = batch_generator(pairs=[mnistm_train, mnist.train.labels], batch_size=batch_size)


  • domain labels if source or target
domain_labels = np.vstack([np.tile([1., 0.], [batch_size // 2, 1]),
                           np.tile([0., 1.], [batch_size // 2, 1])])


# Training model
num_steps = 8600

for i in range(num_steps):

    # Adaptation param and learning rate schedule as described in the paper
    p = float(i) / num_steps # smaller over time
    _lb = 2. / (1. + np.exp(-10. * p)) - 1 # smaller over time
    _lr = 0.01 / (1. + 10 * p)**0.75 

    # Training step

    XS, yS = next(gen_source_batch)
    XT, yT = next(gen_target_batch)

    X_feed = np.vstack([XS, XT])
    y_feed = np.vstack([yS, yT])

    _, batch_loss, dloss, ploss, d_acc, p_acc = sess.run(
        [DA_train_op, totalLoss, domainLoss, predLoss, domain_acc, label_acc],
        feed_dict = {X: X_feed,
           y: y_feed,
           domainID: domain_labels,
           is_training: True,
           lb: _lb,
           lr: _lr})
    
    if i % 100 == 0:
        print('loss: {}  d_acc: {}  p_acc: {}  p: {}  lb: {}  lr: {}'.format(
                batch_loss, d_acc, p_acc, p, _lb, _lr))
loss: 4.825463771820068  d_acc: 0.5  p_acc: 0.0625  p: 0.0  lb: 0.0  lr: 0.01
loss: 0.40940043330192566  d_acc: 0.90625  p_acc: 0.9375  p: 0.011627906976744186  lb: 0.05807411547586794  lr: 0.009208108196143174
loss: 0.5818338394165039  d_acc: 0.890625  p_acc: 0.875  p: 0.023255813953488372  lb: 0.11575782577418603  lr: 0.008548589038554535
loss: 0.5852369666099548  d_acc: 0.796875  p_acc: 0.9375  p: 0.03488372093023256  lb: 0.1726711536624863  lr: 0.007989697680708598
loss: 0.5564760565757751  d_acc: 0.90625  p_acc: 0.9375  p: 0.046511627906976744  lb: 0.22845439143629886  lr: 0.0075092390442056635
loss: 0.35558179020881653  d_acc: 0.875  p_acc: 0.96875  p: 0.05813953488372093  lb: 0.28277682569099527  lr: 0.0070911987391239035
loss: 0.2507930397987366  d_acc: 0.9375  p_acc: 1.0  p: 0.06976744186046512  lb: 0.33534391863054624  lr: 0.006723713451795381
loss: 0.25439000129699707  d_acc: 0.953125  p_acc: 0.96875  p: 0.08139534883720931  lb: 0.3859026564904293  lr: 0.006397796023621206
loss: 0.5283358097076416  d_acc: 0.859375  p_acc: 0.9375  p: 0.09302325581395349  lb: 0.43424492823161986  lr: 0.006106505778508959
loss: 0.4712458550930023  d_acc: 0.859375  p_acc: 0.90625  p: 0.10465116279069768  lb: 0.4802089471455322  lr: 0.00584439199954653
loss: 0.4825291633605957  d_acc: 0.890625  p_acc: 0.96875  p: 0.11627906976744186  lb: 0.5236788585597902  lr: 0.0056071106987353935
loss: 0.19707489013671875  d_acc: 0.90625  p_acc: 1.0  p: 0.12790697674418605  lb: 0.5645827773148586  lr: 0.005391154585323163
loss: 0.4898703396320343  d_acc: 0.78125  p_acc: 1.0  p: 0.13953488372093023  lb: 0.6028895635619975  lr: 0.005193658897076379
loss: 0.3004189431667328  d_acc: 0.9375  p_acc: 1.0  p: 0.1511627906976744  lb: 0.6386046745549958  lr: 0.005012259239308981
loss: 0.4147544503211975  d_acc: 0.859375  p_acc: 0.96875  p: 0.16279069767441862  lb: 0.6717654275930591  lr: 0.004844985806046933
loss: 0.2904179096221924  d_acc: 0.875  p_acc: 0.96875  p: 0.1744186046511628  lb: 0.7024359819836457  lr: 0.004690183518397721
loss: 0.6488478183746338  d_acc: 0.84375  p_acc: 0.90625  p: 0.18604651162790697  lb: 0.730702303851418  lr: 0.0045464509301225125
loss: 0.2889407277107239  d_acc: 0.90625  p_acc: 0.96875  p: 0.19767441860465115  lb: 0.75666732465968  lr: 0.004412592926309448
loss: 0.34746885299682617  d_acc: 0.9375  p_acc: 0.90625  p: 0.20930232558139536  lb: 0.7804464491564882  lr: 0.004287583697558238
loss: 0.40334779024124146  d_acc: 0.84375  p_acc: 1.0  p: 0.22093023255813954  lb: 0.8021635162227865  lr: 0.004170537464590465
loss: 0.5859626531600952  d_acc: 0.953125  p_acc: 0.90625  p: 0.23255813953488372  lb: 0.8219472701703336  lr: 0.00406068511562926
loss: 0.5954075455665588  d_acc: 0.796875  p_acc: 0.9375  p: 0.2441860465116279  lb: 0.8399283622203215  lr: 0.003957355402193208
loss: 0.5475267171859741  d_acc: 0.8125  p_acc: 0.9375  p: 0.2558139534883721  lb: 0.8562368727213061  lr: 0.003859959683449113
loss: 0.8015364408493042  d_acc: 0.65625  p_acc: 0.9375  p: 0.26744186046511625  lb: 0.8710003237511923  lr: 0.0037679794579745665
loss: 0.9847149848937988  d_acc: 0.671875  p_acc: 0.9375  p: 0.27906976744186046  lb: 0.8843421381310939  lr: 0.0036809561034616242
loss: 0.5318312048912048  d_acc: 0.828125  p_acc: 0.90625  p: 0.29069767441860467  lb: 0.8963804933094455  lr: 0.0035984823790740587
loss: 1.4411591291427612  d_acc: 0.75  p_acc: 0.8125  p: 0.3023255813953488  lb: 0.907227515730854  lr: 0.0035201953452889123
loss: 1.3302115201950073  d_acc: 0.53125  p_acc: 0.875  p: 0.313953488372093  lb: 0.9169887619362538  lr: 0.0034457704314723612
loss: 0.5538606643676758  d_acc: 0.71875  p_acc: 1.0  p: 0.32558139534883723  lb: 0.9257629356552557  lr: 0.003374916438763681
loss: 0.8089284896850586  d_acc: 0.53125  p_acc: 0.9375  p: 0.3372093023255814  lb: 0.9336417946475963  lr: 0.003307371309778185
loss: 0.6464406847953796  d_acc: 0.65625  p_acc: 0.96875  p: 0.3488372093023256  lb: 0.9407102063254855  lr: 0.0032428985305833465
loss: 0.892132043838501  d_acc: 0.796875  p_acc: 0.875  p: 0.36046511627906974  lb: 0.9470463167220862  lr: 0.003181284056820732
loss: 0.5464159250259399  d_acc: 0.796875  p_acc: 0.96875  p: 0.37209302325581395  lb: 0.9527218027997246  lr: 0.0031223336765527007
loss: 0.8238319158554077  d_acc: 0.796875  p_acc: 0.96875  p: 0.38372093023255816  lb: 0.9578021831783621  lr: 0.003065870738750244
loss: 0.22796763479709625  d_acc: 0.890625  p_acc: 1.0  p: 0.3953488372093023  lb: 0.962347166972511  lr: 0.0030117341893097
loss: 0.7139698266983032  d_acc: 0.765625  p_acc: 0.96875  p: 0.4069767441860465  lb: 0.9664110244879942  lr: 0.002959776866846464
loss: 0.32875534892082214  d_acc: 0.84375  p_acc: 1.0  p: 0.4186046511627907  lb: 0.9700429670347053  lr: 0.002909864018835692
loss: 0.5904496908187866  d_acc: 0.875  p_acc: 0.90625  p: 0.43023255813953487  lb: 0.9732875260777021  lr: 0.002861872005390412
loss: 0.4229314923286438  d_acc: 0.875  p_acc: 1.0  p: 0.4418604651162791  lb: 0.9761849244171494  lr: 0.0028156871634225067
loss: 0.3996061086654663  d_acc: 0.796875  p_acc: 1.0  p: 0.45348837209302323  lb: 0.9787714341094471  lr: 0.0027712048083815533
loss: 0.4992428421974182  d_acc: 0.828125  p_acc: 0.96875  p: 0.46511627906976744  lb: 0.9810797174732149  lr: 0.002728328354412792
loss: 0.5148041248321533  d_acc: 0.828125  p_acc: 0.9375  p: 0.47674418604651164  lb: 0.9831391488202712  lr: 0.002686968536776859
loss: 0.5221948027610779  d_acc: 0.765625  p_acc: 0.96875  p: 0.4883720930232558  lb: 0.9849761155658179  lr: 0.0026470427228549257
loss: 0.6307872533798218  d_acc: 0.78125  p_acc: 0.90625  p: 0.5  lb: 0.9866142981514305  lr: 0.0026084743001221454
loss: 0.6382717490196228  d_acc: 0.828125  p_acc: 0.90625  p: 0.5116279069767442  lb: 0.9880749288014745  lr: 0.002571192131188139
loss: 0.4289991557598114  d_acc: 0.90625  p_acc: 0.96875  p: 0.5232558139534884  lb: 0.9893770295647741  lr: 0.002535130067438327
loss: 0.40435507893562317  d_acc: 0.890625  p_acc: 0.96875  p: 0.5348837209302325  lb: 0.9905376303999733  lr: 0.002500226514014457
loss: 0.6215903759002686  d_acc: 0.78125  p_acc: 0.96875  p: 0.5465116279069767  lb: 0.9915719682712341  lr: 0.0024664240398871913
loss: 0.5021497011184692  d_acc: 0.796875  p_acc: 1.0  p: 0.5581395348837209  lb: 0.9924936683523062  lr: 0.0024336690276309993
loss: 0.5572361946105957  d_acc: 0.75  p_acc: 0.96875  p: 0.5697674418604651  lb: 0.9933149085094493  lr: 0.0024019113582383557
loss: 0.5989366769790649  d_acc: 0.65625  p_acc: 1.0  p: 0.5813953488372093  lb: 0.9940465682614799  lr: 0.0023711041269283283
loss: 1.0323117971420288  d_acc: 0.6875  p_acc: 0.8125  p: 0.5930232558139535  lb: 0.9946983634099507  lr: 0.00234120338643174
loss: 0.9565933346748352  d_acc: 0.796875  p_acc: 0.90625  p: 0.6046511627906976  lb: 0.9952789675032998  lr: 0.002312167914685915
loss: 0.4626711308956146  d_acc: 0.796875  p_acc: 0.96875  p: 0.6162790697674418  lb: 0.9957961212529436  lr: 0.0022839590042587057
loss: 0.4873444736003876  d_acc: 0.8125  p_acc: 1.0  p: 0.627906976744186  lb: 0.9962567309623462  lr: 0.0022565402711539955
loss: 0.8458352088928223  d_acc: 0.75  p_acc: 0.90625  p: 0.6395348837209303  lb: 0.996666956966364  lr: 0.0022298774809375384
loss: 0.7572154402732849  d_acc: 0.796875  p_acc: 0.90625  p: 0.6511627906976745  lb: 0.9970322930109379  lr: 0.0022039383903697716
loss: 0.45236918330192566  d_acc: 0.859375  p_acc: 0.96875  p: 0.6627906976744186  lb: 0.9973576374349025  lr: 0.0021786926029468733
loss: 0.35806888341903687  d_acc: 0.890625  p_acc: 0.96875  p: 0.6744186046511628  lb: 0.9976473569480822  lr: 0.0021541114369377466
loss: 0.435672402381897  d_acc: 0.84375  p_acc: 0.96875  p: 0.686046511627907  lb: 0.9979053437342209  lr: 0.002130167804666827
loss: 0.5702878832817078  d_acc: 0.71875  p_acc: 1.0  p: 0.6976744186046512  lb: 0.9981350665445223  lr: 0.0021068361019340996
loss: 0.7712510228157043  d_acc: 0.65625  p_acc: 0.96875  p: 0.7093023255813954  lb: 0.998339616388191  lr: 0.002084092106587389
loss: 0.45937344431877136  d_acc: 0.859375  p_acc: 0.96875  p: 0.7209302325581395  lb: 0.9985217473707251  lr: 0.002061912885370307
loss: 0.8559670448303223  d_acc: 0.640625  p_acc: 0.96875  p: 0.7325581395348837  lb: 0.9986839131789393  lr: 0.0020402767082642755
loss: 0.7581861019134521  d_acc: 0.71875  p_acc: 0.9375  p: 0.7441860465116279  lb: 0.9988282996638527  lr: 0.0020191629696266573
loss: 0.8087642192840576  d_acc: 0.71875  p_acc: 0.9375  p: 0.7558139534883721  lb: 0.9989568539285347  lr: 0.001998552115500608
loss: 0.7783613204956055  d_acc: 0.578125  p_acc: 0.96875  p: 0.7674418604651163  lb: 0.9990713102877162  lr: 0.0019784255765372813
loss: 0.5618288516998291  d_acc: 0.59375  p_acc: 1.0  p: 0.7790697674418605  lb: 0.9991732134291553  lr: 0.0019587657060284253
loss: 0.5146697163581848  d_acc: 0.640625  p_acc: 1.0  p: 0.7906976744186046  lb: 0.9992639390733074  lr: 0.0019395557225983058
loss: 0.6340583562850952  d_acc: 0.625  p_acc: 1.0  p: 0.8023255813953488  lb: 0.9993447123974792  lr: 0.0019207796571489855
loss: 0.7205818891525269  d_acc: 0.640625  p_acc: 0.9375  p: 0.813953488372093  lb: 0.999416624463169  lr: 0.0019024223036931153
loss: 0.634271502494812  d_acc: 0.625  p_acc: 0.96875  p: 0.8255813953488372  lb: 0.9994806468604833  lr: 0.0018844691737440408
loss: 0.6444191336631775  d_acc: 0.578125  p_acc: 0.96875  p: 0.8372093023255814  lb: 0.9995376447611286  lr: 0.0018669064539648453
loss: 0.7390989065170288  d_acc: 0.59375  p_acc: 0.96875  p: 0.8488372093023255  lb: 0.9995883885513366  lr: 0.0018497209668063268
loss: 0.5952032804489136  d_acc: 0.6875  p_acc: 1.0  p: 0.8604651162790697  lb: 0.9996335641979495  lr: 0.0018329001338892768
loss: 0.5313435792922974  d_acc: 0.75  p_acc: 1.0  p: 0.872093023255814  lb: 0.9996737824846389  lr: 0.0018164319419091504
loss: 0.9486538767814636  d_acc: 0.75  p_acc: 0.875  p: 0.8837209302325582  lb: 0.9997095872406112  lr: 0.001800304910861562
loss: 0.5704761147499084  d_acc: 0.640625  p_acc: 1.0  p: 0.8953488372093024  lb: 0.9997414626710781  lr: 0.0017845080644053356
loss: 0.7692083120346069  d_acc: 0.640625  p_acc: 0.96875  p: 0.9069767441860465  lb: 0.9997698398870507  lr: 0.0017690309021962457
loss: 0.6062610745429993  d_acc: 0.71875  p_acc: 0.96875  p: 0.9186046511627907  lb: 0.9997951027215122  lr: 0.0017538633740393809
loss: 0.6653850078582764  d_acc: 0.609375  p_acc: 1.0  p: 0.9302325581395349  lb: 0.9998175929096487  lr: 0.0017389958557213823
loss: 0.6266434788703918  d_acc: 0.640625  p_acc: 1.0  p: 0.9418604651162791  lb: 0.9998376147024313  lr: 0.0017244191263958181
loss: 0.5251486897468567  d_acc: 0.75  p_acc: 1.0  p: 0.9534883720930233  lb: 0.9998554389753282  lr: 0.0017101243474058232
loss: 0.6981379985809326  d_acc: 0.59375  p_acc: 0.96875  p: 0.9651162790697675  lb: 0.9998713068872478  lr: 0.0016961030424379434
loss: 0.718218207359314  d_acc: 0.640625  p_acc: 0.96875  p: 0.9767441860465116  lb: 0.999885433138822  lr: 0.001682347078910014
loss: 0.7044693827629089  d_acc: 0.6875  p_acc: 0.90625  p: 0.9883720930232558  lb: 0.9998980088738034  lr: 0.0016688486505039575


Evaluation

  • 학습을 할때는 source/target데이터가 모두 필요하다.
  • 평가할때 source 데이터 성능을 볼땐 source만, target 데이터 성능을 볼땐 target만 사용할 수 있다.
  • 어떤 domain인지를 평가할때는 당연히 두개의 source/target 데이터가 필요하다.
# Compute final evaluation on test data
source_acc = sess.run(label_acc,
                    feed_dict={X: mnist_test, y: mnist.test.labels,
                               is_training: False})

target_acc = sess.run(label_acc,
                    feed_dict={X: mnistm_test, y: mnist.test.labels,
                               is_training: False})


test_domain_acc = sess.run(domain_acc,
                    feed_dict={X: combined_test_imgs,
                               domainID: combined_test_domain, lb: 1.0})


test_emb = sess.run(feature, feed_dict={X: combined_test_imgs})
test_emb.shape
(1000, 2352)


print('\nDomain adaptation training')
print('Source (MNIST) accuracy:', source_acc)
print('Target (MNIST-M) accuracy:', target_acc)
print('Domain accuracy:', test_domain_acc)
Domain adaptation training
Source (MNIST) accuracy: 0.9807
Target (MNIST-M) accuracy: 0.7198
Domain accuracy: 0.669


  • 시각화를 위해 따로 임베딩 vector, domain, label를 뽑아 두었다.
if not os.path.exists('./results'):
    os.mkdir('./results') 
# images(source: 500, target: 500)
np.save('./results/SourceTargetImages.npy',combined_test_imgs)

# 2352차원을 가진 1000 vector(source: 500, target: 500)
np.save('./results/SourceTargetEmbed.npy', test_emb)

# source or target
np.save('./results/SourceTargetDomain.npy', combined_test_domain)

# mnist label
np.save('./results/SourceTargetLabels.npy', combined_test_labels)


Reference

https://github.com/TGISer/tf-dann


DSBA