A Floor Sum

Computer Science Level pending

Write a function that computes

S ( n , r ) = i = 1 n i r \displaystyle S(n,r) = \sum_{i=1}^{n} \left\lfloor i\sqrt{r} \right\rfloor

Submit S ( 1 0 100 , 2020 ) mod 1 0 14 S(10^{100}, 2020) \text{ mod } 10^{14}

Note that you don't need a really good computer to make this computation, it should take less than a second.


The answer is 38653388879862.

This section requires Javascript.
You are seeing this because something didn't load right. We suggest you, (a) try refreshing the page, (b) enabling javascript if it is disabled on your browser and, finally, (c) loading the non-javascript version of this page . We're sorry about the hassle.

1 solution

Julian Poon
Nov 14, 2020

The idea is to use Beatty Sequence to simplify computation. Given 1 r + 1 s = 1 \frac{1}{r} + \frac{1}{s}=1 , s i \lfloor s*i\rfloor and r i \lfloor r*i\rfloor partitions the set of positive integers, for positive integer i i . Hence we have:

S ( n , r ) = e ( e + 1 ) 2 S ( n , s ) , e = n r , n = e s S(n,r) = \frac{e(e+1)}{2} - S(n',s), e = \lfloor nr\rfloor, n' = \left\lfloor \frac{e}{s} \right\rfloor

Now if 1 < r < 2 1<r<2 , n n' will be pretty small relative to n n . Since S ( n , r + k ) = k n ( n + 1 ) 2 + S ( n , r ) S(n,r+k) = k\frac{n(n+1)}{2} + S(n,r) if k k is an integer, we can form the following recursion:

S ( n , r ) = r 0 n ( n + 1 ) 2 + S ( n , r 1 ) = r 0 n ( n + 1 ) 2 + e 1 ( e 1 + 1 ) 2 S ( n 1 , s 1 ) r 0 = r 1 , r 1 = r r 0 e 1 = n r 1 , n 1 = e 1 s 1 , s 1 = 1 1 1 r 1 \begin{aligned} S(n,r) &= r_0\frac{n(n+1)}{2} + S(n,r_1) \\ &= r_0\frac{n(n+1)}{2} + \frac{e_1(e_1+1)}{2} - S(n_1,s_1) \\ r_0 &= \lfloor r \rfloor-1, r_1 = r - r_0 \\ e_1 &= \lfloor nr_1\rfloor , n_1 = \left\lfloor \frac{e_1}{s_1} \right\rfloor , s_1 = 1 - \frac{1}{1-r_1} \end{aligned}

This ensures that n 1 n_1 is always significantly smaller than n n , and hence drastically speed up the computation of the recursion. For accuracy, this can be implemented in sympy. The implementation is below and it takes ~30s to solve.

But we can do way better. We can implement the symbolic calculations only with integers and no floats to preserve precision. We can also calculate rational approximations to r \sqrt{r} to the precision we need, store the result, and only calculate to higher precisions when we need to, and remove all recursions (Python has huge overhead in recursion). This results in a 4000x speedup, calculating the answer in only 0.0089s. The implementation is also below. Below I also calculate for larger n = 1 0 1000 n = 10^{1000} and n = 1 0 10000 n=10^{10000} . Do note I'm using python 3.8 here, which is required for math.comb .

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
######################
### Sympy Solution ###
######################

import sympy as sym
import time

def S_sym(n, r):

    if n==0: return 0

    r0 = sym.floor(r) - 1
    r1 = r - r0
    s1 = 1 - 1/(1-r1)
    e1 = sym.floor(n*r1)
    n1 = sym.floor(e1/s1)

    return r0*n*(n+1)/2 + e1*(e1+1)/2 - S_sym(n1, s1)

n,r = 10**100, 2020
t = time.time()
print(int(S_sym(n, sym.sqrt(r)))%10**14)
print(time.time()-t)

# 38653388879862
# 35.8780882358551 << Time taken

#####################################
### Very efficient implementation ###
#####################################

from math import *
import time

def frac_simplify(frac):
    a,b = frac
    g = gcd(a, b)
    a //= g; b //= g
    return a,b

def frac_add(frac1, frac2):
    a,b = frac1; c,d = frac2
    return a*d + c*b, d*b

def compute_sqrt(n, thres):

    '''Compute p/q ~ sqrt(n) with error < 1/thres'''

    if 'sqrt_cache' not in globals():
        global sqrt_cache
        sqrt_cache = None

    if not sqrt_cache:

        N = int(sqrt(n))
        if N*N == n: return N,1

        d0 = n-N*N
        d1 = n-(N+1)*(N+1)

        if d0 > abs(d1): d = d1; N += 1
        else: d = d0

        N2 = N*N
        end = N2//d

        p,q = 0,1
        i = 0
        d_pow, N_pow, f_pow, end_pow = 1, 1, 1, 1

    else:
        (p,q), (d_pow,N_pow,f_pow,end_pow), (N,N2,i,d,end) = sqrt_cache

    while True:

        if end_pow>thres:
            sqrt_cache = (p,q), (d_pow,N_pow,f_pow,end_pow), (N,N2,i,d,end)
            return frac_simplify((p*N,q))

        a = comb(2*i,i) * d_pow
        b = f_pow*(1-2*i) * N_pow
        p,q = frac_simplify(frac_add((p,q), (a,b)))

        d_pow *= -d
        N_pow *= N2
        f_pow *= 4
        end_pow *= end
        i += 1

def evaluate_floor_r(r, n):

    '''
    compute floor(n*r)
    r is the format ((a,b), (c,d), e), representing a/b + c/d*sqrt(e)
    '''

    (a,b), (c,d), e = r
    p,q = compute_sqrt(e, n*c*n)
    p,q = p*c,q*d

    return (n*(p*b + a*q))//(q*b)

def inv_r(r):

    '''
    compute r_new = 1/r
    r is the format ((a,b), (c,d), e), representing a/b + c/d*sqrt(e)
    '''

    (a,b), (c,d), e = r

    ad = a*d; bd = b*d; cb = c*b
    x = ad*bd
    y = bd*cb
    z = ad*ad - cb*cb*e

    a0,b0 = frac_simplify((x, z))
    c0,d0 = frac_simplify((y,-z))

    r_inv = ((a0,b0), (c0,d0), e)

    return r_inv

def f(r):

    '''
    Computes r_new = 1 - 1/(1-r)
    r and r0 is the format of ((a,b), (c,d), e), representing a/b + c/d*sqrt(e)
    '''

    (a,b), (c,d), e = r

    db = d*b; bc = b*c; bad = (b-a)*d
    x = bad*db
    y = db*bc
    z = bad*bad - bc*bc*e

    a0,b0 = frac_simplify((z-x, z))
    c0,d0 = frac_simplify((y,-z))

    return ((a0, b0), (c0, d0), e)

def S(n, r):

    '''
    Computes \sum_{i=1}^n floor(r*i)
    r is the format of ((a,b), (c,d), e), representing a/b + c/d*sqrt(e)
    '''

    total_sum = 0
    idx = 1

    while True:

        if n==0: break

        (a,b), (c,d), e = r

        # r0 = floor(r) - 1
        # r1 = r - r0
        # s1 = 1 - 1/(1-r1)
        # e1 = floor(n*r1)
        # n1 = floor(e1/s1)

        r0 = evaluate_floor_r(r,1) - 1
        r1 = ((a-r0*b, b), (c,d), e)
        s1 = f(r1)
        e1 = evaluate_floor_r(r1, n)
        n1 = evaluate_floor_r(inv_r(s1), e1)

        total_sum += (r0*n*(n+1)//2 + e1*(e1+1)//2)*idx
        idx *= -1

        n,r = n1,s1

    return total_sum

def solution(n, r):

    if 'sqrt_cache' in globals():
        global sqrt_cache
        del sqrt_cache

    return S(n,((0,1),(1,1),r))

n,r = 10**100, 2020
t = time.time()
print(solution(n, r) % 10**14)
print(time.time() - t)

# 38653388879862
# 0.008942127227783203 <-- Time taken

n,r = 10**1000, 2020
t = time.time()
print(solution(n, r))
print(time.time() - t)

# 2247220505424423186459814044549067942996196616015125639349691367432075664529111959801867066651236497854824070507252872015187461279608646640546527579642920723324078866229464049036944369436031988265230364337186436988234822020785916185169947801717358007323481809701411183664983080518339192630052564466924876658385550117456137806204817241193055861215718976144653515213321023373660738869346381975084659013438816367175439914859648053227563403808030000281583021909592727656436024695485207442441952097880026869806742333958334567905171247536636590312387635751260866996606053041260914439812349473289702423625541818707831734144522896870604531623823613288497545868011621256718992204150781943360737622483136860087781020770617816853849592017800622893861713867451202545859492176034873400242996158324061459397505912384141554163694554602764616094409661460806303421816141087672455809267723440924034362368347770628843094426150463067433766239995252384979821586002181068511803580695124164054455235831013423284430184541334026790458905673849027752417331733152881073168039847295979150332953097491542111203679559174367979450984603774227953339812197472602153889606504497167370582777526644079865948761536153067649249032977167928970640256475302341576171951344448904880723981844623971486443037698541831850581464964704727585402750515131688825271606472609304553733950614359672230988176182854202465178242515135759543773399877988731504225311653790086021058186609149256594159984666548420206920528801753982352348925324581355056739101936808474895686322120142569356222668706204287171710881250681606088260667053192270997643185316538150158305712532689139139378395048545872653028662671565842437217600143710656297660543301397107198032744469265795771302754075659971867739891116627211994587291354184991500511112309261557897954832657188316770303337605479489067326168671515882698009031623261357142643053454710407789706015168857586201515555311233642196623232027163981746295381195738530056080648702408878015214259102236935521993730976345009829789585
# 1.645200252532959 <-- Time taken

n,r = 10**10000, 2020
t = time.time()
print(solution(n, r))
print(time.time() - t)

# 
# 1439.6152136325836 <-- Time taken

0 pending reports

×

Problem Loading...

Note Loading...

Set Loading...