na_o_ysのブログ

プログラミングなど

join から理解する State モナド

すごいH本 14 章に出てきた State モナドの理解が難しかったので, まとめてみました.

State モナドとは

  • 状態付き計算を表現するモナド
  • 現状態 s1 を受け取り, 値 a と次状態 s2 のペアを返す関数
  • newtype State s a = State { runState :: s -> (a, s) }
    • s: 状態の型
    • a: 値の型
  • (現状態, (値, 次状態)) の集合ともみなせる

利用例

詳しくはすごいH本の 14.3 章や 解説記事 を参照.

難しい

これまでのモナドと比べて, bind の定義が複雑で難しいです.

instance Monad (State s) where
  return a        = State $ \s -> (a, s)
  (State h) >>= f = State $ \s -> let (a, newState) = h s
                                      (State g)     = f a
                                  in  g newState

f の引数は何だ?とか, newState の型は?とか考えて, 頭がこんがらがります.

そこで, bind を少し回り道して考えてみます.

bind を fmap, join に分解する

モナドの bind は fmap, join の 2 つの成分に分解することができます. また, 逆に fmap と join から bind を構成することが出来ます.

モナドの種類によって, fmap に特徴があるものと join に特徴にあるものに分かれます.

State モナドは join に特徴があるようです.

以下では, State モナドの join, fmap を順番に定義してみます.

join

join は入れ子になったモナドを潰して一つにする操作です.

join :: Monad m => m (m a) -> m a で表されます.

List モナドだと平坦化 (flatten) と同じですね.

State モナドの join は, 入れ子になった状態付き計算を順に適用する, 関数合成のようなイメージです.

join' :: State s (State s a) -> State s a
join' (State outer) = State $
  \s ->                                     -- 現状態 s を受け取り
    let ((State inner), newState) = outer s -- outer に適用してから
    in  inner newState                      -- inner に適用

join' は入れ子の State モナドを受け取り, 「現状態 s に対して outer, inner を順に適用した結果を返す」State モナドを返します.

入れ子になった State モナドの文脈 (状態付き計算) が, join' によって単一の文脈に糊付けされます.

State モナドは「現状態 s1 を受け取り, 値 a と次状態 s2 のペアを返す関数」なので, ここで定義した join' が State モナドの文脈に沿っていることは納得できると思います.

これが State モナドの核となる部分のようです.

fmap

fmap は単に計算結果の値を map するだけです.

(s -> (a, s)), (s -> (b, s)) はそれぞれ State モナドを構成する関数を表しています.

mapValue :: (a -> b) -> (s -> (a, s)) -> (s -> (b, s))
mapValue f runState' = \s ->
                          let (a, newState) = runState' s
                          in  (f a, newState)  -- 計算結果 a を変換

fmap' f (State h) = State $ mapValue f h

fmap, join から bind を導く

bind は fmap, join を使って次のように構成できます.

m >>= f = join $ fmap f m

f :: Monad m => a -> m b ですので, fmap f m で入れ子モナド m (m b) が作られます.

この入れ子モナドを join で糊付けすれば bind が出来上がります.

instance Monad (State s) where
  return a = State $ \s -> (a, s)
  s >>= f  = join' $ fmap' f s

この宣言で State はモナドとして正しく動作します.

また join', fmap' を前述の定義式に置き換えて適当に式変形をすれば bind の定義式そのものを導くこともできます.

参考文献

(unrated) SRM 644 Div1 Easy, OkonomiyakiParty

問題

N (< 50) 種類のお好み焼きがある. i 番目の大きさは配列 osize[i] (< 10000) で与えられる. このうち M (< N) 種類を選ぶときに, 最も小さいものと最も大きいものの差が K 以下となるような選び方は何通りあるか, 1,000,000,007 の剰余で求めよ.

方針

昇順でソートし, 最も小さいお好み焼きを i としたときの選び方を数え上げる.

解答

using ll = long long;
const int MOD = 1000000007;

class OkonomiyakiParty
{
public:
    int count(vector <int> osize, int M, int K) {
        fill(memo[0], memo[59]+60, -1);
        int N = osize.size();
        sort(all(osize));
        ll ans = 0;
        loop (N, i) {
            int j = i+1;
            while (j < N && osize[j]-osize[i] <= K) j++;
            ans = (ans + cmb(j-i-1, M-1)) % MOD;
        }
        return ans;
    }

