Hi All,
Using Sqlalchemy 1.3.23
I am getting a NotImplementedError: Operator 'getitem' is not supported on
this expression
when sorting on some hybrid_properties.
I have attached a sample code to replicate it.
falls over with the following traceback:
Traceback (most recent call last):
File "testholdings.py", line 526, in <module>
trans =
db_session.query(Transaction).order_by(desc(Transaction.total_cost)).all()
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/ext/hybrid.py",
line 898, in __get__
return self._expr_comparator(owner)
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/ext/hybrid.py",
line 1105, in expr_comparator
comparator(owner),
File "testholdings.py", line 135, in total_cost
return TotalCostComparator(cls)
File "testholdings.py", line 89, in __init__
expr = case(
File "<string>", line 2, in case
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py",
line 2437, in __init__
whenlist = [
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py",
line 2439, in <listcomp>
for (c, r) in whens
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/operators.py",
line 432, in __getitem__
return self.operate(getitem, index)
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py",
line 762, in operate
return op(self.comparator, *other, **kwargs)
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/operators.py",
line 432, in __getitem__
return self.operate(getitem, index)
File "<string>", line 1, in <lambda>
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/type_api.py",
line 67, in operate
return o[0](self.expr, op, *(other + o[1:]), **kwargs)
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/default_comparator.py",
line 237, in _getitem_impl
_unsupported_impl(expr, op, other, **kw)
File
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/default_comparator.py",
line 241, in _unsupported_impl
raise NotImplementedError(
NotImplementedError: Operator 'getitem' is not supported on this expression
Thanks in advance,
George
--
SQLAlchemy -
The Python SQL Toolkit and Object Relational Mapper
http://www.sqlalchemy.org/
To post example code, please provide an MCVE: Minimal, Complete, and Verifiable
Example. See http://stackoverflow.com/help/mcve for a full description.
---
You received this message because you are subscribed to the Google Groups
"sqlalchemy" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
To view this discussion on the web visit
https://groups.google.com/d/msgid/sqlalchemy/6c323f70-51c1-4724-ac12-58100fb1e3fen%40googlegroups.com.
import decimal
from datetime import date
from sqlalchemy import select, asc, desc, cast, case, event, create_engine
from sqlalchemy import Column, ForeignKey, Integer, Numeric, Enum, Date, String
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm import relationship, configure_mappers, sessionmaker
from sqlalchemy.ext.hybrid import hybrid_property, Comparator
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine('sqlite:///:memory:', echo=False)
Base = declarative_base()
class StockCompanyShareInfo(Base):
__tablename__ = "StockCompanyShareInfos"
Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)
LastTradeDate = Column(Date, default=None)
LastPrice = Column(Numeric(9,4), default=0)
# OneToOne side of StockCompany
ItemStockCompany_Id = Column(Integer, ForeignKey("StockCompanies.Id"))
ItemStockCompany = relationship("StockCompany", back_populates="ShareInfo",
primaryjoin="StockCompany.Id==StockCompanyShareInfo.ItemStockCompany_Id")
class StockCompany(Base):
__tablename__ = "StockCompanies"
Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)
Ticker = Column(String(8), index=True, nullable=False)
# One2One side of CompanyShare
ShareInfo = relationship("StockCompanyShareInfo", uselist=False,
back_populates="ItemStockCompany",
primaryjoin="StockCompanyShareInfo.ItemStockCompany_Id==StockCompany.Id",
cascade="all, delete-orphan")
class HybridComparator(Comparator):
def __init__(self, expr):
super().__init__(expr)
def __eq__(self, val):
expr = self.__clause_element__()
return expr == val
def __ne__(self, val):
expr = self.__clause_element__()
return expr != val
def __ge__(self, val):
expr = self.__clause_element__()
return expr >= val
def __gt__(self, val):
expr = self.__clause_element__()
return expr > val
def __le__(self, val):
expr = self.__clause_element__()
return expr <= val
def __lt__(self, val):
expr = self.__clause_element__()
return expr < val
def asc(self):
expr = self.__clause_element__()
return asc(expr)
def desc(self):
expr = self.__clause_element__()
return desc(expr)
class TotalValueComparator(HybridComparator):
def __init__(self, cls):
# need to cast to 2 decimals - display is 2 decimals
# cls.UnitPrice is 4 decimals
expr = cast(cls.Units * cls.UnitPrice, Numeric(9, 2))
super().__init__(expr)
class TotalCostComparator(HybridComparator):
def __init__(self, cls):
expr = case(
(cls.Type == "SELL", cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) - cls.Brokerage),
else_=cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) + cls.Brokerage
)
#expr = case(
# (cls.Type == "SELL", cls.total_value - cls.Brokerage),
# else_=cls.total_value + cls.Brokerage
#)
super().__init__(expr)
class Transaction(Base):
__tablename__ = "Transactions"
Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)
Type = Column(Enum("BUY", "SELL", name="HoldingTransactionType"),
nullable=False, default="BUY"
)
Date = Column(Date, nullable=False, default=None)
Units = Column(Integer, nullable=False)
UnitPrice = Column(Numeric(9, 4), nullable=False)
Brokerage = Column(Numeric(9, 2))
# Many2One side of Holding
ItemHolding_Id = Column(Integer, ForeignKey("Holdings.Id"))
# calculated columns
@hybrid_property
def total_value(self):
return self.Units * self.UnitPrice
@total_value.comparator
def total_value(cls):
return TotalValueComparator(cls)
@hybrid_property
def total_cost(self):
if self.Type == "SELL":
return self.total_value - self.Brokerage
return self.total_value + self.Brokerage
@total_cost.comparator
def total_cost(cls):
return TotalCostComparator(cls)
def on_transaction_delete(mapper, connection, target):
db_session = object_session(target)
holding = getattr(target, "ItemHolding")
total_units = 0
total_value = decimal.Decimal(0)
running_unit_price = decimal.Decimal(0)
for trans in sorted(holding.Transactions, key=lambda obj: obj.Date):
if trans in db_session.deleted:
continue
units = int(trans.Units)
trans_unit_price = decimal.Decimal(trans.UnitPrice)
if trans.Type == "SELL":
total_units = total_units - units
total_value = total_value - (units * running_unit_price)
else:
total_units = total_units + units
total_value = total_value + (units * trans_unit_price)
if total_units == 0:
running_unit_price = 0
else:
running_unit_price = total_value / total_units
if total_units == 0:
ave_unit_price = 0
else:
ave_unit_price = total_value / total_units
connection.execute(
Holding.__table__.update().
values(Units=total_units, UnitPrice=ave_unit_price).
where(Holding.Id == target.ItemHolding_Id)
)
event.listen(Transaction, "after_delete", on_transaction_delete) # Mapper Event
class MarketValueComparator(HybridComparator):
def __init__(self, cls):
# need to cast to 2 decimals - display is 2 decimals
# cls.last_price is 4 decimals
expr = cast(cls.Units * cls.last_price, Numeric(9, 2))
super().__init__(expr)
class VarianceComparator(HybridComparator):
def __init__(self, cls):
# need to cast to 2 decimals - display is 2 decimals
total_cost = cast(cls.Units * cls.UnitPrice, Numeric(9, 2))
market_value = cast(cls.Units * cls.last_price, Numeric(9, 2))
expr = market_value - total_cost
super().__init__(expr)
class VariancePercentComparator(HybridComparator):
def __init__(self, cls):
total_cost = cast(cls.Units * cls.UnitPrice, Numeric(9, 2))
market_value = cast(cls.Units * cls.last_price, Numeric(9, 2))
if total_cost == 0:
expr = 0
else:
expr = ((market_value - total_cost) / total_cost) * 100
super().__init__(expr)
class ThresholdPriceComparator(HybridComparator):
def __init__(self, cls):
if cls.Threshold == 0:
expr = 0
else:
threshold = (1 - (cls.Threshold / 100))
expr = case(
(cls.UnitPrice > cls.last_price, cast(cls.UnitPrice * threshold, Numeric(9,2))),
else_=cast(cls.last_price * threshold, Numeric(9,2))
)
super().__init__(expr)
class ThresholdValueComparator(HybridComparator):
def __init__(self, cls):
expr = cls.threshold_price * cls.Units
super().__init__(expr)
class ThresholdVarianceComparator(HybridComparator):
def __init__(self, cls):
expr = cls.threshold_value - cls.total_cost
super().__init__(expr)
class Holding(Base):
__tablename__ = "Holdings"
Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)
# Many2One
StockCompany_Id = Column(Integer, ForeignKey("StockCompanies.Id"), nullable=False)
StockCompany = relationship("StockCompany", primaryjoin="StockCompany.Id==Holding.StockCompany_Id")
Units = Column(Integer, default=0)
UnitPrice = Column(Numeric(9, 4), default=0)
Threshold = Column(Integer, default=0)
# One2Many
Transactions = relationship("Transaction", uselist=True, backref="ItemHolding",
order_by="desc(Transaction.Date)",
cascade="all, delete-orphan")
# calculated columns
@hybrid_property
def total_cost(self):
return self.Units * self.UnitPrice
@total_cost.comparator
def total_cost(cls):
return TotalValueComparator(cls)
@hybrid_property
def last_price(self):
return self.StockCompany.ShareInfo.LastPrice
@last_price.expression
def last_price(cls):
return select([StockCompanyShareInfo.LastPrice]).\
where(StockCompanyShareInfo.ItemStockCompany_Id == cls.StockCompany_Id).\
as_scalar()
@hybrid_property
def market_value(self):
return self.Units * self.last_price
@market_value.comparator
def market_value(cls):
return MarketValueComparator(cls)
@hybrid_property
def variance(self):
return self.market_value - self.total_cost
@variance.comparator
def variance(cls):
return VarianceComparator(cls)
@hybrid_property
def variance_percent(self):
if self.total_cost == 0:
return 0
return (self.variance / self.total_cost) * 100
@variance_percent.comparator
def variance_percent(cls):
return VariancePercentComparator(cls)
@hybrid_property
def threshold_price(self):
if self.Threshold == 0:
return 0
threshold = decimal.Decimal(1 - (self.Threshold / 100)) # float to decimal
if self.UnitPrice > self.last_price:
return self.UnitPrice * threshold
else:
return self.last_price * threshold
@threshold_price.comparator
def threshold_price(cls):
return ThresholdPriceComparator(cls)
@hybrid_property
def threshold_value(self):
return self.threshold_price * self.Units
@threshold_value.comparator
def threshold_value(cls):
return ThresholdValueComparator(cls)
@hybrid_property
def threshold_variance(self):
return self.threshold_value - self.total_cost
@threshold_variance.comparator
def threshold_variance(cls):
return ThresholdVarianceComparator(cls)
def on_holding_update(mapper, connection, target):
db_session = object_session(target)
transactions = getattr(target, "Transactions")
total_units = 0
total_value = decimal.Decimal(0)
running_unit_price = decimal.Decimal(0)
counter = []
for trans in sorted(transactions, key=lambda obj: obj.Date):
# dont know why but there are duplicate transactions in update mode
if trans in db_session.dirty:
if trans.Id in counter:
continue
else:
counter.append(trans.Id)
units = int(trans.Units)
trans_unit_price = decimal.Decimal(trans.UnitPrice)
if trans.Type == "SELL":
total_units = total_units - units
total_value = total_value - (units * running_unit_price)
else:
total_units = total_units + units
total_value = total_value + (units * trans_unit_price)
if total_units == 0:
running_unit_price = 0
else:
running_unit_price = total_value / total_units
if total_units == 0:
ave_unit_price = 0
else:
ave_unit_price = total_value / total_units
setattr(target, "Units", total_units)
setattr(target, "UnitPrice", ave_unit_price)
event.listen(Holding, "before_insert", on_holding_update) # Mapper Event
event.listen(Holding, "before_update", on_holding_update) # Mapper Event
if __name__ == '__main__':
configure_mappers()
Base.metadata.create_all(engine)
db_session = sessionmaker(bind=engine)()
# populate tables
coy = StockCompany()
coy.Ticker = "GVV"
info = StockCompanyShareInfo()
info.LastTradeDate = date(2021, 11, 18)
info.LastPrice = 0.0300
db_session.add(info)
coy.ShareInfo = info
db_session.add(coy)
db_session.commit()
coy = StockCompany()
coy.Ticker = "PRV"
info = StockCompanyShareInfo()
info.LastTradeDate = date(2021, 11, 18)
info.LastPrice = 0.1000
db_session.add(info)
coy.ShareInfo = info
db_session.add(coy)
db_session.commit()
coy = db_session.query(StockCompany).filter(StockCompany.Ticker == "GVV").first()
if coy is not None:
hold = Holding()
hold.StockCompany = coy
hold.Threshold = 10
db_session.add(hold)
trans = Transaction()
trans.Type = "BUY"
trans.Date = date(2018, 3, 27)
trans.Units = 250000
trans.UnitPrice = 0.0200
trans.Brokerage = 19.95
db_session.add(trans)
hold.Transactions.append(trans)
trans = Transaction()
trans.Type = "SELL"
trans.Date = date(2018, 4, 20)
trans.Units = 250000
trans.UnitPrice = 0.0210
trans.Brokerage = 19.95
db_session.add(trans)
hold.Transactions.append(trans)
trans = Transaction()
trans.Type = "BUY"
trans.Date = date(2018, 5, 2)
trans.Units = 312500
trans.UnitPrice = 0.0160
trans.Brokerage = 19.95
db_session.add(trans)
hold.Transactions.append(trans)
db_session.commit()
coy = db_session.query(StockCompany).filter(StockCompany.Ticker == "PRV").first()
if coy is not None:
hold = Holding()
hold.StockCompany = coy
hold.Threshold = 0
db_session.add(hold)
trans = Transaction()
trans.Type = "BUY"
trans.Date = date(2021, 7, 12)
trans.Units = 57472
trans.UnitPrice = 0.0870
trans.Brokerage = 19.95
db_session.add(trans)
hold.Transactions.append(trans)
trans = Transaction()
trans.Type = "BUY"
trans.Date = date(2021, 10, 7)
trans.Units = 143800
trans.UnitPrice = 0.1450
trans.Brokerage = 19.95
db_session.add(trans)
hold.Transactions.append(trans)
trans = Transaction()
trans.Type = "BUY"
trans.Date = date(2021, 11, 1)
trans.Units = 1643
trans.UnitPrice = 0.1450
trans.Brokerage = 0
db_session.add(trans)
hold.Transactions.append(trans)
db_session.commit()
print("StockCompanies")
coys = db_session.query(StockCompany).all()
format_string = "{:<8} {:<13} {:>9}"
print(format_string.format("Ticker","LastTradeDate","LastPrice"))
for co in coys:
print(format_string.format(co.Ticker, str(co.ShareInfo.LastTradeDate),
str(co.ShareInfo.LastPrice)))
print("\nHoldings")
holds = db_session.query(Holding).all()
format_string = "{:<8} {:>7} {:>9} {:>9} {:>10} {:>11} {:>10} {:>9} {:>13} {:>14} {:>17} {:>10}"
print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
"MarketValue","Variance","Variance%","ThresholdPrice",
"ThresholdValue","ThresholdVariance","Threshold%"))
for hold in holds:
print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
str(hold.UnitPrice), str(hold.last_price),
str(hold.total_cost), str(hold.market_value),
str(hold.variance),
"{:.2f}".format(hold.variance_percent),
"{:.4f}".format(hold.threshold_price),
"{:.2f}".format(hold.threshold_value),
"{:.2f}".format(hold.threshold_variance),
str(hold.Threshold)))
print("\nHoldings - Sort by Variance% asc - OK")
holds = db_session.query(Holding).order_by(asc(Holding.variance_percent)).all()
print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
"MarketValue","Variance","Variance%","ThresholdPrice",
"ThresholdValue","ThresholdVariance","Threshold%"))
for hold in holds:
print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
str(hold.UnitPrice), str(hold.last_price),
str(hold.total_cost), str(hold.market_value),
str(hold.variance),
"{:.2f}".format(hold.variance_percent),
"{:.4f}".format(hold.threshold_price),
"{:.2f}".format(hold.threshold_value),
"{:.2f}".format(hold.threshold_variance),
str(hold.Threshold)))
print("\nTransactions")
trans = db_session.query(Transaction).all()
format_string = "{:<8} {:<4} {:<10} {:>7} {:>9} {:>10} {:>9} {:>10}"
print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue",
"Brokerage","TotalCost"))
for tran in trans:
print(format_string.format(tran.ItemHolding.StockCompany.Ticker,
tran.Type, str(tran.Date), str(tran.Units),
str(tran.UnitPrice),
"{:.2f}".format(tran.total_value),
str(tran.Brokerage),
"{:.2f}".format(tran.total_cost)))
print("\nTransactions - Sort by TotalValue asc - OK")
trans = db_session.query(Transaction).order_by(asc(Transaction.total_value)).all()
print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue",
"Brokerage","TotalCost"))
for tran in trans:
print(format_string.format(tran.ItemHolding.StockCompany.Ticker,
tran.Type, str(tran.Date), str(tran.Units),
str(tran.UnitPrice),
"{:.2f}".format(tran.total_value),
str(tran.Brokerage),
"{:.2f}".format(tran.total_cost)))
###################################################################################
print("\nTransactions - Sort by TotalCost desc - NOT OK")
trans = db_session.query(Transaction).order_by(desc(Transaction.total_cost)).all()
# trans = db_session.query(Transaction).order_by(Transaction.total_cost.desc()).all()
print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue",
"Brokerage","TotalCost"))
for tran in trans:
print(format_string.format(tran.ItemHolding.StockCompany.Ticker,
tran.Type, str(tran.Date), str(tran.Units),
str(tran.UnitPrice),
"{:.2f}".format(tran.total_value),
str(tran.Brokerage),
"{:.2f}".format(tran.total_cost)))
print("\nHoldings - Sort by ThresholdPrice asc - NOT OK")
holds = db_session.query(Holding).order_by(asc(Holding.threshold_price)).all()
format_string = "{:<8} {:>7} {:>9} {:>9} {:>10} {:>11} {:>10} {:>9} {:>13} {:>14} {:>17} {:>10}"
print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
"MarketValue","Variance","Variance%","ThresholdPrice",
"ThresholdValue","ThresholdVariance","Threshold%"))
for hold in holds:
print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
str(hold.UnitPrice), str(hold.last_price),
str(hold.total_cost), str(hold.market_value),
str(hold.variance),
"{:.2f}".format(hold.variance_percent),
"{:.4f}".format(hold.threshold_price),
"{:.2f}".format(hold.threshold_value),
"{:.2f}".format(hold.threshold_variance),
str(hold.Threshold)))
print("\nHoldings - Sort by ThresholdValue asc - NOT OK")
holds = db_session.query(Holding).order_by(asc(Holding.threshold_value)).all()
print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
"MarketValue","Variance","Variance%","ThresholdPrice",
"ThresholdValue","ThresholdVariance","Threshold%"))
for hold in holds:
print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
str(hold.UnitPrice), str(hold.last_price),
str(hold.total_cost), str(hold.market_value),
str(hold.variance),
"{:.2f}".format(hold.variance_percent),
"{:.4f}".format(hold.threshold_price),
"{:.2f}".format(hold.threshold_value),
"{:.2f}".format(hold.threshold_variance),
str(hold.Threshold)))
print("\nHoldings - Sort by ThresholdVariance asc - NOT OK")
holds = db_session.query(Holding).order_by(asc(Holding.threshold_variance)).all()
print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
"MarketValue","Variance","Variance%","ThresholdPrice",
"ThresholdValue","ThresholdVariance","Threshold%"))
for hold in holds:
print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
str(hold.UnitPrice), str(hold.last_price),
str(hold.total_cost), str(hold.market_value),
str(hold.variance),
"{:.2f}".format(hold.variance_percent),
"{:.4f}".format(hold.threshold_price),
"{:.2f}".format(hold.threshold_value),
"{:.2f}".format(hold.threshold_variance),
str(hold.Threshold)))