ABC330 E Mex and Update をPythonで解く【Atcoder】


問題はこちらです。 atcoder.jp

Pythonでセグ木を使う解法記事がないので書いておきます。 まず、セグ木でAに0~N+1がそれぞれ登場する回数を管理します。
クエリが飛んできた時には、 A_iのセグ木上の登場回数1へらし、xを1増やします。(実装ではupdateメソッドに該当します)
次にMEXを求める処理を考えます。
ここで、MEXは登場回数が0回である最小の数と言い換えることができます。
したがって、セグ木を二分探索すればよいです。

# oj t -c "python3 main.py"
import sys,math
from collections import defaultdict,deque
from itertools import combinations,permutations,accumulate,product
from bisect import bisect,bisect_left,bisect_right
from heapq import heappop,heappush,heapify
from sortedcontainers import SortedList,SortedSet
def input(): return sys.stdin.readline().rstrip()
def ii(): return int(input())
def ms(): return map(int, input().split())
def li(): return list(map(int,input().split()))
inf = pow(10,18)
#////////////////////////////////////
class segtree:  # 0-index
  # 要素treeの初期化(最初はinit_val)treeがセグメント木そのもの
  def __init__(self,init_val,segfunc,ide_ele):
    self.segfunc = segfunc
    self.ide_ele = ide_ele
    self.n = len(init_val)
    self.size = 1
    while self.size < self.n:
      self.size *= 2
    self.tree = [self.ide_ele] * (self.size * 2)
    for i in range(self.n):
      self.tree[self.size + i] = init_val[i]
    for i in range(self.size-1,0,-1):
      self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])

  def get(self,pos): # 元の配列のpos番目の要素を取得する(0-index)
    return self.tree[pos+self.size]

  def update(self, pos, x): 
    if pos<=N+1:
      pos += self.size
      self.tree[pos] -= 1
      while pos >= 2:
        pos //= 2
        self.tree[pos] = self.segfunc(self.tree[pos * 2], self.tree[pos * 2 + 1])
    if x<=N+1:
      x += self.size
      self.tree[x] += 1
      while x >= 2:
        x //= 2
        self.tree[x] = self.segfunc(self.tree[x * 2], self.tree[x * 2 + 1])

  # [l, r) の最大値を求める処理
  def query(self, l, r):
    ret = self.ide_ele
    l += self.size
    r += self.size

    while l < r:
      if l & 1:
        ret = self.segfunc(ret, self.tree[l])
        l += 1
      if r & 1:
        ret = self.segfunc(ret, self.tree[r-1])
      l >>= 1
      r >>= 1
    return ret
# 最小値
def segfunc(x, y): return min(x,y)
ide_ele = pow(10,18)

N,Q = ms()
A = li()
B = [0]*(N+1)
for a in A: 
  if a<=N: B[a] += 1
  
seg = segtree(B,segfunc,ide_ele)
for _ in range(Q):
  i,x = ms()
  tmp = A[i-1]
  seg.update(tmp,x)
  A[i-1] = x
  ng,ok = 0,N+2
  while abs(ok-ng)>1:
    mid = (ok+ng)//2
    if seg.query(0,mid)==0: # 条件を満たすなら
      ok = mid
    else:
      ng = mid
  print(ok-1)