    ll memo[60][60];
    ll cmb(ll n, ll r) {
        if (n < 0 || r < 0 || n < r) return 0;
        if (memo[n][r] != -1) return memo[n][r];
        if (!r) return 1;
        return memo[n][r] = (cmb(n-1, r-1) + cmb(n-1, r)) % MOD;
    }
};

Codeforces Good Bye 2014 C, New Year Book Reading

問題

http://codeforces.com/contest/500/problem/C

本が N (< 500) 冊ある. それぞれの本には重さがある. 本は全て重ねて積んだ状態である. 上から I 番目の本を読む場合には, I 番目の本を取り出して山の一番上に移動させる. この時, あなたには 1~I-1 番目の本の重量分の負荷がかかる.

M (< 1000) 日間, 1 日 1 冊の本を読みたい. I 日目に読む本は b[i] で与えられる. 同じ本を複数の日に読む場合もある. M 日間の負荷の最小値を求めよ.

方針

読み終わった本を上に重ねるので, 新しい本 B を読む場合の負荷は「B より前に読んだ本の重さ + B の上にある未読本の重さ」となる. 本を初めて読む順番で積んでおけば, 後者がゼロになる.

解答

int main(int argc, char const* argv[])
{
    int n, m; cin >> n >> m;
    vector<int> weight(n), b(m);
    loop (n, i) cin >> weight[i];
    loop (m, i) { cin >> b[i]; b[i]--; }

    vector<int> order;
    loop (m, i) if (find(all(order), b[i]) == order.end()) {
        order.push_back(b[i]);
    }

    int ans = 0;
    loop (m, i) {
        int j = 0;
        while (order[j] != b[i]) {
            ans += weight[order[j]];
            j++;
        }
        order.erase(order.begin()+j);
        order.insert(order.begin(), b[i]);
    }

    cout << ans << endl;
    return 0;
}

Codeforces Good Bye 2014 D, New Year Santa Network

方針を考えるのが楽しい問題だった.

問題

http://codeforces.com/contest/500/problem/D

ノード数 N (< 100000) の木が与えられる. 各辺は長さを持つ. 1 年に一回, ある辺の長さが更新される. 3 頂点 a, b, c を一様ランダムに選んだ時の, D = distance(a, b) + distance(b, c) + distance(c, a) の期待値を q (< 100000) 年分に求めよ.

方針

各辺について, 辺が D に含まれる(a-b-c のパスの一部になる)確率を求め, 期待値計算をする. 年毎に期待値を更新して出力する.

ある辺について考えた時に, 辺が D に含まれる回数は 0 回または 2 回である. 0 回となるのは, 辺のどちらか片側に a, b, c が全て含まれる場合である.

そこで各辺について, 辺の片側に a, b, c が全て含まれる確率を計算する. 辺の片側のノード数が分かればこれは計算できる.

辺の片側のノード数は, 深さが大きい方の端点の子の数と一致する. よって事前に各頂点の深さと子の数を DFS で計算しておく.

解答

using ll = long long;
using P = pair<int, int>;

const int MAX = 100010;

int degree[MAX] = {};
int children[MAX] = {};
P road[MAX];
double prob[MAX];
vector<int> G[MAX];

int dfs(int prev, int cur, int d)
{
    degree[cur] = d;
    for (int nxt : G[cur]) {
        if (nxt == prev) continue;
        children[cur] += dfs(cur, nxt, d+1)+1;
    }
    return children[cur];
}

ll perm3(ll m)
{
    return m * (m-1) * (m-2);
}

int main(int argc, char const* argv[])
{
    int n; cin >> n;
    vector<P> road;
    vector<int> len;
    loop (n-1, i) {
        int a, b, l; cin >> a >> b >> l;
        a--; b--;
        G[a].push_back(b);
        G[b].push_back(a);
        road.emplace_back(a, b);
        len.push_back(l);
    }

    // 各 node の深さと子の数計算
    dfs(-1, 0, 0);

    loop (n-1, i) {
        int a = road[i].first, b = road[i].second;
        if (degree[a] > degree[b]) swap(a, b);

        // road[i] の方側のノード数
        int count = children[b]+1;

        double probZero = (1.0*perm3(count) + perm3(n-count)) / perm3(n);
        // 3 点が road[i] をまたぐ確率
        prob[i] = 1.0 - probZero;
    }

    // 期待値初期値
    double exp = 0;
    loop (n-1, i) exp += prob[i] * len[i] * 2;

    int q; cin >> q;
    vector<int> r(q), w(q);
    loop (q, i) { cin >> r[i] >> w[i]; r[i]--; }

    cout << setprecision(10) << fixed;
    loop (q, i) {
        int idx = r[i];
        exp -= prob[idx] * len[idx] * 2;
        len[idx] = w[i];
        exp += prob[idx] * len[idx] * 2;
        cout << exp << endl;
    }

    return 0;
}

