2018년 1월 30일 화요일

Tensorflow 학습 결과 Save & Restore

간단한 linear regression 모델을 저장하고 불러와서 예측 값을 반환하는 예제 코드


# 모델 저장 & 불러오기를 위한 경로 지정
base_path = './tf-models'
date_str = '2018130'
load_path = base_path + '/' + date_str
save_path = load_path + '/my_model'

# 예제 데이터 feature 3개x1_data = [73., 93., 89., 96., 73.]
x2_data = [80., 88., 91., 98., 66.]
x3_data = [75., 93., 90., 100., 70.]
y_data = [152., 185., 180., 196., 142.]

# 학습 & 모델 저장
def learning(self):

    x1 = tf.placeholder(tf.float32)
    x2 = tf.placeholder(tf.float32)
    x3 = tf.placeholder(tf.float32)
    Y = tf.placeholder(tf.float32)

    w1 = tf.Variable(tf.random_normal([1]), name='weight1')
    w2 = tf.Variable(tf.random_normal([1]), name='weight2')
    w3 = tf.Variable(tf.random_normal([1]), name='weight3')
    b = tf.Variable(tf.random_normal([1]), name='bias')

    saver = tf.train.Saver()

    hypothesis = x1 * w1 + x2 * w2 + x3 * w3 + b

    cost = tf.reduce_mean(tf.square(hypothesis - Y))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-5)
    train = optimizer.minimize(cost)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    for step in range(2001):
        cost_val, hy_val, _ = sess.run([cost, hypothesis, train],
                                       feed_dict={x1: self.x1_data, x2: self.x2_data, x3: self.x3_data, Y: self.y_data})

    saver.save(sess, self.save_path)

# 모델 불러오기 & 예측
def prediction(self, _x1, _x2, _x3):
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(self.save_path + '.meta')
        saver.restore(sess, tf.train.latest_checkpoint(self.load_path))

        graph = tf.get_default_graph()

        x1 = tf.placeholder(tf.float32)
        x2 = tf.placeholder(tf.float32)
        x3 = tf.placeholder(tf.float32)

        w1 = graph.get_tensor_by_name("weight1:0")
        w2 = graph.get_tensor_by_name("weight2:0")
        w3 = graph.get_tensor_by_name("weight3:0")
        b = graph.get_tensor_by_name("bias:0")

        hypothesis = x1 * w1 + x2 * w2 + x3 * w3 + b

        score = sess.run(hypothesis, feed_dict={x1:float(_x1), x2:float(_x2), x3:float(_x3)})

    return score

댓글 없음:

댓글 쓰기