勉強の記録

機械学習、情報処理について勉強した事柄など

Union Findの問題をpythonで解く

グループ分けして、適宜グループを合体させるような局面で登場するデータ構造。

全ノードに対してグループ名自体を持つようすると、合体のたびに少なくとも一方のグループをすべて書き換える必要があって、時間がかかる。ツリー構造で持っておいて、グループの合体をするときは、一方のrootを他方のrootにぶらさげる形。同じグループかを判定するときは、木を辿っていってrootが共通か確認する。

Union Find木自体は https://atc001.contest.atcoder.jp/tasks/unionfind_a

あとは、こことか。 https://qiita.com/ofutonfuton/items/c17dfd33fc542c222396

AtCoder Typical Contests 001 - B - Union Find

atc001.contest.atcoder.jp

N, Q = list(map(int, input().split()))
PAB = [list(map(int, input().split())) for _ in range(Q)]

roots = {}

def find_root(a, roots=roots):
  if a not in roots:
    return a
  else:
    roots[a] = find_root(roots[a])
  return roots[a]

def unite(a, b, roots=roots):
  root_a = find_root(a)
  root_b = find_root(b)
  if root_a == root_b:
    return
  else:
    roots[root_a] = root_b
    return

def is_union(a, b, roots=roots):
  if find_root(a) == find_root(b):
    return True
  else:
    return False

for p,a,b in PAB:
  if p==0:
    unite(a, b)
  if p==1:
    if is_union(a, b):
      print('Yes')
    else:
      print('No')

なんとなくrootsは辞書にしたが、リストでも良い。時間制限が厳しければリストのほうが良いが、インデックスに注意する。リストのときはNoneを入れるなり、roots[a]=aのときをルートと判別するなりする。

ポイント(?)はfind_root()のroots[a] = find_root(roots[a])の行で、スライドにある繋ぎ直す高速化を実現している。ここで再帰呼び出しを使うことで、自前でスタックなどする必要がない。

AtCoder Beginner Contest 087 - D - People on a Line

これも2点間の相対位置が決まっていくので、互いの相対位置が決まっている群をグループとすると、情報が1個加えられるごとにその端点を含むグループが新たにグループに融合される。

上記のTypical問題では同一グループかの情報のみ保持すればよかったが、今回はそれに加えて、rootからの距離、という情報を持っておく必要がある。

N, M = list(map(int, input().split()))
LRD = [list(map(int, input().split())) for _ in range(M)]

roots = [(i,0) for i in range(N+1)] #(root, distance_from_i_to_root)

def find_root(a, roots=roots):
  if roots[a][0] == a:
    return (a,0)
  else:
    r, d = find_root(roots[a][0])
    roots[a] = (r, roots[a][1] + d)
  return roots[a]

def unite(l, r, d, roots=roots):
  root_l, d_l = find_root(l)
  root_r, d_r = find_root(r)
  if root_l == root_r:
    if d_l == d + d_r:
      return True
    else:
      return False
  else:
    roots[root_l] = (root_r, d +d_r - d_l)
    return True

for l,r,d in LRD:
  if not unite(l,r,d):
    print('No')
    break
else:
  print('Yes')

AtCoder Beginner Contest 120 - D - Decayed Bridges

atcoder.jp

減らしていくのは大変なので、逆に1個ずつ橋でつなげていくことを考える。逆順にするのがやや面倒だが、これも橋でつなぐと2つのグループが合体するので、今までの問題と同じ。

N, M = list(map(int, input().split()))
AB = [list(map(int, input().split())) for _ in range(M)]

roots = [(i,1) for i in range(N+1)] # root, num_islands_in_group

inconveniences = [0]*(M+1)
inconveniences[-1] = N*(N-1)//2 #すべての橋が崩落するとNC2

def find_root(a, roots=roots):
  if roots[a][0] == a:
    return roots[a]
  else:
    roots[a] = find_root(roots[a][0])
  return roots[a]

def reduce_inconvenience_with_unite(a, b, roots=roots):
  root_a, n_a = find_root(a)
  root_b, n_b = find_root(b)
  if root_a == root_b:
    return 0
  else:
    if n_a > n_b: #せっかく要素数の情報を持っているので大きい方に合流させる
      roots[root_a] = (root_a, n_a + n_b)
      roots[root_b] = (root_a, n_a + n_b)
    else:
      roots[root_a] = (root_b, n_a + n_b)
      roots[root_b] = (root_b, n_a + n_b)
    return n_a * n_b

for i, ab in enumerate(AB[:0:-1]): #実は1個めの橋は全く関係ない
  a, b = ab
  inconveniences[M-1-i] = inconveniences[M-i] - reduce_inconvenience_with_unite(a,b)

for inc in inconveniences[1:]:
  print(inc)

ちなみに、大きい方に合流の効率化を入れないと、そのままでは再帰の上限に達してREとなる(それでかなりハマった&RE→ACの人のコードを見てやっと気づいた)。スライドにあるランクの概念を導入するか、$sys.setrecursionlimit(int(1e7))$などとしてもよい。

たぶんほんとは

classで実装した方が良い.