MAGAZINE
ルーターマガジン
機械学習で日本語広告コピーの自動生成をやってみる
はじめに
近年、テキスト自動生成技術は目覚ましい発展を遂げています。2021年にはAIが小説を書いてくれるサービス「AIのべりすと」が登場し、話題になりました。同じく2021年、有名なLINEのチャットボット「りんな」の開発元であるrinna株式会社は、日本語版GPT-2の事前学習モデルを公開しました。GPT-2は、下の例のようにヒントの文を入力すると続きの文章を生成することができるAIです。
入力: お探しのページは、
出力: お探しのページは、ページが削除されているか移動または削除された可能性があります。 あなたが検索した条件また...
更に、事前学習モデルは、追加で学習を行う「ファインチューニング」を施すことで文章生成の傾向を変えることができます。 今回は、弊社のアドクロールという広告クリエイティブ収集サービスの膨大なテキストデータを用いて、広告らしい文章の生成に挑戦しました。
学習結果
アドクロールの広告テキストの中から5万件をサンプリングしてファインチューニングに用いました。学習にはGPUが必要となるため、弊社の保有する機械学習用マシンを使用しました。簡易的なモデル・データ量であるため、学習時間は20分以下でした。
ファインチューニング前後の生成テキスト比較結果は以下のようになりました。
入力 | rinnaモデル (チューニング前) | アドクロールモデル (チューニング後) |
---|---|---|
ギフトが当たる | ギフトが当たる「応募ハガキ」を印刷した箱をプレゼント。 応募ハガキの封筒に当選番号と、応募者名、郵便番号を記入。 抽選で5名にamazonギフト券100円分と | ギフトが当たるお得な応募キャンペーン実施中! |
痩せれない人は絶対やって | 痩せれない人は絶対やってはいけないダイエット法だ。 というのも、ダイエットで大切なのは、リバウンドしないダイエット法として継続性があるかどうかである。 最近、テレビや雑誌で芸能人がダイエットに成功しているのが目に | 痩せれない人は絶対やってる運動 話題の酵素ダイエット |
エンジニアセキュリティ担当者必見 | エンジニアセキュリティ担当者必見! 企業の情報セキュリティ対策セミナー開催 【東京】開催 大手自動車部品メーカーの製品で、「自動運転の実現を目指す」人が多い。その理由はどうなっているのか。本セミナーでは、自動車産業における自動運転 | エンジニアセキュリティ担当者必見の求人募集 |
rinnaモデルは普通の文章であるのに対し、アドクロールモデルは広告テキストらしくなっています。アドクロールモデルは文章が短く、内容が破綻しにくいのも良い点です。ただし、現時点では必ずしも文脈に沿った内容になるわけではなく、一見それらしいけれど少しずれてる文章が生成されることもあります。今後はキーワードを指定して文中に入れられるよう工夫するなど、改善していく予定です。
ファインチューニングのやり方
次の記事を参考にさせて頂きました。
https://towardsdatascience.com/how-to-fine-tune-gpt-2-for-text-generation-ae2ea53bc272
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (AdamW, AutoModelForCausalLM, T5Tokenizer,
get_linear_schedule_with_warmup)
class CopyTexts(Dataset):
def __init__(
self,
df,
pretrained_name,
n_samples=50000,
truncate=True,
max_length=1024,
):
self.tokenizer = T5Tokenizer.from_pretrained(pretrained_name)
s_tokens = self.tokenizer.special_tokens_map
self.tokenizer.do_lower_case = True
self.texts = []
if truncate:
rows = df.sample(n_samples)
else:
rows = df
for _, row in tqdm(rows.iterrows()):
text = str(row["text"])
formatted = f"{s_tokens['bos_token']}{text}{s_tokens['eos_token']}"
self.texts.append(
torch.tensor(self.tokenizer.encode(formatted[:max_length]))
)
self.texts_count = len(self.texts)
def __len__(self):
return self.texts_count
def __getitem__(self, item):
return self.texts[item]
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
if packed_tensor is None:
return new_tensor, True, None
if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
return packed_tensor, False, new_tensor
else:
packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
return packed_tensor, True, None
def train(
dataset,
model,
batch_size=12,
n_epochs=5,
learning_rate=2e-5,
warmup_steps=200,
):
device = torch.device("cuda")
model = model.cuda()
model.train()
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
)
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
loss = 0
accumulating_batch_count = 0
input_tensor = None
for epoch in range(n_epochs):
print(f"Training epoch {epoch}")
print(loss)
for idx, entry in tqdm(enumerate(train_dataloader)):
(input_tensor, carry_on, _) = pack_tensor(entry, input_tensor, 768)
if carry_on and idx != len(train_dataloader) - 1:
continue
input_tensor = input_tensor.to(device)
outputs = model(input_tensor, labels=input_tensor)
loss = outputs[0]
loss.backward()
if (accumulating_batch_count % batch_size) == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model.zero_grad()
accumulating_batch_count += 1
input_tensor = None
return model
model_size = "small"
n_samples = 50000
df = pd.read_csv("texts.csv")
pretrained_name = f"rinna/japanese-gpt2-{model_size}"
dataset = CopyTexts(df, pretrained_name=pretrained_name, n_samples=n_samples)
tokenizer = T5Tokenizer.from_pretrained(pretrained_name)
tokenizer.do_lower_case = True
model = AutoModelForCausalLM.from_pretrained(pretrained_name)
model = train(dataset, model, batch_size=12)
torch.save(model.state_dict(), "modelname.model")
from typing import List
import readline
import sys
import torch
import torch.nn.functional as F
from tqdm import trange
from transformers import AutoModelForCausalLM, T5Tokenizer
def generate(
model,
tokenizer,
prompt,
entry_count=10,
entry_length=30, # maximum number of words
top_p=0.8,
temperature=1.0,
) -> List[str]:
model.eval()
generated_num = 0
generated_list = []
filter_value = -float("Inf")
eos_token = tokenizer.special_tokens_map["eos_token"]
eos_id = tokenizer.all_special_ids[tokenizer.all_special_tokens.index(eos_token)]
with torch.no_grad():
for _ in trange(entry_count):
entry_finished = False
prompt_token = tokenizer.encode(prompt, return_tensors="pt")
generated = prompt_token[prompt_token != eos_id].unsqueeze(0)
for _ in range(entry_length):
outputs = model(generated, labels=generated)
loss, logits = outputs[:2]
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
if next_token in tokenizer.encode(eos_token):
entry_finished = True
if entry_finished:
generated_num = generated_num + 1
output_list = list(generated.squeeze().numpy())
output_text = tokenizer.decode(output_list)
generated_list.append(output_text)
break
if not entry_finished:
output_list = list(generated.squeeze().numpy())
output_text = f"{tokenizer.decode(output_list)}"
generated_list.append(output_text)
return generated_list
model = "modelname.model"
model_size = "small"
if len(sys.argv) >= 2:
model = sys.argv[1]
model_size = sys.argv[2]
model = AutoModelForCausalLM.from_pretrained(f"rinna/japanese-gpt2-{model_size}")
model.load_state_dict(
torch.load(f"path/to/model/{model}", map_location=torch.device("cpu"))
)
tokenizer = T5Tokenizer.from_pretrained(f"rinna/japanese-gpt2-{model_size}")
tokenizer.do_lower_case = True
while True:
try:
input_text = input("\ninput: ")
if input_text == "":
continue
generated_copy = generate(model, tokenizer, input_text, entry_count=1)
print(generated_copy[0])
except KeyboardInterrupt:
print("exit.")
exit()
機械学習による広告コピー生成のまとめ
アドクロールデータを用いて簡単に広告らしいテキストを作れることが分かりました。今回は特別な工夫をしているわけでもなく、モデルも小規模なものを使用したため、やり方次第では更に自然なテキストを生成できるはずです。
CONTACT
お問い合わせ・ご依頼はこちらから