勉強の記録

機械学習、情報処理について勉強した事柄など

二分探索のための遅延評価リストを作った

docs.python.org

標準ライブラリに二分探索が用意されているが,sortと違ってkeyが指定できない.

毎回二分探索を書いても良いが,境界条件などで足を救われがち. よって,クラス継承の練習を兼ねてgetitemしたときだけfuncで遅延評価されるリストを作った.

ほぼlistを継承してgetitem()だけ変更しているのでそのままbisect_lightなどに渡せる.あまりそのようなことはないが元のリストはsortされている必要はなく,funcの結果として単調増加になっていればok.

class lazy_eval_list(list):
    # when used with bisect method, func must be monotonically increasing function.
    # not compatible with insort method.
    def __init__(self, func, a, memorize=True):
        super().__init__(a)
        self._func = func
        self._eval = {}
        
    def __getitem__(self, key):
        if type(key) is int:
            return self.__evaluate__(super().__getitem__(key))
        elif type(key) is slice:
            return [self.__evaluate__(item) for item in super().__getitem__(key)]
        
    def __evaluate__(self, item):
        if item not in self._eval:
            self._eval[item] = self._func(item)
        return self._eval[item] 
    
    def __repr__(self):
        return 'lazy_eval_list({}, {})'.format(self._func, super().__repr__())

上記は,一度計算した結果を保持していたり,スライスでの入力に対応したりしているが,より簡単には以下.

競プロで使うならこちらが良いかも.

class lazy_eval_list(list):
    def __init__(self, func, a):
        super().__init__(a)
        self._func = func
        
    def __getitem__(self, key):
        return self._func(super().__getitem__(key))

いずれにせよ,

ll = lazy_eval_list(lambda x:x**2, [10,20,30])

import bisect
bisect.bisect_left(ll, 400)
>>> 1

と,インデックスを得ることができる.

なお,insert, appendやextendは問題ないが,いくつかの場面でリスト自体を返すので,ll = ll + [100]や ll+=[100],ll=ll*3などはただのリストになってしまう.

bisect.insort_left(a, x) も,内部でxが直接aと比較されてしまうので意図した動きにならない.

GitHub - tmitanitky/lazy_eval_list: lazy evaluation list for bisect

Union Findの問題をpythonで解く

グループ分けして、適宜グループを合体させるような局面で登場するデータ構造。

全ノードに対してグループ名自体を持つようすると、合体のたびに少なくとも一方のグループをすべて書き換える必要があって、時間がかかる。ツリー構造で持っておいて、グループの合体をするときは、一方のrootを他方のrootにぶらさげる形。同じグループかを判定するときは、木を辿っていってrootが共通か確認する。

Union Find木自体は https://atc001.contest.atcoder.jp/tasks/unionfind_a

あとは、こことか。 https://qiita.com/ofutonfuton/items/c17dfd33fc542c222396

AtCoder Typical Contests 001 - B - Union Find

atc001.contest.atcoder.jp

N, Q = list(map(int, input().split()))
PAB = [list(map(int, input().split())) for _ in range(Q)]

roots = {}

def find_root(a, roots=roots):
  if a not in roots:
    return a
  else:
    roots[a] = find_root(roots[a])
  return roots[a]

def unite(a, b, roots=roots):
  root_a = find_root(a)
  root_b = find_root(b)
  if root_a == root_b:
    return
  else:
    roots[root_a] = root_b
    return

def is_union(a, b, roots=roots):
  if find_root(a) == find_root(b):
    return True
  else:
    return False

for p,a,b in PAB:
  if p==0:
    unite(a, b)
  if p==1:
    if is_union(a, b):
      print('Yes')
    else:
      print('No')

なんとなくrootsは辞書にしたが、リストでも良い。時間制限が厳しければリストのほうが良いが、インデックスに注意する。リストのときはNoneを入れるなり、roots[a]=aのときをルートと判別するなりする。

ポイント(?)はfind_root()のroots[a] = find_root(roots[a])の行で、スライドにある繋ぎ直す高速化を実現している。ここで再帰呼び出しを使うことで、自前でスタックなどする必要がない。

AtCoder Beginner Contest 087 - D - People on a Line