SRM 640 Div1 Easy, ChristmasTreeDecoration

特に引っ掛けは無いので, 方針だけ思いつけば簡単な問題.

問題

http://community.topcoder.com/stat?c=problem_statement&pm=13551&rd=16083

N (<50) 個の星と M (<200) 個のリボンがある. 星はそれぞれ色があり, i 番目の星の色は col[i] で与えられる. リボンは指定された 2 つの星を結ぶことができる. j 番目のリボンは星 x[j] と星 y[j] を結ぶことができる.

このような配列 col, x, y が与えられたときに, リボンを適切に選んで全ての星を連結にしたい. ただし, 同じ色の星を結ぶリボンの数が最も少なくなるようにしたい. 同じ色の星を結ぶリボンの最小値を求めよ.

与えられた全てのリボンを使うと全ての星が連結となることは保証されている.

方針

異なる色の星を結ぶリボンのみを使ってグラフを作る. グラフの連結成分の個数 g を求める. 与えられた全てのリボンを使うと全ての星が連結となることが保証されているので, 連結成分同士をつなぐ g-1 個の同色リボンが必ず存在する.よって, g-1 が求める答えとなる.

解答

class ChristmasTreeDecoration
{
public:
    vector<int> G[51];
    int vis[51] = {};

    int solve (vector <int> col, vector <int> x, vector <int> y)
    {
        int N = col.size(), M = x.size();
        loop (M, i) {
            if (col[x[i]-1] == col[y[i]-1]) continue;
            G[x[i]-1].push_back(y[i]-1);
            G[y[i]-1].push_back(x[i]-1);
        }

        int g = 0;
        loop (N, i) {
            if (vis[i]) continue;
            dfs(i);
            g++;
        }
        return g-1;
    }

    void dfs(int cur)
    {
        if (vis[cur]) return;
        vis[cur] = true;
        for (int nxt : G[cur]) dfs(nxt);
    }
};

SRM 643 Div1 Med, TheKingsArmyDiv1

コーナーケースに気をつけて慎重に戦略を考える問題(もしくはDP解法). 本番では題意を微妙に取り違えていて, なぜか pretest 通ってしまい, 呆気無く撃墜された.

問題

http://community.topcoder.com/stat?c=problem_statement&pm=13526&rd=16086

2N 人の兵士が 2 行 N 列に並んでいる. 各兵士は Happy/Sad のどちらかの状態を持っている. 兵士 X が兵士 Y に話しかけると, 兵士 Y の状態が兵士 X と同じになる. あなたは次の命令ができる.

  1. 一人の兵士を選び, 隣(上下左右)の誰かに話しかけさせる.
  2. ある行の隣り合う 2 人以上の兵士を選び, 同列他行の兵士に話しかけさせる.
  3. ある列の兵士(2人)を選び, 隣の列の兵士に話しかけさせる.

全員を Happy にするための最小命令回数を求めよ. 不可能な場合は -1 を出力する.

方針

基本的に 2. の命令が最も効率的なので, 2. を使う戦略を考える.

上下の状態が全て異なる場合

HHSSHHSSHH
SSHHSSHHSS

上記の場合, 次の 3 命令で全員 Happy にできる.

  • 1 行 2~3 列の兵士に 2. を命令
  • 1 行 6~7 列の兵士に 2. を命令
  • 0 行 0~9 列の兵士に 2. を命令

よって, 横に状態が等しい隣接ブロックの数を B とすると, 命令回数は B/2 + 1 回となる.

上下どちらも Happy のペアがいる場合

HHH
SHS

上記の場合, 0 行 0~2 列の兵士に 2. を命令すれば良い. すなわち, 上下 Happy のペアを無視して, 前述の上下が全て異なる場合と同じ計算を適用できる.

上下どちらも Sad のペアがいる場合

HSH
SSS

この場合もまず, 0 行 0~2 列の兵士に 2. を命令する. 次に, 0 列の兵士に右隣に話しかけさせれば良い(命令 3). すなわち, 上下 Sad のペアを無視して, 最後に命令 3. を上下 Sad の個数だけ行う.

まとめ

