This is how I have divided the tasks:
The algorithm is defined in this paper:
https://herbie.uwplse.org/pldi15-paper.pdf
1. Error Calculation on sampling points
- Find the best-accepted level of precision so we calculate the actual
correctly up to 64 bits by increasing the precision until we get constant
up to 64 bits.
- Calculate the value using float( hardware precision).
- Compare real and float value by calculating *E(x,y) = log(2,z)*, z
= number of floating-point between real and approximate answers.
- averaging these differences over the sampling to see how accurate
the expression is.
1. Pick Candidate
- Pick candidates (subexpression) from the table and apply error
calculation as well as a local error at each location on the sampled
points.
- the database will have a set of rewrite rules like commutativity,
associativity, distributivity, (x + y) = (x**2 - y**2)/(x - y), (x - y) =
(x**3 - y**3)/(x**2 + y**2 + x*y), x = exp(log( x )) etc..
1. Recursive- rewrite
- Rewrite candidates using the database of rules and simplify to
cancel terms.
- Recursively repeat the algorithm on the best subexpression.
2. Series Expansion
1. Finding the series expansion of expressions to remove error near 0
and infinity.
3. Candidate Tree
1. Only keep those candidates that give the best accuracy on at least
one location.
2. On every iteration of the outer loop, the algorithm chooses the
program from the table and uses it to find new candidates, every program
is
used once.
3. Candidate programs are also saved as a pair of maps that are tied
with the location that they are best at.
4. removing candidates if more than one candidate is giving the same
results based on their results at other locations.
5. Before the series approximation step, we will use the set cover
approximation algorithm to prune candidates to have a minimal set.
4. Get Piecewise solutions
1. Split is found using dynamic programming and later refined using
binary search.
These are the functions:
*Definition-main(program) :*
points := sample-inputs(program)
exacts := evaluate-exact(program, points)
table := make-candidate-table(simplify(program))
repeat N times
candidate:= pick-candidate(table)
locations := sort-by-local-error(all-locations(candidate))
locations.take(M )
rewritten := recursive-rewrite(candidate, locations)
new-candidates := simplify-each(rewritten)
table.add(new-candidates)
approximated := series-expansion(candidate, locations)
table.add(approximated)
return infer-regimes(table).as-program
*Definition local-error(expr, points) :*
for point ∈ points :
args := evaluate-exact(expr.children)
exact-ans := F(expr.operation.apply-exact(args))
approx-ans := expr.operation.apply-approx(F(args))
accumulate E(exact-ans, approx-ans)
*Definition recursive-rewrite(expr, target) :*
select input
output from RULES
where input.head = expr.operator ∧ output.head = target.head
for (subexpr, subpattern) ∈ zip(expr.children, input.children) :
if ¬matches(subexpr, subpattern) :
recursive-rewrite(subexpr, subpattern)
where matches(expr, input)
expr.rewrite(input -> output)
*Definition infer-regimes(candidates, points) :*
for x i ∈ points :
best-split 0 [x i ] = [regime(best-candidate(−∞, x i ), −∞, x i )]
for n ∈ N until best-split n+1 = best-split n :
for x i ∈ points ∪ {∞} :
for x j ∈ points, x j < x i :
extra-regime := regime(best-candidate(x j , x i ), x i , x j )
option[x j ] := best-split[x j ] ++ [extra-regime]
best-split n+1 [x i ] := lowest-error(option)
if best-split n [x i ].error − 1 ≤ best-split n+1 [x i ].error :
best-split n+1 [x i ] := best-split n [x i ]
split:= best-split ∗ [∞]
split.refine-by-binary-search()
return split
I have written a basic brute force code.
from sympy import *
import numpy as np
x = Symbol("x")
y = Symbol("y")
z = Symbol("z")
#points
maxi = 1000000000000000000000000000000
mini = -1000000000000000000000000000000
step = (maxi-mini)/256
start = mini
points = []
for i in range(0,256):
points.append(start)
start += step
#calculate error
def calc_error(expr,point):
from mpmath import mp, mpf
mp.dps = 1000
symb = expr.atoms(Symbol)
unq_sym = len(symb)
subst_sym = []
subst_mpf = []
i=0
for sym in symb:
subst_sym.append((x,point))
#subst_sym.append((x,300))
#subst_sym.append((z,400))
subst_mpf.append( ( sym,mpf(point) ) )
i = i+1
ans1 = expr.subs(subst_sym)
ans2 = expr.subs(subst_mpf)
return ans1,ans2
#replacement functions
#database
rule = []
rule.append(lambda exp: (exp.args[0]**3 + exp.args[1]**3)/(exp.args[0]**2 +
exp.args[1]**2 - exp.args[1]*exp.args[0]) if type(exp) ==Add and len
(exp.args)==2 and (type(exp.args[0]) != Symbol and type(exp.args[1]) !=
Symbol) else False)
rule.append(lambda exp: (exp.args[0]**2 - exp.args[1]**2)/(exp.args[0
]-exp.args[1]) if type(exp) ==Add and len(exp.args)==2 and (type(exp.args[0])
!= Symbol and type(exp.args[1]) != Symbol)else False)
#rule.append(lambda expr: exp(log(expr)))
#rule1 = lambdify([x, y, (x**2 - y**2)/(x+y))
expr = (x+1)**0.5 - x**0.5
#expr = x+1-x #+ (1/(x-1)) - (2/x)
#expr = (-y + (-4*x*z + y**2)**0.5)/(2*x)
#expr = (x+1)**(1/3) - x**(1/3)
#expr = (x+1)**0.25 - x**0.25
expr = simplify(expr)
main_expr = expr
temp_expr = expr
pprint(rule[0](expr))
print(rule[1](expr))
all_func = []
mapper = dict()
def pre(expr):
for fi in rule:
print("see ",expr)
for pnt in points:
for fi in rule:
k = fi(expr)
if k!=False:
ans1, ans2 = calc_error(expr, pnt)
update_func = main_expr.subs(expr, k)
ans3, ans4 = calc_error(update_func, pnt)
diff1 = ans2 - ans1
diff2 = ans4 - ans3
#print("istrue ", ans4==ans2)
#print("check1 ", ans4)
#print("check2 ", ans2)
#print("difference ", diff1, " ", diff2)
try:
if abs(diff2) <= abs(diff1): #(diff2) <= (diff1)
all_func.append(update_func)
print("inn ", pnt)
try:
lister = mapper[update_func]
lister.append(pnt)
mapper[update_func] = lister
except:
mapper[update_func] = [pnt]
except:
print("failed ",pnt)
#print(abs(diff1), " ", abs(diff2))
#gprint(isinstance(diff2, complex))
#print(update_func)
#exit()
#all_func.append(simplify(main_expr.subs(expr,k)))
print(expr, " space ",)
for arg in expr.args:
pre(arg)
if len(all_func)==0:
all_func.append(simplify(main_expr))
pre(expr)
#print(srepr(expr))
#pprint(set(all_func))
#print("mapper")
#print(mapper)
values = list(set(all_func))
i=0
for expr in values:
values[i] = expr.replace( lambda expr: type(expr)== Pow and expr.args[1]==1,
lambda expr: Mul(expr.args[0],1))
i = i+1
pprint(values)Enter code here...
Output:
1.5 1.5
- x + (x + 1) 1
0.5 0.5
────────────── , ────────── , - x + (x +
1)
0.5 0.5 0.5 0.5
x ⋅ (x + 1) + 2⋅x + 1 x + (x + 1)
please tell me your thoughts on this, any addition, deletion of algorithmic
improvement.
--
You received this message because you are subscribed to the Google Groups
"sympy" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
To view this discussion on the web visit
https://groups.google.com/d/msgid/sympy/f94a7d30-dc4c-4ba0-a333-e67ff1b63513%40googlegroups.com.