Try   HackMD
tags: HUPC2021

G: プログラミングコンテストの作問 解説

原案: ngng628 | 解説: ngng628

基本方針

構文解析をして、画面出力命令 print の動作を実際にシミュレートしていけば良いです。

繰り返し文の処理さえ工夫すれば、十分高速に動作します。

構文解析について

再帰下降構文解析と一字句先読み、すなわちLL(1)構文解析を行えば良いです。

一字句先読みが必要になるのは、<declaration><assignment><for><print> かを判定するときです。

詳しくは実装例を参考にしてください。

繰り返し文の処理について

繰り返し回数が最大で

109 になることから、与えられた繰り返し文をそのままシミュレートする方法をとってしまうと、実行時間制約を超えてしまいます。

問題文にもある通り、<for>

L 個の <for_element> が生成されているとき、
i
番目の <for_element><expression> の値を
ni
とすると、
i=1Lni
回だけ、<assignment> または <print> を実行します。この繰り返し回数を
π
とおきます。

【1】for文に続く命令が <assignment> のとき

変数 a<expression>

π 回代入することを考えます。繰り返し中、値が変化する変数は a のみなので、<expression> は次のように整理できます。

  • <expression>
    =p×
    a
    +q

a

p× a
+q
π
回繰り返し代入したとき、最終的に a

  • a{a(π=0)a+πq(π1p=1)q(pπ1)(p1)1+apπ(otherwize.)

となります。累乗の計算に繰り返し二乗法を使い、逆元の計算に拡張 Euclid の互除法などを使えば、十分高速に動作します。

式の証明

π=0 の場合は、明らかです。

π1p=1 の場合は、数列
bn=bn1+q,b0=a
の一般項を求めれば良いです。
{bn}
は等差数列なので一般項は
bn=a+nq
となります。
n=π
とすれば、上式を得ます。

それ以外の場合は、数列

bn=pbn1+q,b0=a の一般項を求めれば良いです。まず、特性方程式
β=pβ+q
を解くと、
β=q(p1)1
となります。
β
を使うと、
{bn}
の漸化式は、
bnβ=p(bn1β)
と変形できます。
cn=bnβ
とおくと、
{cn}
は等比数列になっているので、一般項
cn=c0pn
が求まります。よって、
bn=(aβ)pn+β
となります。
n=π
とし、
β
q(p1)1
に戻すと、上式を得ます。

【2】for文に続く命令が <print> のとき

<expression> とその出力回数をペアで管理し、データを圧縮すると良いです。

ただし、次のようなケースには注意しましょう。

1 2
for(2)print(100)
print(100)
print(100)
0 0

実装例

C++

# include <bits/stdc++.h> constexpr int MOD = 998244353; using namespace std; int Mod(long long v) { return ((v % MOD) + MOD) % MOD; } int pow_mod(long long a, long long n, int m) { ((a %= m) += m) %= m; long long res = 1; while (n) { if (n & 1) (res *= a) %= m; (a *= a) %= m; n >>= 1; } return res; } int inv_mod(long long a, int m) { ((a %= m) += m) %= m; long long b = m, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } u %= m; if (u < 0) u += m; return u; } class Parser { public: Parser(int _n, const vector<string>& _s) : n(_n) { s = ""; for (string t : _s) { s += t + "\n"; } it = s.begin(); output.emplace_back(0, -1); values['?'] = 0; } vector<pair<long long, int>> parse() { program(); return output; } private: int n; string s; string::const_iterator it; vector<pair<long long, int>> output; map<char, int> values; void program() { while (it != s.end()) { if (*it == '\n') it++; else if (*next(it) == '<') assignment(); else if (peek(2) == "fo") for_statement(); else if (peek(2) == "pr") print(); else declaration(); } } void declaration() { assert(isalpha(*it)); values[*it++] = 0; } void assignment(int n_loops = 1) { char name = value_name(); skip("<-"); int val = expression(n_loops, name); values[name] = val; } void print(int n_loops = 1) { skip("print("); int val = expression(); skip(")"); if (n_loops == 0) return; if (output.back().second == val) { output.back().first += n_loops; } else { output.emplace_back(n_loops, val); } } void for_statement() { int n_loops = 1; while (peek(2) == "fo") { n_loops *= for_element(); if (n_loops > 1e9) n_loops = 0; } if (peek(2) == "pr") print(n_loops); else assignment(n_loops); } int for_element() { skip("for("); int val = expression(); skip(")"); return val; } int expression(int n_loops = 1, char name = '?') { int p = 0; int q = 0; int sgn = 1; while (true) { if (*it == '+') { sgn = 1; ++it; } if (*it == '-') { sgn = -1; ++it; } if (isdigit(*it)) { q += sgn*number(); } else if (isalpha(*it)) { if (*it == name) { p += sgn; ++it; } else { q += sgn * values[value_name()]; } } else { break; } p = Mod(p); q = Mod(q); } const int a = values[name]; if (n_loops == 0) { return a; } else if (n_loops >= 1 and p == 1) { return Mod(a + Mod((long long)Mod(n_loops) * q)); } else { int num = Mod((long long)q * (pow_mod(p, n_loops, MOD) - 1)); int den_inv = inv_mod(p - 1, MOD); int cnst = Mod((long long)a * pow_mod(p, n_loops, MOD)); return Mod(Mod((long long)num * den_inv) + cnst); } } int number() { int res = 0; while (isdigit(*it)) { res = Mod(10LL*res + (*it++ - '0')); } return res; } char value_name() { return *it++; } void skip(const string& t) { for (char c : t) { assert(*it == c); ++it; } } string peek(int n) { string res; for (int i = 0; i < n; ++i) { res.push_back(*next(it, i)); } return res; } }; int main() { while (true) { int n, m; cin >> n >> m; if (!(n | m)) break; vector<string> s(n), t(m); for (string& u : s) cin >> u; for (string& u : t) cin >> u; auto yuchan = Parser(n, s).parse(); auto reikun = Parser(m, t).parse(); puts(yuchan == reikun ? "Yes" : "No"); } }

Python

import sys input = sys.stdin.readline MOD = 998244353 def Mod(n): return pow(n, 1, MOD) class Parser: def __init__(self, n, s): self.n = n self.s = '\n'.join(s) + '\n' + '.' self.cur = 0 self.output = [[0, -1]] self.values = { '?' : 0 } def parse(self): self.__program() def __program(self): while self.__peek(1) != '.': if self.__peek(1) == '\n': self.cur += 1 elif self.__peek(2)[1] == '<': self.__assignment() elif self.__peek(2) == 'fo': self.__for() elif self.__peek(2) == 'pr': self.__print() else: self.__declaration() def __declaration(self): self.values[self.__value_name()] = 0 def __assignment(self, n_loops=1): name = self.__value_name() self.__skip('<-') val = self.__expression(n_loops=n_loops, name=name) self.values[name] = val def __print(self, n_loops=1): self.__skip('print(') val = self.__expression() self.__skip(')') if n_loops == 0: pass elif self.output[-1][1] == val: self.output[-1][0] += n_loops else: self.output.append([n_loops, val]) def __for(self): n_loops = 1 while self.__peek(2) == 'fo': self.__skip('for(') n_loops *= self.__expression() self.__skip(')') if self.__peek(2) == 'pr': self.__print(n_loops) else: self.__assignment(n_loops) def __expression(self, n_loops=1, name='?'): p, q = 0, 0 sgn = 1 while True: if self.__peek(1) == '+': sgn = 1 self.cur += 1 if self.__peek(1) == '-': sgn = -1 self.cur += 1 if self.__peek(1).isdigit(): q += sgn * self.__number() elif self.__peek(1).isalpha(): if self.__peek(1) == name: p += sgn self.cur += 1 else: q += sgn * self.values[self.__value_name()] else: break a = self.values[name] if n_loops == 0: return a elif n_loops >= 1 and p == 1: return Mod(a + n_loops * q) else: num = Mod(q * (pow(p, n_loops, MOD) - 1)) den_inv = pow(p - 1, MOD - 2, MOD) cnst = Mod(a * pow(p, n_loops, MOD)) return Mod(num * den_inv + cnst) def __number(self): res = 0 while self.__peek(1).isdigit(): res = 10*res + int(self.__peek(1)) self.cur += 1 return res def __value_name(self): name = self.__peek(1) self.cur += 1 return name def __skip(self, s): for c in s: assert(c == self.__peek(1)) self.cur += 1 def __peek(self, n): return self.s[self.cur : self.cur + n] def main(): while True: n, m = map(int, input().split()) if n == 0 and m == 0: break s = [input() for _ in range(n)] t = [input() for _ in range(m)] yuchan = Parser(n, s) yuchan.parse() reikun = Parser(m, t) reikun.parse() print('Yes' if yuchan.output == reikun.output else 'No') if __name__ == '__main__': main()