This is how I have divided the tasks:
The algorithm is defined in this paper: 
https://herbie.uwplse.org/pldi15-paper.pdf


   1.  Error Calculation on sampling points


   - Find the best-accepted level of precision so we calculate the actual 
      correctly up to 64 bits by increasing the precision until we get constant 
      up to 64 bits.
      - Calculate the value using float( hardware precision).
      - Compare real and float value by calculating *E(x,y) = log(2,z)*,  z 
      = number of floating-point between real and approximate answers.
      - averaging these differences over the sampling to see how accurate 
      the expression is.
   

   1. Pick Candidate


   - Pick candidates (subexpression) from the table and apply error 
      calculation as well as a local error at each location on the sampled 
points.
      - the database will have a set of rewrite rules like commutativity, 
      associativity, distributivity, (x + y) = (x**2 - y**2)/(x - y), (x - y) = 
      (x**3 - y**3)/(x**2 + y**2 + x*y), x = exp(log( x )) etc..
   

   1. Recursive- rewrite
      - Rewrite candidates using the database of rules and simplify to 
      cancel terms.
      - Recursively repeat the algorithm on the best subexpression. 
   2. Series Expansion
      1. Finding the series expansion of expressions to remove error near 0 
      and infinity.
   3. Candidate Tree
      1. Only keep those candidates that give the best accuracy on at least 
      one location.
      2. On every iteration of the outer loop, the algorithm chooses the 
      program from the table and uses it to find new candidates, every program 
is 
      used once.
      3. Candidate programs are also saved as a pair of maps that are tied 
      with the location that they are best at.
      4. removing candidates if more than one candidate is giving the same 
      results based on their results at other locations.
      5. Before the series approximation step, we will use the set cover 
      approximation algorithm to prune candidates to have a minimal set.
   4.  Get Piecewise solutions
      1. Split is found using dynamic programming and later refined using 
      binary search.
   

These are the functions:



*Definition-main(program) :*
points := sample-inputs(program)

exacts := evaluate-exact(program, points)

table := make-candidate-table(simplify(program))

repeat N times

    candidate:= pick-candidate(table)

    locations := sort-by-local-error(all-locations(candidate))

    locations.take(M ) 

    rewritten := recursive-rewrite(candidate, locations)

    new-candidates := simplify-each(rewritten)

    table.add(new-candidates)

    approximated := series-expansion(candidate, locations)

    table.add(approximated)

return infer-regimes(table).as-program




*Definition local-error(expr, points) :*
for point ∈ points :

    args := evaluate-exact(expr.children)

    exact-ans := F(expr.operation.apply-exact(args))

    approx-ans := expr.operation.apply-approx(F(args))

    accumulate E(exact-ans, approx-ans)



*Definition recursive-rewrite(expr, target) :*

select input

output from RULES

where input.head = expr.operator ∧ output.head = target.head

for (subexpr, subpattern) ∈ zip(expr.children, input.children) :

    if ¬matches(subexpr, subpattern) :

        recursive-rewrite(subexpr, subpattern)

where matches(expr, input)

expr.rewrite(input -> output)



*Definition infer-regimes(candidates, points) :*

for x i ∈ points :

    best-split 0 [x i ] = [regime(best-candidate(−∞, x i ), −∞, x i )]

for n ∈ N until best-split n+1 = best-split n :

    for x i ∈ points ∪ {∞} :

        for x j ∈ points, x j < x i :

            extra-regime := regime(best-candidate(x j , x i ), x i , x j )

            option[x j ] := best-split[x j ] ++ [extra-regime]

        best-split n+1 [x i ] := lowest-error(option)

        if best-split n [x i ].error − 1 ≤ best-split n+1 [x i ].error :

            best-split n+1 [x i ] := best-split n [x i ]

split:= best-split ∗ [∞]

split.refine-by-binary-search()

return split 



I have written a basic brute force code.
from sympy import *
import numpy as np

x = Symbol("x")
y = Symbol("y")
z = Symbol("z")


#points
maxi = 1000000000000000000000000000000
mini = -1000000000000000000000000000000
step = (maxi-mini)/256

start = mini
points = []