次の計算で最小命令数が計算できる

  1. 上下どちらも Happy なペアを除外する.
  2. 上下どちらも Sad なペアを除外し, その数だけ ans++ する.
  3. 隣接ブロックの数 B を計算し, ans += B/2 + 1 とする.

あとは, 全員 Sad の場合と, 上下が全て等しい場合だけ分岐すれば OK.

解答

class TheKingsArmyDiv1
{
public:
    int getNumber (vector <string> state)
    {
        int N = state[0].length();

        if (state[0] == string(N, 'S') && state[1] == string(N, 'S')) return -1;

        int ans = 0;
        string different;
        for (int i = 0; i < N; i++) {
            if (state[0][i] != state[1][i]) {
                different.push_back(state[0][i]);
            }
            else if (state[0][i] == 'S') ans++;
        }
        int M = different.length();

        if (M == 0) return ans;

        int blocks = 1;
        for (int i = 0; i < M-1; i++) if (different[i] != different[i+1]) blocks++;

        return ans + blocks/2 + 1;
    }
};

(おまけ) 区間 DP 解法

区間 [l, r) について, 配列 dp(l, r, k) を次のように持つ.

  • 0 行目を全て H にする手数 (k = 0)
  • 1 行目を全て H にする手数 (k = 1)
  • 全て H にする手数 (k = 2)

幅 (r-l) を 1 から N まで更新しながら, dp を埋めていく.

class TheKingsArmyDiv1
{
public:
    int getNumber (vector <string> state)
    {
        int N = state[0].length();
        int dp[201][201][3];
        fill(dp[0][0], dp[200][200]+3, INF);

        // 幅 1 の初期値
        for (int i = 0; i < N; i++) {
            if (state[0][i] == 'H') dp[i][i+1][0] = 0; // HS or HH
            if (state[1][i] == 'H') dp[i][i+1][1] = 0; // SH or HH
            if (state[0][i] == 'H' && state[1][i] == 'H') dp[i][i+1][2] = 0; // HH
        }

        for (int w = 1; w <= N; w++) {
            for (int l = 0; l+w <= N; l++) {
                for (int k = l+1; k <= l+w; k++) {
                    int r = l+w;
                    // 区間[l,r)
                    // HS or HH にするコスト
                    dp[l][r][0] = min({
                            dp[l][r][0],
                            dp[l][k][0] + dp[k][r][0],
                            dp[l][k][0] + r-k,
                            k-l + dp[k][r][0]});
                    // SH or HH にするコスト
                    dp[l][r][1] = min({
                            dp[l][r][1],
                            dp[l][k][1] + dp[k][r][1],
                            dp[l][k][1] + r-k,
                            k-l + dp[k][r][1]});
                    // HH にするコスト
                    dp[l][r][2] = min({
                            dp[l][r][2],
                            dp[l][k][2] + dp[k][r][2],
                            min(dp[l][r][0], dp[l][r][1]) + 1,
                            dp[l][k][2] + r-k,
                            k-l + dp[k][r][2]});
                    dp[l][r][0] = min(dp[l][r][0], dp[l][r][2]);
                    dp[l][r][1] = min(dp[l][r][1], dp[l][r][2]);
                }
            }
        }

        return dp[0][N][2] >= INF ? -1 : dp[0][N][2];
    }
};

SRM 643 Div1 Easy, TheKingsFactorization

本番は素数リスト作ったけどその必要も無かった.

問題

http://community.topcoder.com/stat?c=problem_statement&pm=13594&rd=16086

1018 以下の整数 N が与えられる. N を素因数分解せよ. ただしヒントとして, 昇順 0, 2, 4, 6, ... 番目の素因数が与えられる. (e.g. N = 60 の場合, 素因数 {2, 2, 3, 5} のうち {2, 3} が与えられる.)

方針

ヒントを考慮しない場合, √1018 = 109 までの試し割りが必要となり, 計算量オーバーする.

ヒントの素因数以外で, 106 を超える素因数は高々 1 つである. よって, 106 までを試し割れば良い.

解答

using ll = long long;
const int MAX = 1000001;

class TheKingsFactorization
{
public:
    vector<long long> getVector (long long N, vector<long long> primes)
    {
        for (ll p : primes) N /= p;

        for (int i = 2; i < MAX; i++) {
            while (N%i == 0) {
                primes.push_back(i);
                N /= i;
            }
        }
        if (N > 1) primes.push_back(N);

        sort(primes.begin(), primes.end());
        return primes;
    }
};