これも2点間の相対位置が決まっていくので、互いの相対位置が決まっている群をグループとすると、情報が1個加えられるごとにその端点を含むグループが新たにグループに融合される。

上記のTypical問題では同一グループかの情報のみ保持すればよかったが、今回はそれに加えて、rootからの距離、という情報を持っておく必要がある。

N, M = list(map(int, input().split()))
LRD = [list(map(int, input().split())) for _ in range(M)]

roots = [(i,0) for i in range(N+1)] #(root, distance_from_i_to_root)

def find_root(a, roots=roots):
  if roots[a][0] == a:
    return (a,0)
  else:
    r, d = find_root(roots[a][0])
    roots[a] = (r, roots[a][1] + d)
  return roots[a]

def unite(l, r, d, roots=roots):
  root_l, d_l = find_root(l)
  root_r, d_r = find_root(r)
  if root_l == root_r:
    if d_l == d + d_r:
      return True
    else:
      return False
  else:
    roots[root_l] = (root_r, d +d_r - d_l)
    return True

for l,r,d in LRD:
  if not unite(l,r,d):
    print('No')
    break
else:
  print('Yes')

AtCoder Beginner Contest 120 - D - Decayed Bridges

atcoder.jp

減らしていくのは大変なので、逆に1個ずつ橋でつなげていくことを考える。逆順にするのがやや面倒だが、これも橋でつなぐと2つのグループが合体するので、今までの問題と同じ。

N, M = list(map(int, input().split()))
AB = [list(map(int, input().split())) for _ in range(M)]

roots = [(i,1) for i in range(N+1)] # root, num_islands_in_group

inconveniences = [0]*(M+1)
inconveniences[-1] = N*(N-1)//2 #すべての橋が崩落するとNC2

def find_root(a, roots=roots):
  if roots[a][0] == a:
    return roots[a]
  else:
    roots[a] = find_root(roots[a][0])
  return roots[a]

def reduce_inconvenience_with_unite(a, b, roots=roots):
  root_a, n_a = find_root(a)
  root_b, n_b = find_root(b)
  if root_a == root_b:
    return 0
  else:
    if n_a > n_b: #せっかく要素数の情報を持っているので大きい方に合流させる
      roots[root_a] = (root_a, n_a + n_b)
      roots[root_b] = (root_a, n_a + n_b)
    else:
      roots[root_a] = (root_b, n_a + n_b)
      roots[root_b] = (root_b, n_a + n_b)
    return n_a * n_b

for i, ab in enumerate(AB[:0:-1]): #実は1個めの橋は全く関係ない
  a, b = ab
  inconveniences[M-1-i] = inconveniences[M-i] - reduce_inconvenience_with_unite(a,b)

for inc in inconveniences[1:]:
  print(inc)

ちなみに、大きい方に合流の効率化を入れないと、そのままでは再帰の上限に達してREとなる(それでかなりハマった&RE→ACの人のコードを見てやっと気づいた)。スライドにあるランクの概念を導入するか、$sys.setrecursionlimit(int(1e7))$などとしてもよい。

たぶんほんとは

classで実装した方が良い.

Google CodeJamをVSC上でデバッグする

Qualification RoundはAtCoderのコードテストを間借りしてコーディングしたのだけど、やっぱりVisual Studio Codeを使いたい。

input()だけ書き換えて、printと想定解を目でみて比較すればよいのだけど、 'Case' → 'CASE' ': ' → ':' あたりでWAになったことがあるので、できればサンプルケースはコピペして実行したい。

ついでに、input()もコピペを間違えると怖いので、提出へもそのままコピペしたい。

ということで、真ん中にそのままコピペできるコード部を残して前後に挟むスクリプトを作った。

stdoutをredirectする方法は見つかったのだけど、自身のstdinに流し込む方法はなさそうなので、inputメソッドをgeneraterで書き換えている。sys.stdin()などinput()以外には対応していない。redirect_stdoutはwith文で使うことが多いようだが、インデントが残ってしまうのでenter()とexit()を使っている。#mainのセクションのみコピペすると幸せになれる、はず!

###########
# header #
###########
from contextlib import redirect_stdout
from io import StringIO
from difflib import Differ
from itertools import zip_longest as _zip_longest

with open('input.txt', 'r') as f:
    _input = f.read()

