SAM : Type sam = (nodes : [SAMNode]) -> (last: Int) -> SAM
init : SAM init = sam [samNode -1 0 empty] 0
.extend : SAM -> Char -> SAM .extend (sam nodes last) ch = curr = len nodes
// add node and connections nodes ::= samNode 0 (nodes[last].len + 1) empty conn nodes state = if state == -1 then (nodes, state) else case nodes[state].next[ch] of some x => (nodes, state) none => conn (nodes[state].next[ch] <- curr) nodes[state].link (nodes, state) = conn nodes last
// case 1, no duplicate if state == -1 then sam nodes curr else
// case 2, duplicate but contains next = nodes[state].next[ch] if nodes[state].len + 1 == nodes[next].len then sam (nodes[curr].link <- q) curr else
// case 3, duplicate and shrink clone = len nodes nodes ::= samNode nodes[next].link (nodes[state].len + 1) nodes[next].next move nodes state = if state == -1 then nodes else case nodes[state].next[ch] of some x if x == next => move (nodes[state].next[ch] <- clone) nodes[state].link none => nodes nodes = move nodes state sam (nodes[next].link <- clone and [curr].link <- clone) curr
标记每个结束结点(从 last 循环调用 link)的 cnt 为 1,使用 DFS 求出 v 到结束结点的不同路径条数即为 cnt[v]
我们采用第二种方法,因为第二种方法不需要修改 extend 逻辑。
绷语言实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
.cnt : SAM -> [Int] .cnt (sam nodes last) =
// make cnt of terminating nodes 1 init cnt state = if state == -1 then cnt else init (cnt[state] <- 1) nodes[state].link cnt = init [0 for len nodes] last
// dfs path count dfs cnt vis x = vis = vis[x] <- true fold nodes[x].next (cnt, vis) ((cnt, vis) => (ch, next) => (cnt, vis) = if vis[next] then (cnt, vis) else dfs cnt vis next (cnt[x] <- cnt[x] + cnt[next], vis)) (dfs cnt [false for len nodes] 0)[0]
firstFactor : Int -> [Int] firstFactor n = fold [2 to n] [0 for n] (arr => i => if arr[i] > 0 then arr else fold [i to n step i] arr (arr => j => arr[j] <- i))
factor : Int -> [(Int, Int)] factor n = ff = firstFactor (1e6 + 4) genFactor arr rem = if rem == 1 then arr else fac = ff rem if len arr == 0 || arr[-1][0] != fac then genFactor (arr :: (fac, 1)) (rem / fac) else genFactor (arr[-1][1] +<- 1) (rem / fac) genFactor [] n
foldFactor : Int -> T -> (Int -> Bool) -> (T -> Int -> Int) -> T foldFactor num init pred f = facList = factor num foldRest res base facID = if facID == len facList then f res base else (fac, cnt) = facList[facID] foldSingle res base usingCnt = if usingCnt > cnt || pred base then res else newRes = foldRest res base (facID + 1) foldSingle newRes (base * fac) (usingCnt + 1) foldSingle res base 0 foldRest init 1 0
main : Void main = n, str = input, input s = fold str init (s => (id, ch) => s.extend ch) cnt = s.cnt res = fold s.nodes 0 (res => (id, node) => if node.len == 0 then res else l = s.nodes[node.link].len + 1 r = node.len mul = foldFactor cnt[id] 0 (fac <= r) (res => fac => if fac >= l && fac <= r then res + 1 else res)) res + mul * cnt[id] print res