This seems to work now, but I'm wondering if Charles is correct, that 
inheritance isn't such a great idea here.

The advantage of inheritance is I don't have to implement forwarding all the 
functions, a pretty big advantage. (I wonder if there is some way to do some 
of these as a generic 'mixin'?)

I was concerned that using a vectorize function in __array_wrap__ would 
result in a recursive call to __array_wrap__ (and in some earlier version I 
did get infinite recursion).  This version doesn't seem to have a problem, 
although I don't know why.

Any thoughts?  Code attached.
import numpy as np

        
def rnd (x, frac_bits, _max):
    x1 = x >> (frac_bits-1)
    if (x1 == _max):
        return x1 >> 1
    else:
        return (x1+1) >> 1

def shift_left (x, bits):
    return x << bits

def shift_right (x, bits):
    return x >> bits

def shiftup_or_rnddn (x, bits, _max, rnd_policy):
    if (bits > 0):
        return shift_left (x, bits)
    elif (bits < 0):
        return rnd_policy (x, -bits, _max)
    else:
        return x

def clip (x, _min, _max):
    if x > _max:
        return _max
    elif x < _min:
        return _min
    else:
        return x

def throw (x, _min, _max):
    import exceptions
    class fixed_pt_overflow (exceptions.Exception):
        pass
    if x > _max or x < _min:
        raise fixed_pt_overflow
    return x

class fixed_pt (object):
    def get_max(self):
        if self.is_signed:
            return (~(self.base_type(-1) << (self.total_bits-1)))
        else:
            return (~(self.base_type (-1) << self.total_bits))
        
    def get_min(self):
        if self.is_signed:
            return ((self.base_type(-1) << (self.total_bits-1)))
        else:
            return 0
        

    def __init__ (self, int_bits, frac_bits, val, scale=True, base_type=int, rnd_policy=rnd, overflow_policy=clip, is_signed=True):
        self.is_signed = is_signed
        self.int_bits = int_bits
        self.frac_bits = frac_bits
        self.base_type = base_type
        self.total_bits = int_bits + frac_bits
        self.rnd_policy = rnd_policy
        self._max = self.get_max ()
        self._min = self.get_min ()
        self.overflow_policy = overflow_policy
        
        if scale:
            self.val = self.overflow_policy (self.base_type (shiftup_or_rnddn (val, frac_bits, self._max, self.rnd_policy)), self._min, self._max)
            

    def as_double (self):
        return np.ldexp (self.val, -self.frac_bits)
    
    def as_base (self):
        return shiftup_or_rnddn (self.val, -self.frac_bits, self._max, self.rnd_policy)

    def __repr__(self):
        return "[%s <%s,%s>]" % (self.val, self.int_bits, self.frac_bits)
    
def get_max(is_signed, base_type, total_bits):
    if is_signed:
        return (~(base_type(-1) << (total_bits-1)))
    else:
        return (~(base_type (-1) << total_bits))

def get_min(is_signed, base_type, total_bits):
    if is_signed:
        return ((base_type(-1) << (total_bits-1)))
    else:
        return 0


class fixed_pt_array(np.ndarray):
    def __new__(cls, input_array, int_bits, frac_bits, scale=True, base_type=int, rnd_policy=rnd, overflow_policy=clip, is_signed=True):
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(input_array, dtype=base_type).view(cls)
        # add the new attribute to the created instance
        obj.int_bits = int_bits
        obj.frac_bits = frac_bits
        obj.rnd_policy = rnd_policy
        obj.overflow_policy = overflow_policy
        obj.is_signed = is_signed
        obj.scale = scale
        obj.base_type = base_type
        obj.total_bits = int_bits + frac_bits
        obj._max = get_max(is_signed, base_type, obj.total_bits)
        obj._min = get_min(is_signed, base_type, obj.total_bits)
        if scale:
            def _scale (val):
                return overflow_policy (base_type (shiftup_or_rnddn (val, frac_bits, obj._max, rnd_policy)), obj._min, obj._max)
            vecfunc = np.vectorize (_scale)
            obj = vecfunc (obj)
 
        # Finally, we must return the newly created object:
        return obj
    
    def __array_finalize__(self,obj):
        # reset the attribute from passed original object
        if hasattr (obj, 'int_bits'):
            self.int_bits = obj.int_bits
            self.frac_bits = obj.frac_bits
            self.rnd_policy = obj.rnd_policy
            self.overflow_policy = obj.overflow_policy
            self.is_signed = obj.is_signed
            self.base_type = obj.base_type
            self.total_bits = obj.total_bits
            self._max = obj._max
            self._min = obj._min
            ## self._max = get_max(self.is_signed, self.base_type, self.total_bits)
            ## self._min = get_min(self.is_signed, self.base_type, self.total_bits)
        
        
        # We do not need to return anything
    ## def __getitem__ (self, index):
    ##     return fp.fixed_pt_int64_clip (self.int_bits, self.frac_bits, int(np.ndarray.__getitem__(self, index)))
    def as_double (self):
        import math
        def _as_double_1 (x):
            return math.ldexp (x, -self.frac_bits)
        vecfunc = np.vectorize (_as_double_1)
        return np.array (vecfunc (self), dtype=float)

    def as_base (self):
        def _shiftup_or_rnddn (x):
            return shiftup_or_rnddn (x, -self.frac_bits, self._max, self.rnd_policy)
        vecfunc = np.vectorize (_shiftup_or_rnddn)
        return np.array (vecfunc (self), dtype=self.base_type)
    
    def check (self):
        def _overflow_func (x):
            return self.overflow_policy (x, self._min, self._max)
        vecfunc = np.vectorize (_overflow_func)
        return vecfunc (self)
    
    ## The purpose for this is to apply the overflow policy
    def __array_wrap__(self, *args):
        print 'fixed_pt_array.__array_wrap__%r' % (args,)
        def _overflow_func (x):
            return self.overflow_policy (x, self._min, self._max)
        vecfunc = np.vectorize (_overflow_func)
        print 'arg0:', args[0]
        args = (vecfunc (args[0]), args[1:])
        # then just call the parent
        return super(fixed_pt_array, self).__array_wrap__(*args)

        
fp = fixed_pt (5, 5, 1)

arr = np.arange(5,dtype=int)
obj = fixed_pt_array(arr, int_bits=5, frac_bits=5)

print obj*100


## obj2 = fixed_pt_array (arr, int_bits=5, frac_bits=5, overflow_policy=throw)
## obj3 = obj2 * 100


_______________________________________________
NumPy-Discussion mailing list
NumPy-Discussion@scipy.org
http://mail.scipy.org/mailman/listinfo/numpy-discussion

Reply via email to