多腕バンディット問題の実装

こんばんは。
今日のことなのですが、アルバイト先である案件を解決するために問題を多腕バンディット問題に帰着させる必要性にかられました。
今までバンディット問題などは証明が難しい印象があり食わず嫌いだったのですが、調べてみたら思っていたよりも簡単(あくまでアルゴリズムは。bound証明などはもちろん僕の手にはおえません(笑))だったので、家でも状況を多腕バンディット問題に戻してサラッと実装してみました。

多腕バンディット問題とは

多腕バンディット問題とは、機械学習の分野でもよく取り沙汰される問題です。
僕たちがゲーセンのスロットマシーンで遊ぶことを考えます。
ただし簡単のため、スロットマシーンで1回遊ぶことで得られる報酬は当たりなら1円、はずれなら0円です。
ゲーセンにはスロットマシーンが複数台あり、店員さんが小細工しているので各台ごとに1円がもらえる確率は違います。
僕たちは出来るだけたくさんお金をゲットしたいので、複数あるスロットマシーンの中で一番当たりが出るもので最初から最後までずっと遊び続ければ、利益は最大化されますよね?

でもここで問題があります。
僕たちはどのスロットマシーンが、どれくらいの確率で当たりになるか知らないのです。
なので遊んでいる回数が少ないうちは、どのマシンが利益が大きくなるのかを調べるために、多少の犠牲を払ってでもあらゆるマシンで遊ぶ必要があります。これを一般に探索といいます。
一方で、探索ばかりしていては当たる確率の低いマシンでも遊ぶため、利益を最大化するためにはどのマシンが当たる確率が高くなるかわかったら、そのマシンに集中するべきです。これを一般に活用といいます。

おわかりだと思いますが、探索と活用はトレードオフの関係にあります。
探索ばかりしていては利益を大きくできませんし、活用ばかりしていたらひょっとしたらもっと良いマシンがあるかもしれないのにもったいないって状況になるかもしれません。

探索と活用のトレードオフをいい塩梅にするアルゴリズムがいくつか提案されています。
以下では代表的な3つ、「トンプソンサンプリング」「UCB」「ε-greedy」を実装しています。
詳細は省きますが、多分コードを見てすぐわかるくらい単純な数式で導かれます。
slot_machine関数の中身や、スロットで遊ぶ回数を変えながら遊んでみると地味に楽しいですよ。
1分くらいで飽きますけど

# -*- coding: utf-8 -*-
import numpy as np
import scipy.stats

def slot_machine(n):
	# 人工スロットマシーン。
	# nは引くスロットのラベル。
	# 引くスロットの当たる確率をもとに、当たりならば1を、外れならば0を出力
	prob_list = [0.3, 0.5, 0.7, 0.5, 0.3]
	prob = prob_list[n]
	return scipy.stats.bernoulli.rvs(prob, size=1)

class BanditProblem:
	def __init__(self, n_trial=np.zeros(5), n_hit=np.zeros(5)):
		self.n_trial = n_trial  # 各スロットマシーンの試行回数
		self.n_hit = n_hit  # 各スロットマシーン試行回数のうち当たった回数

	def add_new_information(self, new_arm, hit_or_miss):
		# 新しく試行をしたときに、結果を追加する
		# new_armは新しくスロットを引いたマシーンのラベル、hit_or_missは当たりかはずれか。
		# 当たりならば1、はずれならば0を入れる
		self.n_trial[new_arm] += 1
		self.n_hit[new_arm] += hit_or_miss

	def decide_next_slot(self, method = 'ThompsonSampling', epsilon=0.2):
		# 次にどのスロットを引いたらいいか、スロットのラベルを出力する。
		#  アルゴリズムは「トンプソンサンプリング」「UCB方策」「ε-greedy探索」から選べる。
		if np.count_nonzero(self.n_trial) !=5:
			# もしまだ引いたことがないスロットがあるのなら、それを引く。
			indexes=np.where(self.n_trial == 0)[0]
			return np.random.choice(indexes)

		else:
			if method == 'ThompsonSampling':
				return self.thompson_sampling()

			elif method == 'UCB':
				return self.UCB()

			else:
				return self.epsilon_greedy(epsilon)

	def thompson_sampling(self):
		alpha = self.n_hit + 1
		beta = self.n_trial - self.n_hit +1
		prob = np.random.beta(alpha, beta)
		return np.argmax(prob)

	def UCB(self):
		mu = self.n_hit / self.n_trial
		UCB=mu + np.sqrt(2*np.log(self.n_trial.sum()) / self.n_trial)
		return np.argmax(UCB)

	def epsilon_greedy(self, epsilon):
		if np.random.rand() < epsilon:
			return np.argmin(self.n_trial)

		else:
			return np.argmax(self.n_hit / self.n_trial)

if __name__ == '__main__':
	n_search=30  # スロットを引く回数
	bp=BanditProblem()
	for i in range(n_search):
		next_arm = bp.decide_next_slot()  # 次のスロットマシーンを選ぶ
		result = slot_machine(next_arm)  # スロットマシーンを引く
		bp.add_new_information(next_arm, result)  # 得られた情報の追加
	print(bp.n_hit)
	print(bp.n_trial)