個人因果効果をどうやって推論するか①
①と銘打っているが、続くかどうか。ほかに
- causal forest
- microsoft/ALICE
あたりも見ていきたいところ。
Estimating individual treatment effect: generalization bounds and algorithms
ICML 2017から。ICMLには毎年causalityのセクションがあり、機械学習を活用した因果推論について他にも面白そうな発表が多いので、これを数年前から追ってみようという企画。
枠組み
・因果効果の枠組みはRubinのpotential outcome model.
・いわゆる強く無視できる割当の仮定
このあたりは因果推論のデファクトか。introに述べられている通り、「強く無視できる割当」はデータからは判断することができず、ドメイン知識や変数間の因果関係の事前知識によって判断される必要がある。
概要
治療群の分布と、非治療群の分布が異なることがICE推定の難点。 これを、Integral Probability Metric (IPM)を使って上からおさえてやろうという発想。
IPMはざっくり言うと確率分布間の距離の指標らしく、 ・Maximum Mean Discrepancy ・Wasserstein distance の2つが用いられている。
ITE: τはY1, Y0のエラー+IPM項で上から押さえられる。
これをY0, Y1, IPMを別々に解くのではなく、総和を最適化すべくrepresentation learning (Bengio et al., 2013)の枠組みで学習している。
貢献
漸近一致性のみではなく、generalization-errorの程度を求めた。
randomised componentや操作変数を想定せず、観察研究の場面にも適用可能なモデル。
実験
CausalForestsやdoubly robustなど各種の因果推論の方法と比較。データセットはシミュレーションで生成したものとJobsというデータセット。 Jobsの方はrandom化されたcomponentで評価したのではっきりしない(という主張)だが、シミュレーションのものではもっとも良かったらしい。
感想
IPMの計算のところでLipschitz fuction、Hilbert spaceのあたりがでてきたのだけど、このあたりは理解の範囲外。
ざっくりいうとmulti-task learningのような枠組みか。IPWを最小化することでtreatmentと独立した「病状」といった表現を学習し、治療への効果はその表現から後段のネットワークで学習する。outcomeは実数値でbinaryの制約もなく、ネットワーク構造も比較的柔軟にモデリングできそう。
causal forestは信頼区間が広くなりがちなのが難点っぽいのだけど、こちらのほうが押さえられるのであれば嬉しい。
terminology
asymptotic consistency:漸近一致性
Microsoft/InterpretMLの中身
一言でいうと
ある一つの特徴量のみのdecision treeを作って残差を予測というのをcyclicに加えていくモデル
知ったきっかけ
Microsoftが、解釈性が高くかつ精度も高いBoostingのモデル(Explainable Boosting Machine=EBM)をOSSで公開。LIMEやSHAPといった解釈を行うための手法も搭載しており、モデルの挙動も簡単に可視化することができる。https://t.co/ghQuLwTNCj
— piqcy (@icoxfog417) 2019年6月27日
解釈可能性が高くモデルの修正も容易な新たなboosting machineとして話題のInterpretML。
本家githubによると、
- EBM is a fast implementation of GA2M. Details on the algorithm can be found here.
ということらしいので、まずはhereのリンクされている論文から。
Intelligible Models for HealthCare: Predicting Pneumonia Risk and Hospital 30-day Readmission (Rich C, KDD'15)
この論文自体にアルゴリズムの詳細はあまり載っていなかった。
書いてあるのは、以下のように二次の交互作用まで加味した一般化加法的モデルを用いていることと、baggingして100個のモデルを作ってから平均したよ、ということくらい。加法的モデルの中身であるfについては[5]の論文を読め、とのこと。gradient boosting of bagging of shallow regression treesらしい。[5]に比べると、二次の交互作用項を考慮してその数KをCVで求め、モデルに組み込んでいる部分が新しい点。
あとは、肺炎データセットにおいて、いかにintelligibleなモデルが構築できたかの実例が提示されているので気になる方は論文を見て欲しい。
(introductionではいわゆるuntreatedな場合のdisease risk scoreを推定するような論調で書かれているのに、求めているのはtreatmentを一切考慮しない説明変数のみの周辺確率なのはちょっとどうなの、という印象)
Intelligible Models for Classification and Regression (Yin L+, KDD'12)
二次の交互作用を加えていないboosted bagged treesによるgeneralized additive model (GAM)を他のsplineなどによるGAMや単純なLogistic regressionなどと比較している。結局、boosted bagged treesによるGAMが一番良かったとのこと。
boosted bagged treesによるGAM構築のアルゴリズム
- 各特徴量を一周するのを1 iterationとして、M iteration繰り返す
- 各特徴量については、その時点での残差を予測する一変数からのdecision treeを構築、これをfunctionに加える(boosting)
- この残差を予測するdecision treeをbaggingしてensembleしたものにする(bagging)
- これをある一定の収束条件を満たすまで繰り返す。
あくまで相関であることに注意
最初の論文の5.7でも協調されているが、あくまで変数間の相関をとらえており原則として因果推論に用いるモデルではないことに注意。
これSHAPやLIMEあたりも求められるようになっているのだけど、特徴量を計算していく順番に大きく依存するのでは。順番に一巡する代わりに一斉に求めて1/K倍にするなどすると、収束は遅くなるけどその分特徴量の重要性については順序のバイアスなく求められそうな気もしなくもない。
精度が出るdatasetは限定されるかも(私見)
あくまで二次の交互作用までしかとらえていないので、精度がでるdatasetは限定されるかもしれない。githubページに載っている精度は交互作用を加味していないEBMの精度っぽい。いずれも交互作用のない1次のlogistic regressionでAUROCが0.9前後となるような比較的単純なdatasetで、random forestよりもlogistic regressionが有効なものも多い。より複雑なdatasetではXGBoostなどが予測性能として上回ることもあるのでは。
個々の特徴量に対するfunctionだけみてそれを落とすことの妥当性は?
主張のひとつに、一般化加法的モデルで各特徴量についてのfunctionは平均0になるようにモデル化しているので、明らかにおかしな推定式になったものは取り除ける=操作できるのがメリットというのがあった。例としては肺炎予測モデルにおいて、喘息があるとかえって死亡率が低い(とモデル化される)といった特徴量。筆者は喘息ありだと最初からICUに入って集中治療を受けるからでは?と推察しており、この項目のfunctionを加法的モデルから取り除くことを提案していた。
これ喘息だから良いけど、他の項目と交絡があるような因子だと、マイナスになったものを取り除くだけでは逆のバイアスを生みかねない気がする。
まとめ
医療系のデータでは割とうまくいきそうな気もする。Microsoftからなのに比較対象がなんでXGBoostなんですかね.
モデルの不確実性:Analyzing the Role of Model Uncertainty for Electronic Health Records
医療などの最終判断が極めて大きな意味を持つタスクにおいては最終的なモデルの出力だけではなく、それがどの程度確信を持った出力なのかが重要。そこにはdata uncertaintyと、model uncertaintyがある。
同グループから既報のMIMIC-IIIというICU電子カルテ公開データからRNNで死亡率を予測するタスクにおいて、
・random seedを振った200個のRNNから各サンプルごとの予測値の経験分布を得る
・ベイズ深層学習を用いてパラメータの分布を直接学習する
という2つの方法でmodel uncertaintyを検証。その分布はよく似ており、サンプルによってsdの広い分布と、狭い分布があり、サンプルごとのmodel uncertaintyの違いを反映しているのではという結果。ベイズ深層学習モデルから、どのfeatureがmodel uncertaintyに寄与していたかをみる解析では、出現頻度の低い単語(モデルに自然言語処理モデルも内包している)のパラメータほどmodel uncertaintyへの寄与が大きい傾向にあったとのこと。
以下私見&コメント
このmodel uncertaintyの定義だと、figure2の通りmeanが0.5に近いほど大きくなる。Fig3右、Fig5右についてはmeanの影響が無視できないので、meanが等しいもの同士の比較をもっと載せてほしかった。あとはサンプルごとのmodel uncertaintyの推定が重要という主題とはずれるが、これだけ出力がブレるならensembleしないと怖いな、という印象。ensembleした上で、閾値を何割のモデルが上回っているかで切るなどするとうまくmodel uncertaintyを取り込んだ判断ができそう。
ベイズ深層学習はまだよく分かってないので、勉強したい。
このRNNモデルは確か電子カルテ中のデータをほぼ全部入れたみたいな巨大なモデルだったので、それを200個気軽に実験できるのはさすがgoogleといったところ。
なお、MIMIC-IIIの概観については以下のQiitaが手っ取り早い。
LeetCode 1036 - Escape a Large Mazeの別解(?)
https://leetcode.com/contest/weekly-contest-134/submissions/detail/225394673/
問題設定自体は普通の迷路と同じだが,迷路のサイズが106と巨大.
普通にdfsやbfsをしていたのでは到底終わらない.
逆にblocked cellは合計200個以下という制約がある
考えたこと
どういう場合にたどり着けないか.
→広い空間にぱらぱらblockされていれば容易に辿り着ける.
→たどり着けないのは,sourceとtargetが分断されている(局所空間に囲い込まれている)とき.
→blockedは高々200個しかないから,最大の領域をblockできるのは角に斜めに並べたときで,その時中身は19800(=1+2+...+199) cell.
→よってsourceから辿り着ける領域の広さを求め,19800を超えればblockedの魔の手を逃れて広い空間に辿り着けることができる.targetについても同様.
→同じ局所空間内に囲い込まれているケースを除外して,正解が得られる.
当初はblock可能な最大領域を10000と勘違いして提出したのだが,それでもacceptedになってしまったので想定解ではなさそう.
より高速な解法
targetからmanhattan距離で200離れることができれば,必ず逃れられるので,dfsで200超えられるように探索する.多くの場合上記の20000cell探索しなくても解が得られるのでその方が早い.
汎用的な解法
blockedとblockedの間の空白空間はそれが何列あっても1列であっても本質的には変わりない.なので題に登場する座標と前後の一列を保つように座標圧縮してしまえば,通常の迷路の問題と同じアルゴリズムで解ける.
問題となるテストケース
一応反例testcaseとして投稿してみた.
[[0, 199], [1, 198], [2, 197], [3, 196], [4, 195], [5, 194], [6, 193], [7, 192], [8, 191], [9, 190], [10, 189], [11, 188], [12, 187], [13, 186], [14, 185], [15, 184], [16, 183], [17, 182], [18, 181], [19, 180], [20, 179], [21, 178], [22, 177], [23, 176], [24, 175], [25, 174], [26, 173], [27, 172], [28, 171], [29, 170], [30, 169], [31, 168], [32, 167], [33, 166], [34, 165], [35, 164], [36, 163], [37, 162], [38, 161], [39, 160], [40, 159], [41, 158], [42, 157], [43, 156], [44, 155], [45, 154], [46, 153], [47, 152], [48, 151], [49, 150], [50, 149], [51, 148], [52, 147], [53, 146], [54, 145], [55, 144], [56, 143], [57, 142], [58, 141], [59, 140], [60, 139], [61, 138], [62, 137], [63, 136], [64, 135], [65, 134], [66, 133], [67, 132], [68, 131], [69, 130], [70, 129], [71, 128], [72, 127], [73, 126], [74, 125], [75, 124], [76, 123], [77, 122], [78, 121], [79, 120], [80, 119], [81, 118], [82, 117], [83, 116], [84, 115], [85, 114], [86, 113], [87, 112], [88, 111], [89, 110], [90, 109], [91, 108], [92, 107], [93, 106], [94, 105], [95, 104], [96, 103], [97, 102], [98, 101], [99, 100], [100, 99], [101, 98], [102, 97], [103, 96], [104, 95], [105, 94], [106, 93], [107, 92], [108, 91], [109, 90], [110, 89], [111, 88], [112, 87], [113, 86], [114, 85], [115, 84], [116, 83], [117, 82], [118, 81], [119, 80], [120, 79], [121, 78], [122, 77], [123, 76], [124, 75], [125, 74], [126, 73], [127, 72], [128, 71], [129, 70], [130, 69], [131, 68], [132, 67], [133, 66], [134, 65], [135, 64], [136, 63], [137, 62], [138, 61],[139, 60], [140, 59], [141, 58], [142, 57], [143, 56], [144, 55], [145, 54], [146, 53], [147, 52], [148, 51], [149, 50], [150, 49], [151, 48], [152, 47], [153, 46], [154, 45], [155, 44], [156, 43], [157, 42], [158, 41], [159, 40], [160, 39], [161, 38], [162, 37], [163, 36], [164, 35], [165, 34], [166, 33], [167, 32], [168, 31], [169, 30], [170, 29], [171, 28], [172, 27], [173,26], [174, 25], [175, 24], [176, 23], [177, 22], [178, 21], [179, 20], [180, 19], [181, 18], [182, 17], [183, 16], [184, 15], [185, 14], [186, 13], [187, 12], [188, 11], [189, 10], [190, 9], [191, 8], [192, 7], [193, 6], [194, 5], [195, 4], [196, 3], [197, 2], [198, 1], [199, 0]] [0,0] [199,199]
Google Code Jam Round1B 通過
Round 1Aは通過できなかったが,1Bで通過できた.深夜ラウンド辛い....日本は幸い翌日が休日なので参加できた.
Visibleを通ってUnvisibleを通らないコードもUnvisible正解時のpenaltyになるので,少し考えてunvisibleも通せそうならその解法でいった方が良い.interaction problemのデバッグがよくわからん....
Code Jam - Google’s Coding Competitions
Question 1. Manhattan Crepe Cart
問題概略:南北東西それぞれQ+1本(西から0-Q, 南から0-Q番)の通りと交差点からなる.P人の人が各々(Xi, Yi)の交差点にいてDi(W,E,N,S)の方向へ向かって歩いている.ほとんどの人はマンハッタン距離に沿って,(X, Y)にあるクレープ屋さんを目指しているらしい.一番多くの人が向かっていることになるX,Yの座標のうち最も西南のものを求めよ.(制約:Q<=105, P<500,街の端にいるとき街の外は向いていない)
from itertools import accumulate def Solution(P,Q,XYD): X = [0]*(Q+1) Y = [0]*(Q+1) for x,y,d in XYD: if d=='W': X[0] += 1 X[x] -= 1 elif d=='E': X[x+1] += 1 elif d=='N': Y[y+1] += 1 elif d=='S': Y[0] += 1 Y[y] -= 1 X = list(accumulate(X)) Y = list(accumulate(Y)) minx=0 for x in range(Q+1): if X[x]<X[minx]: minx=x miny=0 for y in range(Q+1): if Y[y]<Y[miny]: miny=y return x,y T = int(input()) for t in range(T): P, Q = list(map(int, input().split())) XYD = [list(map(int, input().split())) for _ in range(P)] x, y = Solution(P,Q,XYD) print('Case #{}: {} {}'.format(t+1,x,y))
xi, yiの人がWに向かっているなら,x < xiの全領域が候補でyは関係ない.よってx, yを別に考えることにして,x < xiにフラグを立てて,全員分の和を取れば良い.愚直にやると遅いO(N2)ので,領域の端点をそれぞれ+1, -1しておいて最後に累積和を取るとO(N).この累積和が最大となるxのうち一番西のものを答えれば良い.素のpythonでちょこちょこ書いたが,numpyをimportしてargminをとっても良いかも.
Question2. Draupnir
問題概要:Odirさんは1日毎に倍増するリングをR[1]個,2日ごとに倍増するリングをR[2]個,...,6日ごとに倍増するリングをR[5]個持っている(0<=R[i]<=100, 1<=sum(R)).あなたはある日(1~500)日目における合計の指輪の個数(mod 263)をW回(tests et1: 6回, test set: 2回)まで尋ねることができる.Rを求めよ.
def Solution(n42, n210): R = [0]*6 R[3] = (n210>>52) & 0b1111111 R[4] = (n210>>42) & 0b1111111 R[5] = (n210>>35) & 0b1111111 n42 = n42 - R[3]*2**10 - R[4]*2**8 - R[5]*2**7 R[0] = (n42>>42) & 0b1111111 R[1] = (n42>>21) & 0b1111111 R[2] = (n42>>14) & 0b1111111 return R T,W = list(map(int, input().split())) for t in range(T): print(42, flush=True) n42 = int(input()) print(210, flush=True) n210 = int(input()) R = Solution(n42, n210) print(' '.join([str(r) for r in R]), flush=True) ret = int(input()) if ret == -1: raise
キレイに解けてよかった.ビット演算の問題はじめてまともに解けた気がする.100 < 27=128なので,それ以上差が開く日を聞けば良い.mod 291以上とかなら,n=42の代わりにn=84を聞けば,途中でn42 = n42 - R[3]*2**10 - R[4]*2**8 - R[5]*2**7
が要らない.
終わってから気づいたけど,REってWAのことか....
Question.3 Fair Fights
問題概要:数列Cと数列Dが与えられる.その部分列CL~CRの最大値とDL~DRの最大値について,差の絶対値がK以下となるようなL, Rの組み合わせの数を求めよ.(制約:0<=Ci,Di<=105, K<= 105, test case 1: n<=100, test case 2: n<=105)
class SegTree(): def __init__(self, func, A, identity=1e9): self._func = func self._N = 1 while self._N < len(A): self._N *= 2 # 長さ 7 なら N=3 (2^3=8) self._identity = identity # maxなら-1e9, sumなら0, productなら1 self._tree = [self._identity] * (2*self._N -1) self._tree[self._N-1:self._N-1+ len(A)] = A #最下段に元の配列を for i in range(self._N-2, -1, -1): self._calc_single_node(i) ######## # N=3: # # 0000 # # 1122 # # 3456 # ######## self._node_range = [(i-self._N+1,i-self._N+2) for i in range(2*self._N-1)] #right_exclusive for i in range(self._N-2, -1, -1): self._node_range[i] = (self._node_range[2*i+1][0], self._node_range[2*i+2][1]) def _calc_single_node(self,i): self._tree[i] = self._func(self._tree[2*i+1], self._tree[2*i+2]) def point_update(self,i,x): # A[i] = x (addではないので注意) j = self._N-1+i self._tree[j] = x while j>0: j = (j-1)//2 self._calc_single_node(j) def range_query(self, l, r, on=0): # reduce(func, A[l:r]) # onはノード[on]上と重なる部分の,ということ onl, onr = self._node_range[on] # ノード[on]上はすべて覆う if l<=onl and onr<=r: return self._tree[on] # 重なりなし elif onr<=l or r<=onl: return self._identity else: return self._func(self.range_query(l, r, on=2*on+1), self.range_query(l, r, on=2*on+2)) from itertools import combinations def Solution(N,K,C,D): C_st = SegTree(max, C, identity=0) D_st = SegTree(max, D, identity=0) answer = 0 for l,r in combinations(range(N+1), 2): C_max = C_st.range_query(l,r) D_max = D_st.range_query(l,r) if abs(C_max-D_max)<=K: answer += 1 return answer T = int(input()) for t in range(T): N, K = list(map(int, input().split())) C = list(map(int, input().split())) D = list(map(int, input().split())) answer = Solution(N,K,C,D) print('Case #{}: {}'.format(str(t+1), str(answer)))
N=105で制限時間は30秒なのでO(N2)は通らない.test set1はok.test set2はTLEと予想通り.class SegTreeは時間内に実装したわけではなく,自作コード集からのコピペ.ただ,解説によるとn=100なので愚直にmaxとっても通るらしい.
なお,test set2を通すためにはこれだけではだめで,選ばれるCiに注目してCL~Ci-1のmaxがCi未満となるような最大のL,DL~DiのmaxがCi+K以下となる最大のL,DL~DiのmaxがCi-K以上となる最小のLを,それぞれ二分探索+上記Seg木のrange_queryで探す,ということらしい.
教訓:何を固定するかをよく考える.最大値・最小値などの極端な値に注目すること(L固定,R固定しか考えてなかった).Seg木はそれで終わりではなく二分探索と組み合わせられると強力.二列あるときはあえて非対称に扱うことも考える.
二分探索のための遅延評価リストを作った
標準ライブラリに二分探索が用意されているが,sortと違ってkeyが指定できない.
毎回二分探索を書いても良いが,境界条件などで足を救われがち. よって,クラス継承の練習を兼ねてgetitemしたときだけfuncで遅延評価されるリストを作った.
ほぼlistを継承してgetitem()だけ変更しているのでそのままbisect_lightなどに渡せる.あまりそのようなことはないが元のリストはsortされている必要はなく,funcの結果として単調増加になっていればok.
class lazy_eval_list(list): # when used with bisect method, func must be monotonically increasing function. # not compatible with insort method. def __init__(self, func, a, memorize=True): super().__init__(a) self._func = func self._eval = {} def __getitem__(self, key): if type(key) is int: return self.__evaluate__(super().__getitem__(key)) elif type(key) is slice: return [self.__evaluate__(item) for item in super().__getitem__(key)] def __evaluate__(self, item): if item not in self._eval: self._eval[item] = self._func(item) return self._eval[item] def __repr__(self): return 'lazy_eval_list({}, {})'.format(self._func, super().__repr__())
上記は,一度計算した結果を保持していたり,スライスでの入力に対応したりしているが,より簡単には以下.
競プロで使うならこちらが良いかも.
class lazy_eval_list(list): def __init__(self, func, a): super().__init__(a) self._func = func def __getitem__(self, key): return self._func(super().__getitem__(key))
いずれにせよ,
ll = lazy_eval_list(lambda x:x**2, [10,20,30]) import bisect bisect.bisect_left(ll, 400) >>> 1
と,インデックスを得ることができる.
なお,insert, appendやextendは問題ないが,いくつかの場面でリスト自体を返すので,ll = ll + [100]や ll+=[100],ll=ll*3などはただのリストになってしまう.
bisect.insort_left(a, x) も,内部でxが直接aと比較されてしまうので意図した動きにならない.
GitHub - tmitanitky/lazy_eval_list: lazy evaluation list for bisect
Union Findの問題をpythonで解く
グループ分けして、適宜グループを合体させるような局面で登場するデータ構造。
全ノードに対してグループ名自体を持つようすると、合体のたびに少なくとも一方のグループをすべて書き換える必要があって、時間がかかる。ツリー構造で持っておいて、グループの合体をするときは、一方のrootを他方のrootにぶらさげる形。同じグループかを判定するときは、木を辿っていってrootが共通か確認する。
Union Find木自体は https://atc001.contest.atcoder.jp/tasks/unionfind_a
あとは、こことか。 https://qiita.com/ofutonfuton/items/c17dfd33fc542c222396
AtCoder Typical Contests 001 - B - Union Find
N, Q = list(map(int, input().split())) PAB = [list(map(int, input().split())) for _ in range(Q)] roots = {} def find_root(a, roots=roots): if a not in roots: return a else: roots[a] = find_root(roots[a]) return roots[a] def unite(a, b, roots=roots): root_a = find_root(a) root_b = find_root(b) if root_a == root_b: return else: roots[root_a] = root_b return def is_union(a, b, roots=roots): if find_root(a) == find_root(b): return True else: return False for p,a,b in PAB: if p==0: unite(a, b) if p==1: if is_union(a, b): print('Yes') else: print('No')
なんとなくrootsは辞書にしたが、リストでも良い。時間制限が厳しければリストのほうが良いが、インデックスに注意する。リストのときはNoneを入れるなり、roots[a]=aのときをルートと判別するなりする。
ポイント(?)はfind_root()のroots[a] = find_root(roots[a])の行で、スライドにある繋ぎ直す高速化を実現している。ここで再帰呼び出しを使うことで、自前でスタックなどする必要がない。
AtCoder Beginner Contest 087 - D - People on a Line
これも2点間の相対位置が決まっていくので、互いの相対位置が決まっている群をグループとすると、情報が1個加えられるごとにその端点を含むグループが新たにグループに融合される。
上記のTypical問題では同一グループかの情報のみ保持すればよかったが、今回はそれに加えて、rootからの距離、という情報を持っておく必要がある。
N, M = list(map(int, input().split())) LRD = [list(map(int, input().split())) for _ in range(M)] roots = [(i,0) for i in range(N+1)] #(root, distance_from_i_to_root) def find_root(a, roots=roots): if roots[a][0] == a: return (a,0) else: r, d = find_root(roots[a][0]) roots[a] = (r, roots[a][1] + d) return roots[a] def unite(l, r, d, roots=roots): root_l, d_l = find_root(l) root_r, d_r = find_root(r) if root_l == root_r: if d_l == d + d_r: return True else: return False else: roots[root_l] = (root_r, d +d_r - d_l) return True for l,r,d in LRD: if not unite(l,r,d): print('No') break else: print('Yes')
AtCoder Beginner Contest 120 - D - Decayed Bridges
減らしていくのは大変なので、逆に1個ずつ橋でつなげていくことを考える。逆順にするのがやや面倒だが、これも橋でつなぐと2つのグループが合体するので、今までの問題と同じ。
N, M = list(map(int, input().split())) AB = [list(map(int, input().split())) for _ in range(M)] roots = [(i,1) for i in range(N+1)] # root, num_islands_in_group inconveniences = [0]*(M+1) inconveniences[-1] = N*(N-1)//2 #すべての橋が崩落するとNC2 def find_root(a, roots=roots): if roots[a][0] == a: return roots[a] else: roots[a] = find_root(roots[a][0]) return roots[a] def reduce_inconvenience_with_unite(a, b, roots=roots): root_a, n_a = find_root(a) root_b, n_b = find_root(b) if root_a == root_b: return 0 else: if n_a > n_b: #せっかく要素数の情報を持っているので大きい方に合流させる roots[root_a] = (root_a, n_a + n_b) roots[root_b] = (root_a, n_a + n_b) else: roots[root_a] = (root_b, n_a + n_b) roots[root_b] = (root_b, n_a + n_b) return n_a * n_b for i, ab in enumerate(AB[:0:-1]): #実は1個めの橋は全く関係ない a, b = ab inconveniences[M-1-i] = inconveniences[M-i] - reduce_inconvenience_with_unite(a,b) for inc in inconveniences[1:]: print(inc)
ちなみに、大きい方に合流の効率化を入れないと、そのままでは再帰の上限に達してREとなる(それでかなりハマった&RE→ACの人のコードを見てやっと気づいた)。スライドにあるランクの概念を導入するか、$sys.setrecursionlimit(int(1e7))$などとしてもよい。
たぶんほんとは
classで実装した方が良い.