아래의 코드는 Unsupervised Domain Adaptation by Backpropagation의 내용을 코드로 정리한 내용입니다.
Code
import tensorflow as tf
import numpy as np
import random
import os
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from tensorflow.python.framework import ops
import pickle
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
WARNING:tensorflow:From <ipython-input-2-71e12f4bac70>:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/donghwa/lab/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
load mnist
mnist.train.images.shape
(55000, 784)
# Process MNIST (55000, 784)
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
mnist_train.shape
(55000, 28, 28, 3)
Load MNIST-M
mnistm_data.pkl
를 다운 받을 수 있음# Load MNIST-M
mnistm = pickle.load(open('mnistm_data.pkl', 'rb'))
mnistm_train = mnistm['train']
mnistm_test = mnistm['test']
mnistm_valid = mnistm['valid']
# average of each RGB
pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2))
pixel_mean.shape
(3,)
# Create a mixed dataset for TSNE visualization
num_test = 500
combined_test_imgs = np.vstack([mnist_test[:num_test], mnistm_test[:num_test]])
combined_test_labels = np.vstack([mnist.test.labels[:num_test], mnist.test.labels[:num_test]])
combined_test_domain = np.vstack([np.tile([1., 0.], [num_test, 1]),
np.tile([0., 1.], [num_test, 1])])
def imshow_grid(images, shape=[2, 8]):
"""Plot images in a grid of a given shape."""
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in range(size):
grid[i].axis('off')
grid[i].imshow(images[i]) # The AxesGrid object work as a list of axes.
plt.show()
imshow_grid(mnist_train)
imshow_grid(mnistm_train)
batch_size = 64
X = tf.placeholder(tf.uint8, [None, 28, 28, 3])
y = tf.placeholder(tf.float32, [None, 10])
domainID = tf.placeholder(tf.float32, [None, 2])
lb = tf.placeholder(tf.float32, []) # lambda
lr = tf.placeholder(tf.float32, []) # learning rate
is_training = tf.placeholder(tf.bool, [])
X_input = (tf.cast(X, tf.float32) - pixel_mean) / 255.
X_input
<tf.Tensor 'truediv:0' shape=(?, 28, 28, 3) dtype=float32>
featureExtractor
def conv_layer(inputs,
name,
kernelSize,
inChannel,
outChannel,
stride = 1):
with tf.variable_scope(name) as name:
convWeights = tf.get_variable(
"W",
shape= [kernelSize, kernelSize, inChannel, outChannel],
initializer=tf.initializers.truncated_normal(stddev=0.1)
)
convBiases = tf.get_variable(
"b",
shape=[outChannel],
initializer=tf.constant_initializer(0.1)
)
conv = tf.nn.conv2d(inputs, convWeights, [1,stride,stride,1], padding='SAME')
convBias = tf.nn.bias_add(conv, convBiases)
relu = tf.nn.relu(convBias)
return relu
with tf.variable_scope('feature_extractor') as scope:
# Conv1: n x 28(=(28-5+4)/1+1) x 28 x 32 ;
conv1 = conv_layer(X_input, kernelSize=5, name= "conv1", inChannel=3, outChannel=32)
# Pool1: n x 14(=(28-2)/2+1) x 14 x 32
pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')
# Conv2: n x 14(=(14-5+4)/1+1) x 14 x 48
conv2 = conv_layer(pool1, kernelSize=5, name= "conv2", inChannel=32, outChannel=48)
# Pool2: n x 7(=(14-2)/2+1) x 7 x 48
pool2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2')
# domain-invariant feature
feature = tf.reshape(pool2, [-1, np.prod(pool2.shape.as_list()[1:])])
feature
<tf.Tensor 'feature_extractor/Reshape:0' shape=(?, 2352) dtype=float32>
labelPredictor
그 함수의 형태로 바꿔주는 간단한 방법은 lambda를 사용하는 것이다.
[0, 0]
: location[batch_size // 2, -1]
: location으로부터 batch_size // 2
행만큼, `-1(2352) 열만큼 데이터를 잘라냄# when training
source_features = lambda: tf.slice(feature, [0, 0], [batch_size // 2, -1])
source_labels = lambda: tf.slice(y, [0, 0], [batch_size // 2, -1])
# when testing
all_features = lambda: feature
all_labels = lambda: y
# features/labels
training_features = tf.cond(is_training, source_features, all_features)
true_labels = tf.cond(is_training, source_labels, all_labels)
def fc_layer(inputs,
name,
in_channel,
out_channel, act_output = True):
with tf.variable_scope(name) as name:
fcWeights = tf.get_variable(
"W",
shape= [in_channel, out_channel],
initializer=tf.initializers.truncated_normal(stddev=0.1)
)
fcBiases = tf.get_variable(
"b",
shape=[out_channel],
initializer=tf.constant_initializer(0.1)
)
logit = tf.nn.bias_add(tf.matmul(inputs, fcWeights), fcBiases)
if act_output==True:
return tf.nn.relu(logit)
return logit
feature.shape.as_list()[-1]
2352
with tf.variable_scope('label_predictor') as scope:
label_fc1_output= fc_layer(training_features, name ="fc1", in_channel= feature.shape.as_list()[-1], out_channel=100, act_output=True)
label_fc2_output= fc_layer(label_fc1_output, name ="fc2", in_channel= 100, out_channel=100, act_output=True)
label_fc3_output= fc_layer(label_fc2_output, name ="fc3", in_channel= 100, out_channel=10, act_output=False)
label_fc3_output
<tf.Tensor 'label_predictor/fc3/BiasAdd:0' shape=(?, 10) dtype=float32>
label_pred = tf.nn.softmax(label_fc3_output, axis = -1)
label_pred
<tf.Tensor 'Softmax:0' shape=(?, 10) dtype=float32>
# logit: label_fc3_output
pred_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=label_fc3_output, labels=true_labels)
pred_loss
<tf.Tensor 'softmax_cross_entropy_with_logits/Reshape_2:0' shape=(?,) dtype=float32>
domainPredictor
class FlipGradientBuilder(object):
def __init__(self):
self.num_calls = 0
def __call__(self, x, lb=1.0):
grad_name = "FlipGradient%d" % self.num_calls
@ops.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.negative(grad) * lb]
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": grad_name}):
y = tf.identity(x) # copy for assign op
self.num_calls += 1
return y
flip gradient
를 적용하면 아래와 같다.
lb
(lambda)는 label loss의 grad
와 domain loss의 -grad
의 상대적 중요도를 의미한다.flip_gradient = FlipGradientBuilder()
with tf.variable_scope('domain_predictor'):
# Flip the gradient when backpropagating through this operation
grad_reversed_feature = flip_gradient(feature, lb)
domain_fc1_output= fc_layer(grad_reversed_feature, name ="fc1", in_channel= feature.shape.as_list()[-1], out_channel=100, act_output=True)
domain_fc2_output= fc_layer(domain_fc1_output, name ="fc2", in_channel= 100, out_channel=2, act_output=False)
domain_pred = tf.nn.softmax(domain_fc2_output)
domain_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=domain_fc2_output, labels=domainID)
domain_loss
<tf.Tensor 'softmax_cross_entropy_with_logits_1/Reshape_2:0' shape=(?,) dtype=float32>
Loss
# loss
predLoss = tf.reduce_mean(pred_loss)
domainLoss = tf.reduce_mean(domain_loss)
totalLoss = predLoss + domainLoss
# train op
# label_train_op = tf.train.AdamOptimizer(0.9).minimize(predLoss)
DA_train_op = tf.train.MomentumOptimizer(lr,0.9).minimize(totalLoss)
# Accuacy
correct_label_pred = tf.equal(tf.argmax(true_labels, 1), tf.argmax(label_pred, 1)) #boolean
label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32)) #scalar
correct_domain_pred = tf.equal(tf.argmax(domainID, 1), tf.argmax(domain_pred, 1))
domain_acc = tf.reduce_mean(tf.cast(correct_domain_pred, tf.float32))
Session on
sess=tf.Session()
sess.run(tf.global_variables_initializer())
def shuffle_aligned_list(pairs):
"""Shuffle arrays in a list by shuffling each array identically."""
num = pairs[0].shape[0]
pm_idx = np.random.permutation(num)
return [pair[pm_idx] for pair in pairs]
def batch_generator(pairs, batch_size, shuffle=True):
"""Generate batches of data.
Given a list of array-like objects, generate batches of a given
size by yielding a list of array-like objects corresponding to the
same slice of each input.
"""
if shuffle:
pairs = shuffle_aligned_list(pairs)
batch_num = 0
while True:
if batch_num * batch_size + batch_size >= len(pairs[0]):
batch_num = 0
if shuffle:
pairs = shuffle_aligned_list(pairs)
start = batch_num * batch_size
end = start + batch_size
batch_num += 1
yield [pair[start:end] for pair in pairs]
gen_source_batch = batch_generator(pairs=[mnist_train, mnist.train.labels], batch_size= batch_size // 2)
gen_target_batch = batch_generator(pairs=[mnistm_train, mnist.train.labels], batch_size= batch_size // 2)
gen_source_only_batch = batch_generator(pairs=[mnist_train, mnist.train.labels],batch_size= batch_size)
gen_target_only_batch = batch_generator(pairs=[mnistm_train, mnist.train.labels], batch_size=batch_size)
domain_labels = np.vstack([np.tile([1., 0.], [batch_size // 2, 1]),
np.tile([0., 1.], [batch_size // 2, 1])])
# Training model
num_steps = 8600
for i in range(num_steps):
# Adaptation param and learning rate schedule as described in the paper
p = float(i) / num_steps # smaller over time
_lb = 2. / (1. + np.exp(-10. * p)) - 1 # smaller over time
_lr = 0.01 / (1. + 10 * p)**0.75
# Training step
XS, yS = next(gen_source_batch)
XT, yT = next(gen_target_batch)
X_feed = np.vstack([XS, XT])
y_feed = np.vstack([yS, yT])
_, batch_loss, dloss, ploss, d_acc, p_acc = sess.run(
[DA_train_op, totalLoss, domainLoss, predLoss, domain_acc, label_acc],
feed_dict = {X: X_feed,
y: y_feed,
domainID: domain_labels,
is_training: True,
lb: _lb,
lr: _lr})
if i % 100 == 0:
print('loss: {} d_acc: {} p_acc: {} p: {} lb: {} lr: {}'.format(
batch_loss, d_acc, p_acc, p, _lb, _lr))
loss: 4.825463771820068 d_acc: 0.5 p_acc: 0.0625 p: 0.0 lb: 0.0 lr: 0.01
loss: 0.40940043330192566 d_acc: 0.90625 p_acc: 0.9375 p: 0.011627906976744186 lb: 0.05807411547586794 lr: 0.009208108196143174
loss: 0.5818338394165039 d_acc: 0.890625 p_acc: 0.875 p: 0.023255813953488372 lb: 0.11575782577418603 lr: 0.008548589038554535
loss: 0.5852369666099548 d_acc: 0.796875 p_acc: 0.9375 p: 0.03488372093023256 lb: 0.1726711536624863 lr: 0.007989697680708598
loss: 0.5564760565757751 d_acc: 0.90625 p_acc: 0.9375 p: 0.046511627906976744 lb: 0.22845439143629886 lr: 0.0075092390442056635
loss: 0.35558179020881653 d_acc: 0.875 p_acc: 0.96875 p: 0.05813953488372093 lb: 0.28277682569099527 lr: 0.0070911987391239035
loss: 0.2507930397987366 d_acc: 0.9375 p_acc: 1.0 p: 0.06976744186046512 lb: 0.33534391863054624 lr: 0.006723713451795381
loss: 0.25439000129699707 d_acc: 0.953125 p_acc: 0.96875 p: 0.08139534883720931 lb: 0.3859026564904293 lr: 0.006397796023621206
loss: 0.5283358097076416 d_acc: 0.859375 p_acc: 0.9375 p: 0.09302325581395349 lb: 0.43424492823161986 lr: 0.006106505778508959
loss: 0.4712458550930023 d_acc: 0.859375 p_acc: 0.90625 p: 0.10465116279069768 lb: 0.4802089471455322 lr: 0.00584439199954653
loss: 0.4825291633605957 d_acc: 0.890625 p_acc: 0.96875 p: 0.11627906976744186 lb: 0.5236788585597902 lr: 0.0056071106987353935
loss: 0.19707489013671875 d_acc: 0.90625 p_acc: 1.0 p: 0.12790697674418605 lb: 0.5645827773148586 lr: 0.005391154585323163
loss: 0.4898703396320343 d_acc: 0.78125 p_acc: 1.0 p: 0.13953488372093023 lb: 0.6028895635619975 lr: 0.005193658897076379
loss: 0.3004189431667328 d_acc: 0.9375 p_acc: 1.0 p: 0.1511627906976744 lb: 0.6386046745549958 lr: 0.005012259239308981
loss: 0.4147544503211975 d_acc: 0.859375 p_acc: 0.96875 p: 0.16279069767441862 lb: 0.6717654275930591 lr: 0.004844985806046933
loss: 0.2904179096221924 d_acc: 0.875 p_acc: 0.96875 p: 0.1744186046511628 lb: 0.7024359819836457 lr: 0.004690183518397721
loss: 0.6488478183746338 d_acc: 0.84375 p_acc: 0.90625 p: 0.18604651162790697 lb: 0.730702303851418 lr: 0.0045464509301225125
loss: 0.2889407277107239 d_acc: 0.90625 p_acc: 0.96875 p: 0.19767441860465115 lb: 0.75666732465968 lr: 0.004412592926309448
loss: 0.34746885299682617 d_acc: 0.9375 p_acc: 0.90625 p: 0.20930232558139536 lb: 0.7804464491564882 lr: 0.004287583697558238
loss: 0.40334779024124146 d_acc: 0.84375 p_acc: 1.0 p: 0.22093023255813954 lb: 0.8021635162227865 lr: 0.004170537464590465
loss: 0.5859626531600952 d_acc: 0.953125 p_acc: 0.90625 p: 0.23255813953488372 lb: 0.8219472701703336 lr: 0.00406068511562926
loss: 0.5954075455665588 d_acc: 0.796875 p_acc: 0.9375 p: 0.2441860465116279 lb: 0.8399283622203215 lr: 0.003957355402193208
loss: 0.5475267171859741 d_acc: 0.8125 p_acc: 0.9375 p: 0.2558139534883721 lb: 0.8562368727213061 lr: 0.003859959683449113
loss: 0.8015364408493042 d_acc: 0.65625 p_acc: 0.9375 p: 0.26744186046511625 lb: 0.8710003237511923 lr: 0.0037679794579745665
loss: 0.9847149848937988 d_acc: 0.671875 p_acc: 0.9375 p: 0.27906976744186046 lb: 0.8843421381310939 lr: 0.0036809561034616242
loss: 0.5318312048912048 d_acc: 0.828125 p_acc: 0.90625 p: 0.29069767441860467 lb: 0.8963804933094455 lr: 0.0035984823790740587
loss: 1.4411591291427612 d_acc: 0.75 p_acc: 0.8125 p: 0.3023255813953488 lb: 0.907227515730854 lr: 0.0035201953452889123
loss: 1.3302115201950073 d_acc: 0.53125 p_acc: 0.875 p: 0.313953488372093 lb: 0.9169887619362538 lr: 0.0034457704314723612
loss: 0.5538606643676758 d_acc: 0.71875 p_acc: 1.0 p: 0.32558139534883723 lb: 0.9257629356552557 lr: 0.003374916438763681
loss: 0.8089284896850586 d_acc: 0.53125 p_acc: 0.9375 p: 0.3372093023255814 lb: 0.9336417946475963 lr: 0.003307371309778185
loss: 0.6464406847953796 d_acc: 0.65625 p_acc: 0.96875 p: 0.3488372093023256 lb: 0.9407102063254855 lr: 0.0032428985305833465
loss: 0.892132043838501 d_acc: 0.796875 p_acc: 0.875 p: 0.36046511627906974 lb: 0.9470463167220862 lr: 0.003181284056820732
loss: 0.5464159250259399 d_acc: 0.796875 p_acc: 0.96875 p: 0.37209302325581395 lb: 0.9527218027997246 lr: 0.0031223336765527007
loss: 0.8238319158554077 d_acc: 0.796875 p_acc: 0.96875 p: 0.38372093023255816 lb: 0.9578021831783621 lr: 0.003065870738750244
loss: 0.22796763479709625 d_acc: 0.890625 p_acc: 1.0 p: 0.3953488372093023 lb: 0.962347166972511 lr: 0.0030117341893097
loss: 0.7139698266983032 d_acc: 0.765625 p_acc: 0.96875 p: 0.4069767441860465 lb: 0.9664110244879942 lr: 0.002959776866846464
loss: 0.32875534892082214 d_acc: 0.84375 p_acc: 1.0 p: 0.4186046511627907 lb: 0.9700429670347053 lr: 0.002909864018835692
loss: 0.5904496908187866 d_acc: 0.875 p_acc: 0.90625 p: 0.43023255813953487 lb: 0.9732875260777021 lr: 0.002861872005390412
loss: 0.4229314923286438 d_acc: 0.875 p_acc: 1.0 p: 0.4418604651162791 lb: 0.9761849244171494 lr: 0.0028156871634225067
loss: 0.3996061086654663 d_acc: 0.796875 p_acc: 1.0 p: 0.45348837209302323 lb: 0.9787714341094471 lr: 0.0027712048083815533
loss: 0.4992428421974182 d_acc: 0.828125 p_acc: 0.96875 p: 0.46511627906976744 lb: 0.9810797174732149 lr: 0.002728328354412792
loss: 0.5148041248321533 d_acc: 0.828125 p_acc: 0.9375 p: 0.47674418604651164 lb: 0.9831391488202712 lr: 0.002686968536776859
loss: 0.5221948027610779 d_acc: 0.765625 p_acc: 0.96875 p: 0.4883720930232558 lb: 0.9849761155658179 lr: 0.0026470427228549257
loss: 0.6307872533798218 d_acc: 0.78125 p_acc: 0.90625 p: 0.5 lb: 0.9866142981514305 lr: 0.0026084743001221454
loss: 0.6382717490196228 d_acc: 0.828125 p_acc: 0.90625 p: 0.5116279069767442 lb: 0.9880749288014745 lr: 0.002571192131188139
loss: 0.4289991557598114 d_acc: 0.90625 p_acc: 0.96875 p: 0.5232558139534884 lb: 0.9893770295647741 lr: 0.002535130067438327
loss: 0.40435507893562317 d_acc: 0.890625 p_acc: 0.96875 p: 0.5348837209302325 lb: 0.9905376303999733 lr: 0.002500226514014457
loss: 0.6215903759002686 d_acc: 0.78125 p_acc: 0.96875 p: 0.5465116279069767 lb: 0.9915719682712341 lr: 0.0024664240398871913
loss: 0.5021497011184692 d_acc: 0.796875 p_acc: 1.0 p: 0.5581395348837209 lb: 0.9924936683523062 lr: 0.0024336690276309993
loss: 0.5572361946105957 d_acc: 0.75 p_acc: 0.96875 p: 0.5697674418604651 lb: 0.9933149085094493 lr: 0.0024019113582383557
loss: 0.5989366769790649 d_acc: 0.65625 p_acc: 1.0 p: 0.5813953488372093 lb: 0.9940465682614799 lr: 0.0023711041269283283
loss: 1.0323117971420288 d_acc: 0.6875 p_acc: 0.8125 p: 0.5930232558139535 lb: 0.9946983634099507 lr: 0.00234120338643174
loss: 0.9565933346748352 d_acc: 0.796875 p_acc: 0.90625 p: 0.6046511627906976 lb: 0.9952789675032998 lr: 0.002312167914685915
loss: 0.4626711308956146 d_acc: 0.796875 p_acc: 0.96875 p: 0.6162790697674418 lb: 0.9957961212529436 lr: 0.0022839590042587057
loss: 0.4873444736003876 d_acc: 0.8125 p_acc: 1.0 p: 0.627906976744186 lb: 0.9962567309623462 lr: 0.0022565402711539955
loss: 0.8458352088928223 d_acc: 0.75 p_acc: 0.90625 p: 0.6395348837209303 lb: 0.996666956966364 lr: 0.0022298774809375384
loss: 0.7572154402732849 d_acc: 0.796875 p_acc: 0.90625 p: 0.6511627906976745 lb: 0.9970322930109379 lr: 0.0022039383903697716
loss: 0.45236918330192566 d_acc: 0.859375 p_acc: 0.96875 p: 0.6627906976744186 lb: 0.9973576374349025 lr: 0.0021786926029468733
loss: 0.35806888341903687 d_acc: 0.890625 p_acc: 0.96875 p: 0.6744186046511628 lb: 0.9976473569480822 lr: 0.0021541114369377466
loss: 0.435672402381897 d_acc: 0.84375 p_acc: 0.96875 p: 0.686046511627907 lb: 0.9979053437342209 lr: 0.002130167804666827
loss: 0.5702878832817078 d_acc: 0.71875 p_acc: 1.0 p: 0.6976744186046512 lb: 0.9981350665445223 lr: 0.0021068361019340996
loss: 0.7712510228157043 d_acc: 0.65625 p_acc: 0.96875 p: 0.7093023255813954 lb: 0.998339616388191 lr: 0.002084092106587389
loss: 0.45937344431877136 d_acc: 0.859375 p_acc: 0.96875 p: 0.7209302325581395 lb: 0.9985217473707251 lr: 0.002061912885370307
loss: 0.8559670448303223 d_acc: 0.640625 p_acc: 0.96875 p: 0.7325581395348837 lb: 0.9986839131789393 lr: 0.0020402767082642755
loss: 0.7581861019134521 d_acc: 0.71875 p_acc: 0.9375 p: 0.7441860465116279 lb: 0.9988282996638527 lr: 0.0020191629696266573
loss: 0.8087642192840576 d_acc: 0.71875 p_acc: 0.9375 p: 0.7558139534883721 lb: 0.9989568539285347 lr: 0.001998552115500608
loss: 0.7783613204956055 d_acc: 0.578125 p_acc: 0.96875 p: 0.7674418604651163 lb: 0.9990713102877162 lr: 0.0019784255765372813
loss: 0.5618288516998291 d_acc: 0.59375 p_acc: 1.0 p: 0.7790697674418605 lb: 0.9991732134291553 lr: 0.0019587657060284253
loss: 0.5146697163581848 d_acc: 0.640625 p_acc: 1.0 p: 0.7906976744186046 lb: 0.9992639390733074 lr: 0.0019395557225983058
loss: 0.6340583562850952 d_acc: 0.625 p_acc: 1.0 p: 0.8023255813953488 lb: 0.9993447123974792 lr: 0.0019207796571489855
loss: 0.7205818891525269 d_acc: 0.640625 p_acc: 0.9375 p: 0.813953488372093 lb: 0.999416624463169 lr: 0.0019024223036931153
loss: 0.634271502494812 d_acc: 0.625 p_acc: 0.96875 p: 0.8255813953488372 lb: 0.9994806468604833 lr: 0.0018844691737440408
loss: 0.6444191336631775 d_acc: 0.578125 p_acc: 0.96875 p: 0.8372093023255814 lb: 0.9995376447611286 lr: 0.0018669064539648453
loss: 0.7390989065170288 d_acc: 0.59375 p_acc: 0.96875 p: 0.8488372093023255 lb: 0.9995883885513366 lr: 0.0018497209668063268
loss: 0.5952032804489136 d_acc: 0.6875 p_acc: 1.0 p: 0.8604651162790697 lb: 0.9996335641979495 lr: 0.0018329001338892768
loss: 0.5313435792922974 d_acc: 0.75 p_acc: 1.0 p: 0.872093023255814 lb: 0.9996737824846389 lr: 0.0018164319419091504
loss: 0.9486538767814636 d_acc: 0.75 p_acc: 0.875 p: 0.8837209302325582 lb: 0.9997095872406112 lr: 0.001800304910861562
loss: 0.5704761147499084 d_acc: 0.640625 p_acc: 1.0 p: 0.8953488372093024 lb: 0.9997414626710781 lr: 0.0017845080644053356
loss: 0.7692083120346069 d_acc: 0.640625 p_acc: 0.96875 p: 0.9069767441860465 lb: 0.9997698398870507 lr: 0.0017690309021962457
loss: 0.6062610745429993 d_acc: 0.71875 p_acc: 0.96875 p: 0.9186046511627907 lb: 0.9997951027215122 lr: 0.0017538633740393809
loss: 0.6653850078582764 d_acc: 0.609375 p_acc: 1.0 p: 0.9302325581395349 lb: 0.9998175929096487 lr: 0.0017389958557213823
loss: 0.6266434788703918 d_acc: 0.640625 p_acc: 1.0 p: 0.9418604651162791 lb: 0.9998376147024313 lr: 0.0017244191263958181
loss: 0.5251486897468567 d_acc: 0.75 p_acc: 1.0 p: 0.9534883720930233 lb: 0.9998554389753282 lr: 0.0017101243474058232
loss: 0.6981379985809326 d_acc: 0.59375 p_acc: 0.96875 p: 0.9651162790697675 lb: 0.9998713068872478 lr: 0.0016961030424379434
loss: 0.718218207359314 d_acc: 0.640625 p_acc: 0.96875 p: 0.9767441860465116 lb: 0.999885433138822 lr: 0.001682347078910014
loss: 0.7044693827629089 d_acc: 0.6875 p_acc: 0.90625 p: 0.9883720930232558 lb: 0.9998980088738034 lr: 0.0016688486505039575
Evaluation
# Compute final evaluation on test data
source_acc = sess.run(label_acc,
feed_dict={X: mnist_test, y: mnist.test.labels,
is_training: False})
target_acc = sess.run(label_acc,
feed_dict={X: mnistm_test, y: mnist.test.labels,
is_training: False})
test_domain_acc = sess.run(domain_acc,
feed_dict={X: combined_test_imgs,
domainID: combined_test_domain, lb: 1.0})
test_emb = sess.run(feature, feed_dict={X: combined_test_imgs})
test_emb.shape
(1000, 2352)
print('\nDomain adaptation training')
print('Source (MNIST) accuracy:', source_acc)
print('Target (MNIST-M) accuracy:', target_acc)
print('Domain accuracy:', test_domain_acc)
Domain adaptation training
Source (MNIST) accuracy: 0.9807
Target (MNIST-M) accuracy: 0.7198
Domain accuracy: 0.669
if not os.path.exists('./results'):
os.mkdir('./results')
# images(source: 500, target: 500)
np.save('./results/SourceTargetImages.npy',combined_test_imgs)
# 2352차원을 가진 1000 vector(source: 500, target: 500)
np.save('./results/SourceTargetEmbed.npy', test_emb)
# source or target
np.save('./results/SourceTargetDomain.npy', combined_test_domain)
# mnist label
np.save('./results/SourceTargetLabels.npy', combined_test_labels)
Reference
https://github.com/TGISer/tf-dann