ChatBot(챗봇)에 학습되는 입력 데이터

February 24, 2018    ChatBot(챗봇)

ChatBot Input

챗봇 입력데이터는 질문을 한 사람(parent_id) 응답하는 사람(comment_id)의 paired dataset으로 구성해야 하며, 또한 모델을 평가하기 위해 학습(training), 평가(test)데이터로 구분해야만 한다. test.from는 평가데이터이면서 질문하는 내용을 저장, test.to는 평가데이터이면서 응답하는 내용을 저장, train.from train.to 는 학습데이터에 대해서 앞서 한 내용을 적용한다.

  • 이전 블로그에서 업로드한 데이터를 이용하여 seq2seq 입력데이터를 만들어 보자.
  • 월별 많은 데이터 테이블이 존재하게 되는 데 for loop를 실행하면 되니 하나만 예를 들어보자.
timeframes = ['2006_01','2006_02',...]


  • 2006_01 데이터 테이블 예시
timeData = '2006_01'


  • 몇가지 필요한 변수들을 정의해보자

    limit : 한번에 불어오는 데이터의 행의 갯수이다. 이 크기 만큼 training set, test set파일에 번갈아 가면서 업로드 된다. last_unix: 한번 업데이트 하고나면, 업데이트 된 데이터와 중복이 되지 않는 다른 데이터를 training set, test set파일에 재귀적으로 업로드 해야 하는데 데이터 구별을 위해 업데이트 될때마다 마지막 시간값을 기록한다. cur_length: 처음에 limit이라는 값으로 초기화 되지만, 계속적의 데이터 테이블의 row의 수로 업데이트 된다. 이 의미는 업데이트를 하려는 데이터가 얼마 안 남았을 때, cur_length는 사전에 정의한 limit 보다 작을 것이다. 즉, cur_length < limit이 조건이 되었을 때 반복문을 종료시킨다. counter: 실행순서를 출력을 위해 사용한다. test_done: 반복문이 실행될 때마다 True, False값을 교체시켜 training set, test set파일에 저장하는데 쓰인다.

# limit : 얼마나 많은 행을 가져올 것인가? e.g limit = 2000 => 2000 rows와
limit = 5000
# 끝나는 시점
last_unix = 0

cur_length = limit
counter = 0
test_done = False


  • 2006_01 데이터 테이블을 불러오기 위한 MySQL DB연결
connection = MySQLdb.connect(host='localhost',
                             user='root',
                             password='1225')
c = connection.cursor()
c.execute("USE {}_reddit;".format(timeData.split('_')[0]))


  • pandas를 활용한 2006_01 데이터 테이블을 불러오기
df = pd.read_sql("SELECT * FROM {} WHERE unix > {} AND parent IS NOT NULL AND score > 0 ORDER BY unix ASC LIMIT {}".format(timeData, last_unix, limit),
            connection)
df
parent_id comment_id parent comment subreddit unix score
0 c3982 c4141 Too bad Mr Crockford himself misunderstands Ja... Hmm, I tried the following, and it worked: new... reddit.com 1136838703 3
1 c3899 c4186 Sorry, not basic enough. How about explanation... U is a combinator, a function that takes a fun... reddit.com 1136855468 3
2 c3187 c4239 well, I'm learning Python and so are all the g... The ranking claims to rate how 'mainstream' ea... reddit.com 1136889581 3
3 c3846 c5141 Most distributions allow you to update from on... That's not the point. If I install SuSE 9.1 o... reddit.com 1137121161 2
4 c5278 c5310 Quite frankly, I'm sick of hearing about the b... This isn't about socialism vs capitalism. It's... reddit.com 1137165941 5
5 c5222 c5355 why? Because Bill Gates is a very intelligent man. reddit.com 1137174845 2
6 c5287 c5356 I'm wondering how my comment got 3 negative vo... Indeed - you just can't have casual discussion... reddit.com 1137175214 3
7 c5291 c5373 wow, this is the only practical use of fractal... A while back, lots of people were excited abou... reddit.com 1137177834 2
8 c5228 c5411 'hilarious' news piece from *2002*. Can you ev... nope, after I showed the link to my girlfriend... reddit.com 1137183450 2


  • 다음 업데이트 때 순차적인 데이터 셋을 불러오기 위해서 last_unix에 업데이트된 마지막 행의 시간을 기록
last_unix = df.tail(1)['unix'].values[0]
last_unix

1137183450
cur_length = len(df)
cur_length

9


Total Code

for timeData in timeframes:
    connection = MySQLdb.connect(host='localhost',
                                 user='root',
                                 password='1225')
    c = connection.cursor()
    c.execute("USE {}_reddit;".format(timeData.split('_')[0]))

    # limit : 얼마나 많은 행을 가져올 것인가? e.g limit = 2000 => 2000 rows
    limit = 5000
    last_unix = 0
    cur_length = limit
    counter = 0
    test_done = False

    while cur_length == limit:
        # 5000 rows 씩 데이터를 불러옴
        df = pd.read_sql("SELECT * FROM {} WHERE unix > {} AND parent IS NOT NULL AND score > 0 ORDER BY unix ASC LIMIT {}".format(timeData, last_unix, limit),
                    connection)
        # 가장 늦은 시간
        last_unix = df.tail(1)['unix'].values[0]
        cur_length = len(df)
        if not test_done:
            with open('test.from','a', encoding='utf8') as f:
                for content in df['parent'].values:
                    f.write(content+'\n')
            with open('test.to', 'a', encoding='utf8') as f:
                for content in df['comment'].values:
                    f.write(content + '\n')

            test_done = True

        else:
            with open('train.from', 'a', encoding='utf8') as f:
                for content in df['parent'].values:
                    f.write(content + '\n')
            with open('train.to', 'a', encoding='utf8') as f:
                for content in df['comment'].values:
                    f.write(content + '\n')

        counter += 1
        if counter % 20 == 0:
            print('Update:',counter*limit)

DSBA