def input_generator(_input):
    for line in _input.splitlines():
        yield line

input_gen=input_generator(_input)
input = input_gen.__next__

captured_stdout = StringIO()
_redirect_stdout = redirect_stdout(captured_stdout)
_redirect_stdout.__enter__()

########
# main #
########

def solve(a,b):
    print(a+b)
    print(a)
    

T=int(input())
for _ in range(T):
    a=int(input())
    b=int(input())
    
    solve(a,b)


##########
# footer  #
##########
_redirect_stdout.__exit__(None, None, None)

with open('output.txt', 'r') as f:
    out = f.read()

for x,y in _zip_longest(captured_stdout.getvalue().splitlines(),
                out.splitlines()):
    print('\n'.join(Differ().compare([x],[y])))

使っているところはこんな感じ。input, outputともひとまずサンプルケースのコピペで良いし、必要があれば適宜自作のテストケースを書き加える。 f:id:tmitani-tky:20190409000358p:plain

Google Code Jam 2019 Qualification Round 全完でした

codingcompetitions.withgoogle.com

Foregone Solution

任意の一例を構築すれば良いだけなので適当に。

T = int(input())
N = [input() for _ in range(T)]

for i, n in enumerate(N):
  A = ''
  B = ''
  for d in n:
    if d=='4':
      A+='2'
      B+='2'
    else:
      A+=d
      B+='0'
  print('Case  #' + str(i+1) + ': ' + str(int(A)) + ' ' + str(int(B)))

You Can Go Your Own Way

深さ優先探索してみたらTest set1のみ通過するものの、Test set2(n<=1000)でTLE。 N*Nの正方形でこれも任意の一例を示せばよいだけなので、裏返すだけでよかった。

部分点解法

T = int(input())
N = []
L = []

for _ in range(T):
  n = int(input())
  l = input() #lydia's move
  N.append(n)
  L.append(l)

from collections import deque

for i, [n, l] in enumerate(zip(N,L)):
  # i: case i+1
  # n: n*n maze
  # l: lydia's move
  
  l_positions = [(1,1)]
  dq=deque()
  dq.append((0,(1,1),''))
  
  for l_direction in l:
    if l_direction =='E':
      l_positions.append((l_positions[-1][0] + 1, l_positions[-1][1] + 0))
    if l_direction =='S':
      l_positions.append((l_positions[-1][0] + 0, l_positions[-1][1] + 1))
  #print(l_positions)
      
  while dq:
    j, position, way = dq.pop()
    #print(i, n, l, j, position, way)
    
    if position == (n,n):
      print('Case #' + str(i+1) + ': ' + way)
      break
    
    if position!=l_positions[j]:
      if position[0] < n and position[1] <n:
        dq.extend([(j+1, (position[0]+1, position[1]), way+'E'),
                 (j+1, (position[0], position[1]+1), way+'S')])
      elif position[0] ==n:
        dq.append((j+1, (position[0], position[1]+1), way+'S'))
      elif position[1] ==n:
        dq.append((j+1, (position[0]+1, position[1]), way+'E'))
    else:
      if l[j] =='E' and position[1] < n:
        dq.append((j+1, (position[0], position[1]+1), way+'S'))
      elif l[j] =='S' and position[0] < n:
        dq.append((j+1, (position[0]+1, position[1]), way+'E'))  

満点解答

T = int(input())
N = []
L = []

for _ in range(T):
  n = int(input())
  l = input() #lydia's move
  N.append(n)
  L.append(l)
  
for i, l in enumerate(L):
  l=l.replace('S','_')
  l=l.replace('E','S')
  l=l.replace('_','E')
  
  print('Case #' + str(i+1) + ': ' + l)

Cryptopangrams

ユークリッドの互除法はmath.gcdで省力化。エッジケースをあまり考えずに実装して、test caseは通るものの、実際のtest setではRE.最初同じ文字が並ぶ場合を考慮していなかった。

T = int(input())
N = []
L = []
ciphertexts = []

for _ in range(T):
  n, l = list(map(int, input().split()))
  ciphertext = list(map(int, input().split()))
  N.append(n)
  L.append(l)
  ciphertexts.append(ciphertext)

from math import gcd
#from fractions import gcd