for i in range(0,256):
points.append(start)
start += step

#calculate error
def calc_error(expr,point):
from mpmath import mp, mpf
mp.dps = 1000
symb = expr.atoms(Symbol)
unq_sym = len(symb)


subst_sym = []
subst_mpf = []

i=0
for sym in symb:
subst_sym.append((x,point))
#subst_sym.append((x,300))
#subst_sym.append((z,400)) 
subst_mpf.append( ( sym,mpf(point) ) )
i = i+1

ans1 = expr.subs(subst_sym)
ans2 = expr.subs(subst_mpf)
return ans1,ans2

#replacement functions
#database 
rule = []
rule.append(lambda exp: (exp.args[0]**3 + exp.args[1]**3)/(exp.args[0]**2 + 
exp.args[1]**2 - exp.args[1]*exp.args[0]) if type(exp) ==Add and len
(exp.args)==2 and (type(exp.args[0]) != Symbol and type(exp.args[1]) != 
Symbol) else False)
rule.append(lambda exp: (exp.args[0]**2 - exp.args[1]**2)/(exp.args[0
]-exp.args[1]) if type(exp) ==Add and len(exp.args)==2 and (type(exp.args[0]) 
!= Symbol and type(exp.args[1]) != Symbol)else False)
#rule.append(lambda expr: exp(log(expr)))

#rule1 = lambdify([x, y, (x**2 - y**2)/(x+y))

expr = (x+1)**0.5 - x**0.5
#expr = x+1-x #+ (1/(x-1)) - (2/x)
#expr = (-y + (-4*x*z + y**2)**0.5)/(2*x)
#expr = (x+1)**(1/3) - x**(1/3)
#expr = (x+1)**0.25 - x**0.25 
expr = simplify(expr)
main_expr = expr
temp_expr = expr
pprint(rule[0](expr))
print(rule[1](expr))

all_func = []
mapper = dict()

def pre(expr):
for fi in rule:
print("see ",expr)
for pnt in points:
for fi in rule:
k = fi(expr)
if k!=False: 
ans1, ans2 = calc_error(expr, pnt)
update_func = main_expr.subs(expr, k)
ans3, ans4 = calc_error(update_func, pnt)
diff1 = ans2 - ans1
diff2 = ans4 - ans3
#print("istrue ", ans4==ans2)
#print("check1 ", ans4)
#print("check2 ", ans2)
#print("difference ", diff1, " ", diff2)
try:
if abs(diff2) <= abs(diff1): #(diff2) <= (diff1) 
all_func.append(update_func)
print("inn ", pnt)
try:
lister = mapper[update_func]
lister.append(pnt)
mapper[update_func] = lister
except:
mapper[update_func] = [pnt] 
except:
print("failed ",pnt)
#print(abs(diff1), " ", abs(diff2))
#gprint(isinstance(diff2, complex))
#print(update_func)
#exit() 
#all_func.append(simplify(main_expr.subs(expr,k)))
print(expr, " space ",)
for arg in expr.args:
pre(arg)

if len(all_func)==0:
all_func.append(simplify(main_expr))

pre(expr)
#print(srepr(expr))
#pprint(set(all_func))
#print("mapper")
#print(mapper)
values = list(set(all_func))
i=0
for expr in values:
values[i] = expr.replace( lambda expr: type(expr)== Pow and expr.args[1]==1, 
lambda expr: Mul(expr.args[0],1))
i = i+1
pprint(values)Enter code here...

Output:

        1.5             1.5                                           
    - x    +  (x + 1)                                       1              
                    0.5            0.5 
 ──────────────          ,        ──────────       ,         - x    + (x + 
1)   
   0.5         0.5                                     0.5            0.5  
                   
 x   ⋅ (x + 1)    + 2⋅x + 1                     x    + (x + 1)              
         




please tell me your thoughts on this, any addition, deletion of algorithmic 
improvement.

-- 
You received this message because you are subscribed to the Google Groups 
"sympy" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to [email protected].
To view this discussion on the web visit 
https://groups.google.com/d/msgid/sympy/f94a7d30-dc4c-4ba0-a333-e67ff1b63513%40googlegroups.com.

Reply via email to