for i, [n,l,ciphertext] in enumerate(zip(N,L,ciphertexts)):
  # i: case i+1
  # n: maximum of prime, 101<=n<=10000(test1), 1e100(test2)
  # l: the length of the list of values in the ciphertext, 25<=l<=100
  # ciphertext: list of ints: (the prime for X) * (the prime for Y)
  
  prime_l=[]
  
  p1 = gcd(ciphertext[0], ciphertext[1])
  p0 = ciphertext[0]//p1
  prime_l.append(p0)
  
  for j in range(l):
    prime_l.append(ciphertext[j]//prime_l[j])
  
  cipher_dict={p:c for p, c in zip(sorted(list(set(prime_l))), 'ABCDEFGHIJKLMNOPQRSTUVWXYZ')}
  
  print('Case #' + str(i+1) + ': ' + ''.join([cipher_dict[p] for p in prime_l]))

初めて2個めの数字が出てくる位置を検索して、同じ文字が並んでいる場合と、2つの文字が交互に並んでいる場合があることに注意しつつ微修正して無事通過。この初めて違う数字が出てくるインデックス、whileで回したのだけどもっとスマートな書き方がありそう。

T = int(input())
N = []
L = []
ciphertexts = []

for _ in range(T):
  n, l = list(map(int, input().split()))
  ciphertext = list(map(int, input().split()))
  N.append(n)
  L.append(l)
  ciphertexts.append(ciphertext)

#from math import gcd
from fractions import gcd

for i, [n,l,ciphertext] in enumerate(zip(N,L,ciphertexts)):
  # i: case i+1
  # n: maximum of prime, 101<=n<=10000(test1), 1e100(test2)
  # l: the length of the list of values in the ciphertext, 25<=l<=100
  # ciphertext: list of ints: (the prime for X) * (the prime for Y)
  
  prime_l=[]
  
  j=0
  while ciphertext[j]==ciphertext[j+1]:
    j+=1
  
  q = gcd(ciphertext[j], ciphertext[j+1])
  p = ciphertext[j] // q
  
  prime_l = [q if (k-j)%2 else p for k in range(j+1)]
  #print(ciphertext[j], ciphertext[j+1], p, q, prime_l)
    
  for k in range(l-j):
    prime_l.append(ciphertext[j+k]//prime_l[j+k])
  
  cipher_dict={p:c for p, c in zip(sorted(list(set(prime_l))), 'ABCDEFGHIJKLMNOPQRSTUVWXYZ')}
  
  print('Case #' + str(i+1) + ': ' + ''.join([cipher_dict[p] for p in prime_l]))

Dat Bae

このタイトル、Data Baseからaとsが脱落したタイトルなのか。今気づいた。

Test set1のF=10までは比較的自然な発想。全ビットを見分けられるようなQueryを投げて、返りを適当に。n列目をnの2進数表記をして、それを1~F回目のqueryで1桁ずつ送るようにすると、返り値を列ごとに並べて2進数→intで列数に戻るので便利。 itertoolsとnumpyによる転置と、int('xxxx', 2)を駆使して実装。高級言語万歳。

from itertools import product
import math
import numpy as np

def solution():
    n, b, f = list(map(int, input().split())) # n of workers, broken workers, lines
    m = min(f, math.ceil(math.log2(n)))
    
    queries = np.array(list(product([0,1], repeat = m))[:n]).T
    responses = []
    
    for i in range(m):
        print(''.join(queries[i].astype('str')), flush=True)
        responses.append([int(s) for s in input()])
    responses = np.array(responses)
    
    correct = [int(''.join(r.astype('str')), 2) for r in responses.T]
    broken = [str(_) for _ in range(n) if _ not in correct]
    
    print(' '.join(broken), flush=True)
    assert(input()=='1')
    
t = int(input())
for _ in range(t):
    solution()

そうはいってもこれだとF=5に対応できない。B<=15という意味深な条件があるが…。

16個選ぶとかならず1個は生きてるビットがあることに気づく。116, 016と16個ずつ並べれば個々の戻り列は、どの16個の中にあるか特定できるので、そのなかで残りの4列を24=16に使って解ける。

from itertools import product, groupby
import math
import numpy as np

def solution():
    n, b, f = list(map(int, input().split())) # n of workers, broken workers, lines
    
    f=5    
    
    # 1st line
    r = (n+31)//32
    print((('0'*16+'1'*16)*r)[:n], flush=True)
    response=input()
    broken_in_blocks = [len(list(g)) for k, g in groupby(response)]
    
    # 2nd - 5th line
    r = (n+15)//16
    query4_16 = np.array(list(product([0,1], repeat = 4))).T
    queries = np.tile(query4_16, (1,r))[:,:n]
    #print(queries)
    responses = []
    
    for i in range(f-1):
        print(''.join(queries[i].astype('str')), flush=True)
        responses.append([int(s) for s in input()])
    
    responses = np.array(responses)
    
    l=0
    correct = []
    for i in range(r):
        m = broken_in_blocks[i]
        correct.extend([i*16 + int(''.join(x.astype('str')), 2) for x in responses[:,l:l+m].T])
        l+=m
    
    broken = [str(_) for _ in range(n) if _ not in correct]
    print(' '.join(broken), flush=True)
    assert(input()=='1')
        
t = int(input())
for _ in range(t):
    solution()

整数問題で商にmath.ceilを使ってはいけない?

n/m以下の整数を使いたいときはn//mを使えばよいが,n/m以上の最小の整数に相当する演算子は存在しない.

そこで,横着してmath.ceilを使うと,答えがずれることがある.

ABC046 C - AtCoDeerくんと選挙速報 / AtCoDeer and Election Report atcoder.jp

誤り解法

import math

N=int(input())

T=1
A=1
for _ in range(N):
  t,a=list(map(int, input().split()))
  m = max(int(math.ceil(T/t)), int(math.ceil(A/a))))
  T = t*m
  A = a*m
  #print(T,A)
  
print(T+A)

これをするとtest caseのうち4件でWAとなる.

# 略
def divceil(a,b):
  if a%b==0:
    return a//b
  else:
    return a//b + 1

# 略
  m = max(divceil(T,t), divceil(A,a))

これならok.

T/tの部分が浮動小数点で計算されるので,割り切れるような場合であっても精度の末尾のところで微妙にずれる のだろうと推測.

ABC-047C - 一次元リバーシ / 1D Reversi をpythonで解く

atcoder.jp

リバーシにおいて同じ色の1個の石も2個の石も等価なので,まずは同じ色の並んだ部分を1個の石に置き換えると,あとは1個ずつひっくり返していくしかないのでその長さ-1が答えとなる.

解法1 np.diffを使う

import numpy as np

S=input()
S=np.array([1 if s=='B' else 0 for s in S])

answer = np.abs(np.diff(S)).sum()
print(answer)

AC @ 334ms

変化点を見つけるためには差分系列を取れば良い.01からなるndarrayに変更しておけば,np.diffで差分をとって,その絶対値の和をとれば良い.

経過時間はimport numpyに時間がかかっている.

解法2 itertools.groupby

from itertools import groupby

S=input()
answer = -1
for _ in groupby(S):
  answer+=1

print(answer)

AC @ 32ms itertools内のgroupbyは値が切り替わるたびに値を送出するiterater.sort済みの配列に対して使われることが多いが,このように単に連続する値をグループ化したいときにも使える.iteratorなので回しきらないと全体数がわからないのがネックだが,逆に複雑な処理が必要な場合も有効かも.

for文のところは,リスト内包表記でも書けてちょっと早くなる(AC @28ms)

from itertools import groupby

S=input()
answer = sum([1 for _ in groupby(S)]) -1

print(answer)

解法3 実直に

S=input()

answer=0
prev=S[0]
for s in S:
  if s!=prev:
    answer+=1
  prev=s

print(answer)

AC @33ms いろが変わるたびにanswerを+1する.

個人的にはnumpyの解法が好き.

AtCoderのランク感

最近ちょこちょことAtCoderをやっている.

以下のblogでも引用されているように, kumagi.hatenablog.com

ということでABCのC問題が解けないとお話にならないらしいが,C問題難しいorz.となっていた.

ただ,この投稿,よくよく見ると2018年6月末の投稿.これは~ABC100あたりまでの話.自分が解けなかったのはABC100以降が多く,ABC100以前は8割くらいは自力で解けている感じ.

とりあえずC埋めぐらいはして水色になりたい.