http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/cassandra/CassandraDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/cassandra/CassandraDataModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/cassandra/CassandraDataModel.java new file mode 100644 index 0000000..b220993 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/cassandra/CassandraDataModel.java @@ -0,0 +1,465 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model.cassandra; + +import com.google.common.base.Preconditions; +import me.prettyprint.cassandra.model.HColumnImpl; +import me.prettyprint.cassandra.serializers.BytesArraySerializer; +import me.prettyprint.cassandra.serializers.FloatSerializer; +import me.prettyprint.cassandra.serializers.LongSerializer; +import me.prettyprint.cassandra.service.OperationType; +import me.prettyprint.hector.api.Cluster; +import me.prettyprint.hector.api.ConsistencyLevelPolicy; +import me.prettyprint.hector.api.HConsistencyLevel; +import me.prettyprint.hector.api.Keyspace; +import me.prettyprint.hector.api.beans.ColumnSlice; +import me.prettyprint.hector.api.beans.HColumn; +import me.prettyprint.hector.api.factory.HFactory; +import me.prettyprint.hector.api.mutation.Mutator; +import me.prettyprint.hector.api.query.ColumnQuery; +import me.prettyprint.hector.api.query.CountQuery; +import me.prettyprint.hector.api.query.SliceQuery; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.Cache; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.Retriever; +import org.apache.mahout.cf.taste.impl.model.GenericItemPreferenceArray; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; + +import java.io.Closeable; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/** + * <p>A {@link DataModel} based on a Cassandra keyspace. By default it uses keyspace "recommender" but this + * can be configured. Create the keyspace before using this class; this can be done on the Cassandra command + * line with a command linke {@code create keyspace recommender;}.</p> + * + * <p>Within the keyspace, this model uses four column families:</p> + * + * <p>First, it uses a column family called "users". This is keyed by the user ID as an 8-byte long. + * It contains a column for every preference the user expresses. The column name is item ID, again as + * an 8-byte long, and value is a floating point value represnted as an IEEE 32-bit floating poitn value.</p> + * + * <p>It uses an analogous column family called "items" for the same data, but keyed by item ID rather + * than user ID. In this column family, column names are user IDs instead.</p> + * + * <p>It uses a column family called "userIDs" as well, with an identical schema. It has one row under key + * 0. IT contains a column for every user ID in th emodel. It has no values.</p> + * + * <p>Finally it also uses an analogous column family "itemIDs" containing item IDs.</p> + * + * <p>Each of these four column families needs to be created ahead of time. Again the + * Cassandra CLI can be used to do so, with commands like {@code create column family users;}.</p> + * + * <p>Note that this thread uses a long-lived Cassandra client which will run until terminated. You + * must {@link #close()} this implementation when done or the JVM will not terminate.</p> + * + * <p>This implementation still relies heavily on reading data into memory and caching, + * as it remains too data-intensive to be effective even against Cassandra. It will take some time to + * "warm up" as the first few requests will block loading user and item data into caches. This is still going + * to send a great deal of query traffic to Cassandra. It would be advisable to employ caching wrapper + * classes in your implementation, like {@link org.apache.mahout.cf.taste.impl.recommender.CachingRecommender} + * or {@link org.apache.mahout.cf.taste.impl.similarity.CachingItemSimilarity}.</p> + */ +public final class CassandraDataModel implements DataModel, Closeable { + + /** Default Cassandra host. Default: localhost */ + private static final String DEFAULT_HOST = "localhost"; + + /** Default Cassandra port. Default: 9160 */ + private static final int DEFAULT_PORT = 9160; + + /** Default Cassandra keyspace. Default: recommender */ + private static final String DEFAULT_KEYSPACE = "recommender"; + + static final String USERS_CF = "users"; + static final String ITEMS_CF = "items"; + static final String USER_IDS_CF = "userIDs"; + static final String ITEM_IDS_CF = "itemIDs"; + private static final long ID_ROW_KEY = 0L; + private static final byte[] EMPTY = new byte[0]; + + private final Cluster cluster; + private final Keyspace keyspace; + private final Cache<Long,PreferenceArray> userCache; + private final Cache<Long,PreferenceArray> itemCache; + private final Cache<Long,FastIDSet> itemIDsFromUserCache; + private final Cache<Long,FastIDSet> userIDsFromItemCache; + private final AtomicReference<Integer> userCountCache; + private final AtomicReference<Integer> itemCountCache; + + /** + * Uses the standard Cassandra host and port (localhost:9160), and keyspace name ("recommender"). + */ + public CassandraDataModel() { + this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_KEYSPACE); + } + + /** + * @param host Cassandra server host name + * @param port Cassandra server port + * @param keyspaceName name of Cassandra keyspace to use + */ + public CassandraDataModel(String host, int port, String keyspaceName) { + + Preconditions.checkNotNull(host); + Preconditions.checkArgument(port > 0, "port must be greater then 0!"); + Preconditions.checkNotNull(keyspaceName); + + cluster = HFactory.getOrCreateCluster(CassandraDataModel.class.getSimpleName(), host + ':' + port); + keyspace = HFactory.createKeyspace(keyspaceName, cluster); + keyspace.setConsistencyLevelPolicy(new OneConsistencyLevelPolicy()); + + userCache = new Cache<>(new UserPrefArrayRetriever(), 1 << 20); + itemCache = new Cache<>(new ItemPrefArrayRetriever(), 1 << 20); + itemIDsFromUserCache = new Cache<>(new ItemIDsFromUserRetriever(), 1 << 20); + userIDsFromItemCache = new Cache<>(new UserIDsFromItemRetriever(), 1 << 20); + userCountCache = new AtomicReference<>(null); + itemCountCache = new AtomicReference<>(null); + } + + @Override + public LongPrimitiveIterator getUserIDs() { + SliceQuery<Long,Long,?> query = buildNoValueSliceQuery(USER_IDS_CF); + query.setKey(ID_ROW_KEY); + FastIDSet userIDs = new FastIDSet(); + for (HColumn<Long,?> userIDColumn : query.execute().get().getColumns()) { + userIDs.add(userIDColumn.getName()); + } + return userIDs.iterator(); + } + + @Override + public PreferenceArray getPreferencesFromUser(long userID) throws TasteException { + return userCache.get(userID); + } + + @Override + public FastIDSet getItemIDsFromUser(long userID) throws TasteException { + return itemIDsFromUserCache.get(userID); + } + + @Override + public LongPrimitiveIterator getItemIDs() { + SliceQuery<Long,Long,?> query = buildNoValueSliceQuery(ITEM_IDS_CF); + query.setKey(ID_ROW_KEY); + FastIDSet itemIDs = new FastIDSet(); + for (HColumn<Long,?> itemIDColumn : query.execute().get().getColumns()) { + itemIDs.add(itemIDColumn.getName()); + } + return itemIDs.iterator(); + } + + @Override + public PreferenceArray getPreferencesForItem(long itemID) throws TasteException { + return itemCache.get(itemID); + } + + @Override + public Float getPreferenceValue(long userID, long itemID) { + ColumnQuery<Long,Long,Float> query = + HFactory.createColumnQuery(keyspace, LongSerializer.get(), LongSerializer.get(), FloatSerializer.get()); + query.setColumnFamily(USERS_CF); + query.setKey(userID); + query.setName(itemID); + HColumn<Long,Float> column = query.execute().get(); + return column == null ? null : column.getValue(); + } + + @Override + public Long getPreferenceTime(long userID, long itemID) { + ColumnQuery<Long,Long,?> query = + HFactory.createColumnQuery(keyspace, LongSerializer.get(), LongSerializer.get(), BytesArraySerializer.get()); + query.setColumnFamily(USERS_CF); + query.setKey(userID); + query.setName(itemID); + HColumn<Long,?> result = query.execute().get(); + return result == null ? null : result.getClock(); + } + + @Override + public int getNumItems() { + Integer itemCount = itemCountCache.get(); + if (itemCount == null) { + CountQuery<Long,Long> countQuery = + HFactory.createCountQuery(keyspace, LongSerializer.get(), LongSerializer.get()); + countQuery.setKey(ID_ROW_KEY); + countQuery.setColumnFamily(ITEM_IDS_CF); + countQuery.setRange(null, null, Integer.MAX_VALUE); + itemCount = countQuery.execute().get(); + itemCountCache.set(itemCount); + } + return itemCount; + } + + @Override + public int getNumUsers() { + Integer userCount = userCountCache.get(); + if (userCount == null) { + CountQuery<Long,Long> countQuery = + HFactory.createCountQuery(keyspace, LongSerializer.get(), LongSerializer.get()); + countQuery.setKey(ID_ROW_KEY); + countQuery.setColumnFamily(USER_IDS_CF); + countQuery.setRange(null, null, Integer.MAX_VALUE); + userCount = countQuery.execute().get(); + userCountCache.set(userCount); + } + return userCount; + } + + @Override + public int getNumUsersWithPreferenceFor(long itemID) throws TasteException { + /* + CountQuery<Long,Long> query = HFactory.createCountQuery(keyspace, LongSerializer.get(), LongSerializer.get()); + query.setColumnFamily(ITEMS_CF); + query.setKey(itemID); + query.setRange(null, null, Integer.MAX_VALUE); + return query.execute().get(); + */ + return userIDsFromItemCache.get(itemID).size(); + } + + @Override + public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException { + FastIDSet userIDs1 = userIDsFromItemCache.get(itemID1); + FastIDSet userIDs2 = userIDsFromItemCache.get(itemID2); + return userIDs1.size() < userIDs2.size() + ? userIDs2.intersectionSize(userIDs1) + : userIDs1.intersectionSize(userIDs2); + } + + @Override + public void setPreference(long userID, long itemID, float value) { + + if (Float.isNaN(value)) { + value = 1.0f; + } + + long now = System.currentTimeMillis(); + + Mutator<Long> mutator = HFactory.createMutator(keyspace, LongSerializer.get()); + + HColumn<Long,Float> itemForUsers = new HColumnImpl<>(LongSerializer.get(), FloatSerializer.get()); + itemForUsers.setName(itemID); + itemForUsers.setClock(now); + itemForUsers.setValue(value); + mutator.addInsertion(userID, USERS_CF, itemForUsers); + + HColumn<Long,Float> userForItems = new HColumnImpl<>(LongSerializer.get(), FloatSerializer.get()); + userForItems.setName(userID); + userForItems.setClock(now); + userForItems.setValue(value); + mutator.addInsertion(itemID, ITEMS_CF, userForItems); + + HColumn<Long,byte[]> userIDs = new HColumnImpl<>(LongSerializer.get(), BytesArraySerializer.get()); + userIDs.setName(userID); + userIDs.setClock(now); + userIDs.setValue(EMPTY); + mutator.addInsertion(ID_ROW_KEY, USER_IDS_CF, userIDs); + + HColumn<Long,byte[]> itemIDs = new HColumnImpl<>(LongSerializer.get(), BytesArraySerializer.get()); + itemIDs.setName(itemID); + itemIDs.setClock(now); + itemIDs.setValue(EMPTY); + mutator.addInsertion(ID_ROW_KEY, ITEM_IDS_CF, itemIDs); + + mutator.execute(); + } + + @Override + public void removePreference(long userID, long itemID) { + Mutator<Long> mutator = HFactory.createMutator(keyspace, LongSerializer.get()); + mutator.addDeletion(userID, USERS_CF, itemID, LongSerializer.get()); + mutator.addDeletion(itemID, ITEMS_CF, userID, LongSerializer.get()); + mutator.execute(); + // Not deleting from userIDs, itemIDs though + } + + /** + * @return true + */ + @Override + public boolean hasPreferenceValues() { + return true; + } + + /** + * @return Float#NaN + */ + @Override + public float getMaxPreference() { + return Float.NaN; + } + + /** + * @return Float#NaN + */ + @Override + public float getMinPreference() { + return Float.NaN; + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + userCache.clear(); + itemCache.clear(); + userIDsFromItemCache.clear(); + itemIDsFromUserCache.clear(); + userCountCache.set(null); + itemCountCache.set(null); + } + + @Override + public String toString() { + return "CassandraDataModel[" + keyspace + ']'; + } + + @Override + public void close() { + HFactory.shutdownCluster(cluster); + } + + + private SliceQuery<Long,Long,byte[]> buildNoValueSliceQuery(String cf) { + SliceQuery<Long,Long,byte[]> query = + HFactory.createSliceQuery(keyspace, LongSerializer.get(), LongSerializer.get(), BytesArraySerializer.get()); + query.setColumnFamily(cf); + query.setRange(null, null, false, Integer.MAX_VALUE); + return query; + } + + private SliceQuery<Long,Long,Float> buildValueSliceQuery(String cf) { + SliceQuery<Long,Long,Float> query = + HFactory.createSliceQuery(keyspace, LongSerializer.get(), LongSerializer.get(), FloatSerializer.get()); + query.setColumnFamily(cf); + query.setRange(null, null, false, Integer.MAX_VALUE); + return query; + } + + + private static final class OneConsistencyLevelPolicy implements ConsistencyLevelPolicy { + @Override + public HConsistencyLevel get(OperationType op) { + return HConsistencyLevel.ONE; + } + + @Override + public HConsistencyLevel get(OperationType op, String cfName) { + return HConsistencyLevel.ONE; + } + } + + private final class UserPrefArrayRetriever implements Retriever<Long, PreferenceArray> { + @Override + public PreferenceArray get(Long userID) throws TasteException { + SliceQuery<Long,Long,Float> query = buildValueSliceQuery(USERS_CF); + query.setKey(userID); + + ColumnSlice<Long,Float> result = query.execute().get(); + if (result == null) { + throw new NoSuchUserException(userID); + } + List<HColumn<Long,Float>> itemIDColumns = result.getColumns(); + if (itemIDColumns.isEmpty()) { + throw new NoSuchUserException(userID); + } + int size = itemIDColumns.size(); + PreferenceArray prefs = new GenericUserPreferenceArray(size); + prefs.setUserID(0, userID); + for (int i = 0; i < size; i++) { + HColumn<Long,Float> itemIDColumn = itemIDColumns.get(i); + prefs.setItemID(i, itemIDColumn.getName()); + prefs.setValue(i, itemIDColumn.getValue()); + } + return prefs; + } + } + + private final class ItemPrefArrayRetriever implements Retriever<Long, PreferenceArray> { + @Override + public PreferenceArray get(Long itemID) throws TasteException { + SliceQuery<Long,Long,Float> query = buildValueSliceQuery(ITEMS_CF); + query.setKey(itemID); + ColumnSlice<Long,Float> result = query.execute().get(); + if (result == null) { + throw new NoSuchItemException(itemID); + } + List<HColumn<Long,Float>> userIDColumns = result.getColumns(); + if (userIDColumns.isEmpty()) { + throw new NoSuchItemException(itemID); + } + int size = userIDColumns.size(); + PreferenceArray prefs = new GenericItemPreferenceArray(size); + prefs.setItemID(0, itemID); + for (int i = 0; i < size; i++) { + HColumn<Long,Float> userIDColumn = userIDColumns.get(i); + prefs.setUserID(i, userIDColumn.getName()); + prefs.setValue(i, userIDColumn.getValue()); + } + return prefs; + } + } + + private final class UserIDsFromItemRetriever implements Retriever<Long, FastIDSet> { + @Override + public FastIDSet get(Long itemID) throws TasteException { + SliceQuery<Long,Long,byte[]> query = buildNoValueSliceQuery(ITEMS_CF); + query.setKey(itemID); + ColumnSlice<Long,byte[]> result = query.execute().get(); + if (result == null) { + throw new NoSuchItemException(itemID); + } + List<HColumn<Long,byte[]>> columns = result.getColumns(); + FastIDSet userIDs = new FastIDSet(columns.size()); + for (HColumn<Long,?> userIDColumn : columns) { + userIDs.add(userIDColumn.getName()); + } + return userIDs; + } + } + + private final class ItemIDsFromUserRetriever implements Retriever<Long, FastIDSet> { + @Override + public FastIDSet get(Long userID) throws TasteException { + SliceQuery<Long,Long,byte[]> query = buildNoValueSliceQuery(USERS_CF); + query.setKey(userID); + FastIDSet itemIDs = new FastIDSet(); + ColumnSlice<Long,byte[]> result = query.execute().get(); + if (result == null) { + throw new NoSuchUserException(userID); + } + List<HColumn<Long,byte[]>> columns = result.getColumns(); + if (columns.isEmpty()) { + throw new NoSuchUserException(userID); + } + for (HColumn<Long,?> itemIDColumn : columns) { + itemIDs.add(itemIDColumn.getName()); + } + return itemIDs; + } + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/hbase/HBaseDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/hbase/HBaseDataModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/hbase/HBaseDataModel.java new file mode 100644 index 0000000..9735ffe --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/hbase/HBaseDataModel.java @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model.hbase; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.HColumnDescriptor; +import org.apache.hadoop.hbase.HTableDescriptor; +import org.apache.hadoop.hbase.KeyValue; +import org.apache.hadoop.hbase.client.Delete; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.HBaseAdmin; +import org.apache.hadoop.hbase.client.HTableFactory; +import org.apache.hadoop.hbase.client.HTableInterface; +import org.apache.hadoop.hbase.client.HTablePool; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.ResultScanner; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.filter.FilterList; +import org.apache.hadoop.hbase.filter.FirstKeyOnlyFilter; +import org.apache.hadoop.hbase.filter.KeyOnlyFilter; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.model.GenericItemPreferenceArray; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; + +/** + * <p>Naive approach of storing one preference as one value in the table. + * Preferences are indexed as (user, item) and (item, user) for O(1) lookups.</p> + * + * <p>The default table name is "taste", this can be set through a constructor + * argument. Each row has a value starting with "i" or "u" followed by the + * actual id encoded as a big endian long.</p> + * + * <p>E.g., "u\x00\x00\x00\x00\x00\x00\x04\xd2" is user 1234L</p> + * + * <p>There are two column families: "users" and "items".</p> + * + * <p>The "users" column family holds user->item preferences. Each userID is the + * column qualifier and the value is the preference.</p> + * + * <p>The "items" column fmaily holds item->user preferences. Each itemID is the + * column qualifier and the value is the preference.</p> + * + * <p>User IDs and item IDs are cached in a FastIDSet since it requires a full + * table scan to build these sets. Preferences are not cached since they + * are pretty cheap lookups in HBase (also caching the Preferences defeats + * the purpose of a scalable storage engine like HBase).</p> + */ +public final class HBaseDataModel implements DataModel, Closeable { + + private static final Logger log = LoggerFactory.getLogger(HBaseDataModel.class); + + private static final String DEFAULT_TABLE = "taste"; + private static final byte[] USERS_CF = Bytes.toBytes("users"); + private static final byte[] ITEMS_CF = Bytes.toBytes("items"); + + private final HTablePool pool; + private final String tableName; + + // Cache of user and item ids + private volatile FastIDSet itemIDs; + private volatile FastIDSet userIDs; + + public HBaseDataModel(String zkConnect) throws IOException { + this(zkConnect, DEFAULT_TABLE); + } + + public HBaseDataModel(String zkConnect, String tableName) throws IOException { + log.info("Using HBase table {}", tableName); + Configuration conf = HBaseConfiguration.create(); + conf.set("hbase.zookeeper.quorum", zkConnect); + HTableFactory tableFactory = new HTableFactory(); + this.pool = new HTablePool(conf, 8, tableFactory); + this.tableName = tableName; + + bootstrap(conf); + // Warm the cache + refresh(null); + } + + public HBaseDataModel(HTablePool pool, String tableName, Configuration conf) throws IOException { + log.info("Using HBase table {}", tableName); + this.pool = pool; + this.tableName = tableName; + + bootstrap(conf); + + // Warm the cache + refresh(null); + } + + public String getTableName() { + return tableName; + } + + /** + * Create the table if it doesn't exist + */ + private void bootstrap(Configuration conf) throws IOException { + HTableDescriptor tDesc = new HTableDescriptor(Bytes.toBytes(tableName)); + tDesc.addFamily(new HColumnDescriptor(USERS_CF)); + tDesc.addFamily(new HColumnDescriptor(ITEMS_CF)); + try (HBaseAdmin admin = new HBaseAdmin(conf)) { + admin.createTable(tDesc); + log.info("Created table {}", tableName); + } + } + + /** + * Prefix a user id with "u" and convert to byte[] + */ + private static byte[] userToBytes(long userID) { + ByteBuffer bb = ByteBuffer.allocate(9); + bb.put((byte)0x75); // The letter "u" + bb.putLong(userID); + return bb.array(); + } + + /** + * Prefix an item id with "i" and convert to byte[] + */ + private static byte[] itemToBytes(long itemID) { + ByteBuffer bb = ByteBuffer.allocate(9); + bb.put((byte)0x69); // The letter "i" + bb.putLong(itemID); + return bb.array(); + } + + /** + * Extract the id out of a prefix byte[] id + */ + private static long bytesToUserOrItemID(byte[] ba) { + ByteBuffer bb = ByteBuffer.wrap(ba); + return bb.getLong(1); + } + + /* DataModel interface */ + + @Override + public LongPrimitiveIterator getUserIDs() { + return userIDs.iterator(); + } + + @Override + public PreferenceArray getPreferencesFromUser(long userID) throws TasteException { + Result result; + try { + HTableInterface table = pool.getTable(tableName); + Get get = new Get(userToBytes(userID)); + get.addFamily(ITEMS_CF); + result = table.get(get); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to retrieve user preferences from HBase", e); + } + + if (result.isEmpty()) { + throw new NoSuchUserException(userID); + } + + SortedMap<byte[], byte[]> families = result.getFamilyMap(ITEMS_CF); + PreferenceArray prefs = new GenericUserPreferenceArray(families.size()); + prefs.setUserID(0, userID); + int i = 0; + for (Map.Entry<byte[], byte[]> entry : families.entrySet()) { + prefs.setItemID(i, Bytes.toLong(entry.getKey())); + prefs.setValue(i, Bytes.toFloat(entry.getValue())); + i++; + } + return prefs; + } + + @Override + public FastIDSet getItemIDsFromUser(long userID) throws TasteException { + Result result; + try { + HTableInterface table = pool.getTable(tableName); + Get get = new Get(userToBytes(userID)); + get.addFamily(ITEMS_CF); + result = table.get(get); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to retrieve item IDs from HBase", e); + } + + if (result.isEmpty()) { + throw new NoSuchUserException(userID); + } + + SortedMap<byte[],byte[]> families = result.getFamilyMap(ITEMS_CF); + FastIDSet ids = new FastIDSet(families.size()); + for (byte[] family : families.keySet()) { + ids.add(Bytes.toLong(family)); + } + return ids; + } + + @Override + public LongPrimitiveIterator getItemIDs() { + return itemIDs.iterator(); + } + + @Override + public PreferenceArray getPreferencesForItem(long itemID) throws TasteException { + Result result; + try { + HTableInterface table = pool.getTable(tableName); + Get get = new Get(itemToBytes(itemID)); + get.addFamily(USERS_CF); + result = table.get(get); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to retrieve item preferences from HBase", e); + } + + if (result.isEmpty()) { + throw new NoSuchItemException(itemID); + } + + SortedMap<byte[], byte[]> families = result.getFamilyMap(USERS_CF); + PreferenceArray prefs = new GenericItemPreferenceArray(families.size()); + prefs.setItemID(0, itemID); + int i = 0; + for (Map.Entry<byte[], byte[]> entry : families.entrySet()) { + prefs.setUserID(i, Bytes.toLong(entry.getKey())); + prefs.setValue(i, Bytes.toFloat(entry.getValue())); + i++; + } + return prefs; + } + + @Override + public Float getPreferenceValue(long userID, long itemID) throws TasteException { + Result result; + try { + HTableInterface table = pool.getTable(tableName); + Get get = new Get(userToBytes(userID)); + get.addColumn(ITEMS_CF, Bytes.toBytes(itemID)); + result = table.get(get); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to retrieve user preferences from HBase", e); + } + + if (result.isEmpty()) { + throw new NoSuchUserException(userID); + } + + if (result.containsColumn(ITEMS_CF, Bytes.toBytes(itemID))) { + return Bytes.toFloat(result.getValue(ITEMS_CF, Bytes.toBytes(itemID))); + } else { + return null; + } + } + + @Override + public Long getPreferenceTime(long userID, long itemID) throws TasteException { + Result result; + try { + HTableInterface table = pool.getTable(tableName); + Get get = new Get(userToBytes(userID)); + get.addColumn(ITEMS_CF, Bytes.toBytes(itemID)); + result = table.get(get); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to retrieve user preferences from HBase", e); + } + + if (result.isEmpty()) { + throw new NoSuchUserException(userID); + } + + if (result.containsColumn(ITEMS_CF, Bytes.toBytes(itemID))) { + KeyValue kv = result.getColumnLatest(ITEMS_CF, Bytes.toBytes(itemID)); + return kv.getTimestamp(); + } else { + return null; + } + } + + @Override + public int getNumItems() { + return itemIDs.size(); + } + + @Override + public int getNumUsers() { + return userIDs.size(); + } + + @Override + public int getNumUsersWithPreferenceFor(long itemID) throws TasteException { + PreferenceArray prefs = getPreferencesForItem(itemID); + return prefs.length(); + } + + @Override + public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException { + Result[] results; + try { + HTableInterface table = pool.getTable(tableName); + List<Get> gets = new ArrayList<>(2); + gets.add(new Get(itemToBytes(itemID1))); + gets.add(new Get(itemToBytes(itemID2))); + gets.get(0).addFamily(USERS_CF); + gets.get(1).addFamily(USERS_CF); + results = table.get(gets); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to retrieve item preferences from HBase", e); + } + + if (results[0].isEmpty()) { + throw new NoSuchItemException(itemID1); + } + if (results[1].isEmpty()) { + throw new NoSuchItemException(itemID2); + } + + // First item + Result result = results[0]; + SortedMap<byte[], byte[]> families = result.getFamilyMap(USERS_CF); + FastIDSet idSet1 = new FastIDSet(families.size()); + for (byte[] id : families.keySet()) { + idSet1.add(Bytes.toLong(id)); + } + + // Second item + result = results[1]; + families = result.getFamilyMap(USERS_CF); + FastIDSet idSet2 = new FastIDSet(families.size()); + for (byte[] id : families.keySet()) { + idSet2.add(Bytes.toLong(id)); + } + + return idSet1.intersectionSize(idSet2); + } + + @Override + public void setPreference(long userID, long itemID, float value) throws TasteException { + try { + HTableInterface table = pool.getTable(tableName); + List<Put> puts = new ArrayList<>(2); + puts.add(new Put(userToBytes(userID))); + puts.add(new Put(itemToBytes(itemID))); + puts.get(0).add(ITEMS_CF, Bytes.toBytes(itemID), Bytes.toBytes(value)); + puts.get(1).add(USERS_CF, Bytes.toBytes(userID), Bytes.toBytes(value)); + table.put(puts); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to store preference in HBase", e); + } + } + + @Override + public void removePreference(long userID, long itemID) throws TasteException { + try { + HTableInterface table = pool.getTable(tableName); + List<Delete> deletes = new ArrayList<>(2); + deletes.add(new Delete(userToBytes(userID))); + deletes.add(new Delete(itemToBytes(itemID))); + deletes.get(0).deleteColumns(ITEMS_CF, Bytes.toBytes(itemID)); + deletes.get(1).deleteColumns(USERS_CF, Bytes.toBytes(userID)); + table.delete(deletes); + table.close(); + } catch (IOException e) { + throw new TasteException("Failed to remove preference from HBase", e); + } + } + + @Override + public boolean hasPreferenceValues() { + return true; + } + + @Override + public float getMaxPreference() { + throw new UnsupportedOperationException(); + } + + @Override + public float getMinPreference() { + throw new UnsupportedOperationException(); + } + + /* Closeable interface */ + + @Override + public void close() throws IOException { + pool.close(); + } + + /* Refreshable interface */ + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + if (alreadyRefreshed == null || !alreadyRefreshed.contains(this)) { + try { + log.info("Refreshing item and user ID caches"); + long t1 = System.currentTimeMillis(); + refreshItemIDs(); + refreshUserIDs(); + long t2 = System.currentTimeMillis(); + log.info("Finished refreshing caches in {} ms", t2 - t1); + } catch (IOException e) { + throw new IllegalStateException("Could not reload DataModel", e); + } + } + } + + /* + * Refresh the item id cache. Warning: this does a large table scan + */ + private synchronized void refreshItemIDs() throws IOException { + // Get the list of item ids + HTableInterface table = pool.getTable(tableName); + Scan scan = new Scan(new byte[]{0x69}, new byte[]{0x70}); + scan.setFilter(new FilterList(FilterList.Operator.MUST_PASS_ALL, new KeyOnlyFilter(), new FirstKeyOnlyFilter())); + ResultScanner scanner = table.getScanner(scan); + Collection<Long> ids = new LinkedList<>(); + for (Result result : scanner) { + ids.add(bytesToUserOrItemID(result.getRow())); + } + table.close(); + + // Copy into FastIDSet + FastIDSet itemIDs = new FastIDSet(ids.size()); + for (long l : ids) { + itemIDs.add(l); + } + + // Swap with the active + this.itemIDs = itemIDs; + } + + /* + * Refresh the user id cache. Warning: this does a large table scan + */ + private synchronized void refreshUserIDs() throws IOException { + // Get the list of user ids + HTableInterface table = pool.getTable(tableName); + Scan scan = new Scan(new byte[]{0x75}, new byte[]{0x76}); + scan.setFilter(new FilterList(FilterList.Operator.MUST_PASS_ALL, new KeyOnlyFilter(), new FirstKeyOnlyFilter())); + ResultScanner scanner = table.getScanner(scan); + Collection<Long> ids = new LinkedList<>(); + for (Result result : scanner) { + ids.add(bytesToUserOrItemID(result.getRow())); + } + table.close(); + + // Copy into FastIDSet + FastIDSet userIDs = new FastIDSet(ids.size()); + for (long l : ids) { + userIDs.add(l); + } + + // Swap with the active + this.userIDs = userIDs; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractBooleanPrefJDBCDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractBooleanPrefJDBCDataModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractBooleanPrefJDBCDataModel.java new file mode 100644 index 0000000..79ca1ac --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractBooleanPrefJDBCDataModel.java @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model.jdbc; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import javax.sql.DataSource; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.model.BooleanPreference; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.common.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +public abstract class AbstractBooleanPrefJDBCDataModel extends AbstractJDBCDataModel { + + private static final Logger log = LoggerFactory.getLogger(AbstractBooleanPrefJDBCDataModel.class); + + static final String NO_SUCH_COLUMN = "NO_SUCH_COLUMN"; + + private final String setPreferenceSQL; + + protected AbstractBooleanPrefJDBCDataModel(DataSource dataSource, + String preferenceTable, + String userIDColumn, + String itemIDColumn, + String preferenceColumn, + String getPreferenceSQL, + String getPreferenceTimeSQL, + String getUserSQL, + String getAllUsersSQL, + String getNumItemsSQL, + String getNumUsersSQL, + String setPreferenceSQL, + String removePreferenceSQL, + String getUsersSQL, + String getItemsSQL, + String getPrefsForItemSQL, + String getNumPreferenceForItemSQL, + String getNumPreferenceForItemsSQL, + String getMaxPreferenceSQL, + String getMinPreferenceSQL) { + super(dataSource, + preferenceTable, + userIDColumn, + itemIDColumn, + preferenceColumn, + getPreferenceSQL, + getPreferenceTimeSQL, + getUserSQL, + getAllUsersSQL, + getNumItemsSQL, + getNumUsersSQL, + setPreferenceSQL, + removePreferenceSQL, + getUsersSQL, + getItemsSQL, + getPrefsForItemSQL, + getNumPreferenceForItemSQL, + getNumPreferenceForItemsSQL, + getMaxPreferenceSQL, + getMinPreferenceSQL); + this.setPreferenceSQL = setPreferenceSQL; + } + + @Override + protected Preference buildPreference(ResultSet rs) throws SQLException { + return new BooleanPreference(getLongColumn(rs, 1), getLongColumn(rs, 2)); + } + + @Override + String getSetPreferenceSQL() { + return setPreferenceSQL; + } + + @Override + public void setPreference(long userID, long itemID, float value) throws TasteException { + Preconditions.checkArgument(!Float.isNaN(value), "NaN value"); + log.debug("Setting preference for user {}, item {}", userID, itemID); + + Connection conn = null; + PreparedStatement stmt = null; + + try { + conn = getDataSource().getConnection(); + stmt = conn.prepareStatement(setPreferenceSQL); + setLongParameter(stmt, 1, userID); + setLongParameter(stmt, 2, itemID); + + log.debug("Executing SQL update: {}", setPreferenceSQL); + stmt.executeUpdate(); + + } catch (SQLException sqle) { + log.warn("Exception while setting preference", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(null, stmt, conn); + } + } + + @Override + public boolean hasPreferenceValues() { + return false; + } + + @Override + public float getMaxPreference() { + return 1.0f; + } + + @Override + public float getMinPreference() { + return 1.0f; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractJDBCDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractJDBCDataModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractJDBCDataModel.java new file mode 100644 index 0000000..66f0a77 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/AbstractJDBCDataModel.java @@ -0,0 +1,787 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model.jdbc; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.Cache; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.Retriever; +import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; +import org.apache.mahout.cf.taste.impl.common.jdbc.ResultSetIterator; +import org.apache.mahout.cf.taste.impl.model.GenericItemPreferenceArray; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.JDBCDataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import javax.sql.DataSource; + +/** + * <p> + * An abstract superclass for {@link JDBCDataModel} implementations, providing most of the common + * functionality that any such implementation would need. + * </p> + * + * <p> + * Performance will be a concern with any {@link JDBCDataModel}. There are going to be lots of + * simultaneous reads and some writes to one table. Make sure the table is set up optimally -- for example, + * you'll want to establish indexes. + * </p> + * + * <p> + * You'll also want to use connection pooling of some kind. Most J2EE containers like Tomcat provide + * connection pooling, so make sure the {@link DataSource} it exposes is using pooling. Outside a J2EE + * container, you can use packages like Jakarta's <a href="http://jakarta.apache.org/commons/dbcp/">DBCP</a> + * to create a {@link DataSource} on top of your database whose {@link Connection}s are pooled. + * </p> + */ +public abstract class AbstractJDBCDataModel extends AbstractJDBCComponent implements JDBCDataModel { + + private static final Logger log = LoggerFactory.getLogger(AbstractJDBCDataModel.class); + + public static final String DEFAULT_PREFERENCE_TABLE = "taste_preferences"; + public static final String DEFAULT_USER_ID_COLUMN = "user_id"; + public static final String DEFAULT_ITEM_ID_COLUMN = "item_id"; + public static final String DEFAULT_PREFERENCE_COLUMN = "preference"; + public static final String DEFAULT_PREFERENCE_TIME_COLUMN = "timestamp"; + + private final DataSource dataSource; + private final String preferenceTable; + private final String userIDColumn; + private final String itemIDColumn; + private final String preferenceColumn; + private final String getPreferenceSQL; + private final String getPreferenceTimeSQL; + private final String getUserSQL; + private final String getAllUsersSQL; + private final String getNumItemsSQL; + private final String getNumUsersSQL; + private final String setPreferenceSQL; + private final String removePreferenceSQL; + private final String getUsersSQL; + private final String getItemsSQL; + private final String getPrefsForItemSQL; + private final String getNumPreferenceForItemsSQL; + private final String getMaxPreferenceSQL; + private final String getMinPreferenceSQL; + private int cachedNumUsers; + private int cachedNumItems; + private final Cache<Long,Integer> itemPrefCounts; + private float maxPreference; + private float minPreference; + + protected AbstractJDBCDataModel(DataSource dataSource, + String getPreferenceSQL, + String getPreferenceTimeSQL, + String getUserSQL, + String getAllUsersSQL, + String getNumItemsSQL, + String getNumUsersSQL, + String setPreferenceSQL, + String removePreferenceSQL, + String getUsersSQL, + String getItemsSQL, + String getPrefsForItemSQL, + String getNumPreferenceForItemSQL, + String getNumPreferenceForItemsSQL, + String getMaxPreferenceSQL, + String getMinPreferenceSQL) { + this(dataSource, + DEFAULT_PREFERENCE_TABLE, + DEFAULT_USER_ID_COLUMN, + DEFAULT_ITEM_ID_COLUMN, + DEFAULT_PREFERENCE_COLUMN, + getPreferenceSQL, + getPreferenceTimeSQL, + getUserSQL, + getAllUsersSQL, + getNumItemsSQL, + getNumUsersSQL, + setPreferenceSQL, + removePreferenceSQL, + getUsersSQL, + getItemsSQL, + getPrefsForItemSQL, + getNumPreferenceForItemSQL, + getNumPreferenceForItemsSQL, + getMaxPreferenceSQL, + getMinPreferenceSQL); + } + + protected AbstractJDBCDataModel(DataSource dataSource, + String preferenceTable, + String userIDColumn, + String itemIDColumn, + String preferenceColumn, + String getPreferenceSQL, + String getPreferenceTimeSQL, + String getUserSQL, + String getAllUsersSQL, + String getNumItemsSQL, + String getNumUsersSQL, + String setPreferenceSQL, + String removePreferenceSQL, + String getUsersSQL, + String getItemsSQL, + String getPrefsForItemSQL, + String getNumPreferenceForItemSQL, + String getNumPreferenceForItemsSQL, + String getMaxPreferenceSQL, + String getMinPreferenceSQL) { + + log.debug("Creating AbstractJDBCModel..."); + + AbstractJDBCComponent.checkNotNullAndLog("preferenceTable", preferenceTable); + AbstractJDBCComponent.checkNotNullAndLog("userIDColumn", userIDColumn); + AbstractJDBCComponent.checkNotNullAndLog("itemIDColumn", itemIDColumn); + AbstractJDBCComponent.checkNotNullAndLog("preferenceColumn", preferenceColumn); + + AbstractJDBCComponent.checkNotNullAndLog("dataSource", dataSource); + AbstractJDBCComponent.checkNotNullAndLog("getUserSQL", getUserSQL); + AbstractJDBCComponent.checkNotNullAndLog("getAllUsersSQL", getAllUsersSQL); + AbstractJDBCComponent.checkNotNullAndLog("getPreferenceSQL", getPreferenceSQL); + // getPreferenceTimeSQL can be null + AbstractJDBCComponent.checkNotNullAndLog("getNumItemsSQL", getNumItemsSQL); + AbstractJDBCComponent.checkNotNullAndLog("getNumUsersSQL", getNumUsersSQL); + AbstractJDBCComponent.checkNotNullAndLog("setPreferenceSQL", setPreferenceSQL); + AbstractJDBCComponent.checkNotNullAndLog("removePreferenceSQL", removePreferenceSQL); + AbstractJDBCComponent.checkNotNullAndLog("getUsersSQL", getUsersSQL); + AbstractJDBCComponent.checkNotNullAndLog("getItemsSQL", getItemsSQL); + AbstractJDBCComponent.checkNotNullAndLog("getPrefsForItemSQL", getPrefsForItemSQL); + AbstractJDBCComponent.checkNotNullAndLog("getNumPreferenceForItemSQL", getNumPreferenceForItemSQL); + AbstractJDBCComponent.checkNotNullAndLog("getNumPreferenceForItemsSQL", getNumPreferenceForItemsSQL); + AbstractJDBCComponent.checkNotNullAndLog("getMaxPreferenceSQL", getMaxPreferenceSQL); + AbstractJDBCComponent.checkNotNullAndLog("getMinPreferenceSQL", getMinPreferenceSQL); + + if (!(dataSource instanceof ConnectionPoolDataSource)) { + log.warn("You are not using ConnectionPoolDataSource. Make sure your DataSource pools connections " + + "to the database itself, or database performance will be severely reduced."); + } + + this.preferenceTable = preferenceTable; + this.userIDColumn = userIDColumn; + this.itemIDColumn = itemIDColumn; + this.preferenceColumn = preferenceColumn; + + this.dataSource = dataSource; + this.getPreferenceSQL = getPreferenceSQL; + this.getPreferenceTimeSQL = getPreferenceTimeSQL; + this.getUserSQL = getUserSQL; + this.getAllUsersSQL = getAllUsersSQL; + this.getNumItemsSQL = getNumItemsSQL; + this.getNumUsersSQL = getNumUsersSQL; + this.setPreferenceSQL = setPreferenceSQL; + this.removePreferenceSQL = removePreferenceSQL; + this.getUsersSQL = getUsersSQL; + this.getItemsSQL = getItemsSQL; + this.getPrefsForItemSQL = getPrefsForItemSQL; + //this.getNumPreferenceForItemSQL = getNumPreferenceForItemSQL; + this.getNumPreferenceForItemsSQL = getNumPreferenceForItemsSQL; + this.getMaxPreferenceSQL = getMaxPreferenceSQL; + this.getMinPreferenceSQL = getMinPreferenceSQL; + + this.cachedNumUsers = -1; + this.cachedNumItems = -1; + this.itemPrefCounts = new Cache<>(new ItemPrefCountRetriever(getNumPreferenceForItemSQL)); + + this.maxPreference = Float.NaN; + this.minPreference = Float.NaN; + } + + /** @return the {@link DataSource} that this instance is using */ + @Override + public DataSource getDataSource() { + return dataSource; + } + + public String getPreferenceTable() { + return preferenceTable; + } + + public String getUserIDColumn() { + return userIDColumn; + } + + public String getItemIDColumn() { + return itemIDColumn; + } + + public String getPreferenceColumn() { + return preferenceColumn; + } + + String getSetPreferenceSQL() { + return setPreferenceSQL; + } + + @Override + public LongPrimitiveIterator getUserIDs() throws TasteException { + log.debug("Retrieving all users..."); + try { + return new ResultSetIDIterator(getUsersSQL); + } catch (SQLException sqle) { + throw new TasteException(sqle); + } + } + + /** + * @throws NoSuchUserException + * if there is no such user + */ + @Override + public PreferenceArray getPreferencesFromUser(long userID) throws TasteException { + + log.debug("Retrieving user ID '{}'", userID); + + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getUserSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + setLongParameter(stmt, 1, userID); + + log.debug("Executing SQL query: {}", getUserSQL); + rs = stmt.executeQuery(); + + List<Preference> prefs = new ArrayList<>(); + while (rs.next()) { + prefs.add(buildPreference(rs)); + } + + if (prefs.isEmpty()) { + throw new NoSuchUserException(userID); + } + + return new GenericUserPreferenceArray(prefs); + + } catch (SQLException sqle) { + log.warn("Exception while retrieving user", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + + } + + @Override + public FastByIDMap<PreferenceArray> exportWithPrefs() throws TasteException { + log.debug("Exporting all data"); + + Connection conn = null; + Statement stmt = null; + ResultSet rs = null; + + FastByIDMap<PreferenceArray> result = new FastByIDMap<>(); + + try { + conn = dataSource.getConnection(); + stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + + log.debug("Executing SQL query: {}", getAllUsersSQL); + rs = stmt.executeQuery(getAllUsersSQL); + + Long currentUserID = null; + List<Preference> currentPrefs = new ArrayList<>(); + while (rs.next()) { + long nextUserID = getLongColumn(rs, 1); + if (currentUserID != null && !currentUserID.equals(nextUserID) && !currentPrefs.isEmpty()) { + result.put(currentUserID, new GenericUserPreferenceArray(currentPrefs)); + currentPrefs.clear(); + } + currentPrefs.add(buildPreference(rs)); + currentUserID = nextUserID; + } + if (!currentPrefs.isEmpty()) { + result.put(currentUserID, new GenericUserPreferenceArray(currentPrefs)); + } + + return result; + + } catch (SQLException sqle) { + log.warn("Exception while exporting all data", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + + } + } + + @Override + public FastByIDMap<FastIDSet> exportWithIDsOnly() throws TasteException { + log.debug("Exporting all data"); + + Connection conn = null; + Statement stmt = null; + ResultSet rs = null; + + FastByIDMap<FastIDSet> result = new FastByIDMap<>(); + + try { + conn = dataSource.getConnection(); + stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + + log.debug("Executing SQL query: {}", getAllUsersSQL); + rs = stmt.executeQuery(getAllUsersSQL); + + boolean currentUserIDSet = false; + long currentUserID = 0L; // value isn't used + FastIDSet currentItemIDs = new FastIDSet(2); + while (rs.next()) { + long nextUserID = getLongColumn(rs, 1); + if (currentUserIDSet && currentUserID != nextUserID && !currentItemIDs.isEmpty()) { + result.put(currentUserID, currentItemIDs); + currentItemIDs = new FastIDSet(2); + } + currentItemIDs.add(getLongColumn(rs, 2)); + currentUserID = nextUserID; + currentUserIDSet = true; + } + if (!currentItemIDs.isEmpty()) { + result.put(currentUserID, currentItemIDs); + } + + return result; + + } catch (SQLException sqle) { + log.warn("Exception while exporting all data", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + + } + } + + /** + * @throws NoSuchUserException + * if there is no such user + */ + @Override + public FastIDSet getItemIDsFromUser(long userID) throws TasteException { + + log.debug("Retrieving items for user ID '{}'", userID); + + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getUserSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + setLongParameter(stmt, 1, userID); + + log.debug("Executing SQL query: {}", getUserSQL); + rs = stmt.executeQuery(); + + FastIDSet result = new FastIDSet(); + while (rs.next()) { + result.add(getLongColumn(rs, 2)); + } + + if (result.isEmpty()) { + throw new NoSuchUserException(userID); + } + + return result; + + } catch (SQLException sqle) { + log.warn("Exception while retrieving item s", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + + } + + @Override + public Float getPreferenceValue(long userID, long itemID) throws TasteException { + log.debug("Retrieving preferences for item ID '{}'", itemID); + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getPreferenceSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(1); + setLongParameter(stmt, 1, userID); + setLongParameter(stmt, 2, itemID); + + log.debug("Executing SQL query: {}", getPreferenceSQL); + rs = stmt.executeQuery(); + if (rs.next()) { + return rs.getFloat(1); + } else { + return null; + } + } catch (SQLException sqle) { + log.warn("Exception while retrieving prefs for item", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + } + + @Override + public Long getPreferenceTime(long userID, long itemID) throws TasteException { + if (getPreferenceTimeSQL == null) { + return null; + } + log.debug("Retrieving preference time for item ID '{}'", itemID); + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getPreferenceTimeSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(1); + setLongParameter(stmt, 1, userID); + setLongParameter(stmt, 2, itemID); + + log.debug("Executing SQL query: {}", getPreferenceTimeSQL); + rs = stmt.executeQuery(); + if (rs.next()) { + return rs.getLong(1); + } else { + return null; + } + } catch (SQLException sqle) { + log.warn("Exception while retrieving time for item", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + } + + @Override + public LongPrimitiveIterator getItemIDs() throws TasteException { + log.debug("Retrieving all items..."); + try { + return new ResultSetIDIterator(getItemsSQL); + } catch (SQLException sqle) { + throw new TasteException(sqle); + } + } + + @Override + public PreferenceArray getPreferencesForItem(long itemID) throws TasteException { + List<Preference> list = doGetPreferencesForItem(itemID); + if (list.isEmpty()) { + throw new NoSuchItemException(itemID); + } + return new GenericItemPreferenceArray(list); + } + + protected List<Preference> doGetPreferencesForItem(long itemID) throws TasteException { + log.debug("Retrieving preferences for item ID '{}'", itemID); + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getPrefsForItemSQL, ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + setLongParameter(stmt, 1, itemID); + + log.debug("Executing SQL query: {}", getPrefsForItemSQL); + rs = stmt.executeQuery(); + List<Preference> prefs = new ArrayList<>(); + while (rs.next()) { + prefs.add(buildPreference(rs)); + } + return prefs; + } catch (SQLException sqle) { + log.warn("Exception while retrieving prefs for item", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + } + + @Override + public int getNumItems() throws TasteException { + if (cachedNumItems < 0) { + cachedNumItems = getNumThings("items", getNumItemsSQL); + } + return cachedNumItems; + } + + @Override + public int getNumUsers() throws TasteException { + if (cachedNumUsers < 0) { + cachedNumUsers = getNumThings("users", getNumUsersSQL); + } + return cachedNumUsers; + } + + @Override + public int getNumUsersWithPreferenceFor(long itemID) throws TasteException { + return itemPrefCounts.get(itemID); + } + + @Override + public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException { + return getNumThings("user preferring items", getNumPreferenceForItemsSQL, itemID1, itemID2); + } + + private int getNumThings(String name, String sql, long... args) throws TasteException { + log.debug("Retrieving number of {} in model", name); + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + if (args != null) { + for (int i = 1; i <= args.length; i++) { + setLongParameter(stmt, i, args[i - 1]); + } + } + log.debug("Executing SQL query: {}", sql); + rs = stmt.executeQuery(); + rs.next(); + return rs.getInt(1); + } catch (SQLException sqle) { + log.warn("Exception while retrieving number of {}", name, sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + } + + @Override + public void setPreference(long userID, long itemID, float value) throws TasteException { + Preconditions.checkArgument(!Float.isNaN(value), "NaN value"); + + log.debug("Setting preference for user {}, item {}", userID, itemID); + + Connection conn = null; + PreparedStatement stmt = null; + + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(setPreferenceSQL); + setLongParameter(stmt, 1, userID); + setLongParameter(stmt, 2, itemID); + stmt.setDouble(3, value); + stmt.setDouble(4, value); + + log.debug("Executing SQL update: {}", setPreferenceSQL); + stmt.executeUpdate(); + + } catch (SQLException sqle) { + log.warn("Exception while setting preference", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(null, stmt, conn); + } + } + + @Override + public void removePreference(long userID, long itemID) throws TasteException { + + log.debug("Removing preference for user '{}', item '{}'", userID, itemID); + + Connection conn = null; + PreparedStatement stmt = null; + + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(removePreferenceSQL); + setLongParameter(stmt, 1, userID); + setLongParameter(stmt, 2, itemID); + + log.debug("Executing SQL update: {}", removePreferenceSQL); + stmt.executeUpdate(); + + } catch (SQLException sqle) { + log.warn("Exception while removing preference", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(null, stmt, conn); + } + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + cachedNumUsers = -1; + cachedNumItems = -1; + minPreference = Float.NaN; + maxPreference = Float.NaN; + itemPrefCounts.clear(); + } + + @Override + public boolean hasPreferenceValues() { + return true; + } + + @Override + public float getMaxPreference() { + if (Float.isNaN(maxPreference)) { + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getMaxPreferenceSQL); + + log.debug("Executing SQL query: {}", getMaxPreferenceSQL); + rs = stmt.executeQuery(); + rs.next(); + maxPreference = rs.getFloat(1); + + } catch (SQLException sqle) { + log.warn("Exception while removing preference", sqle); + // do nothing + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + } + return maxPreference; + } + + @Override + public float getMinPreference() { + if (Float.isNaN(minPreference)) { + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getMinPreferenceSQL); + + log.debug("Executing SQL query: {}", getMinPreferenceSQL); + rs = stmt.executeQuery(); + rs.next(); + minPreference = rs.getFloat(1); + + } catch (SQLException sqle) { + log.warn("Exception while removing preference", sqle); + // do nothing + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + } + return minPreference; + } + + // Some overrideable methods to customize the class behavior: + + protected Preference buildPreference(ResultSet rs) throws SQLException { + return new GenericPreference(getLongColumn(rs, 1), getLongColumn(rs, 2), rs.getFloat(3)); + } + + /** + * Subclasses may wish to override this if ID values in the file are not numeric. This provides a hook by + * which subclasses can inject an {@link org.apache.mahout.cf.taste.model.IDMigrator} to perform + * translation. + */ + protected long getLongColumn(ResultSet rs, int position) throws SQLException { + return rs.getLong(position); + } + + /** + * Subclasses may wish to override this if ID values in the file are not numeric. This provides a hook by + * which subclasses can inject an {@link org.apache.mahout.cf.taste.model.IDMigrator} to perform + * translation. + */ + protected void setLongParameter(PreparedStatement stmt, int position, long value) throws SQLException { + stmt.setLong(position, value); + } + + /** + * <p> + * An {@link java.util.Iterator} which returns items from a {@link ResultSet}. This is a useful way to + * iterate over all user data since it does not require all data to be read into memory at once. It does + * however require that the DB connection be held open. Note that this class will only release database + * resources after {@link #hasNext()} has been called and has returned {@code false}; callers should + * make sure to "drain" the entire set of data to avoid tying up database resources. + * </p> + */ + private final class ResultSetIDIterator extends ResultSetIterator<Long> implements LongPrimitiveIterator { + + private ResultSetIDIterator(String sql) throws SQLException { + super(dataSource, sql); + } + + @Override + protected Long parseElement(ResultSet resultSet) throws SQLException { + return getLongColumn(resultSet, 1); + } + + @Override + public long nextLong() { + return next(); + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public long peek() { + // This could be supported; is it worth it? + throw new UnsupportedOperationException(); + } + } + + private final class ItemPrefCountRetriever implements Retriever<Long,Integer> { + private final String getNumPreferenceForItemSQL; + + private ItemPrefCountRetriever(String getNumPreferenceForItemSQL) { + this.getNumPreferenceForItemSQL = getNumPreferenceForItemSQL; + } + + @Override + public Integer get(Long key) throws TasteException { + return getNumThings("user preferring item", getNumPreferenceForItemSQL, key); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/ConnectionPoolDataSource.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/ConnectionPoolDataSource.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/ConnectionPoolDataSource.java new file mode 100644 index 0000000..ff7f661 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/ConnectionPoolDataSource.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model.jdbc; + +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.logging.Logger; + +import javax.sql.DataSource; + +import org.apache.commons.dbcp.ConnectionFactory; +import org.apache.commons.dbcp.PoolableConnectionFactory; +import org.apache.commons.dbcp.PoolingDataSource; +import org.apache.commons.pool.impl.GenericObjectPool; + +import com.google.common.base.Preconditions; + +/** + * <p> + * A wrapper {@link DataSource} which pools connections. + * </p> + */ +public final class ConnectionPoolDataSource implements DataSource { + + private final DataSource delegate; + + public ConnectionPoolDataSource(DataSource underlyingDataSource) { + Preconditions.checkNotNull(underlyingDataSource); + ConnectionFactory connectionFactory = new ConfiguringConnectionFactory(underlyingDataSource); + GenericObjectPool objectPool = new GenericObjectPool(); + objectPool.setTestOnBorrow(false); + objectPool.setTestOnReturn(false); + objectPool.setTestWhileIdle(true); + objectPool.setTimeBetweenEvictionRunsMillis(60 * 1000L); + // Constructor actually sets itself as factory on pool + new PoolableConnectionFactory(connectionFactory, objectPool, null, "SELECT 1", false, false); + delegate = new PoolingDataSource(objectPool); + } + + @Override + public Connection getConnection() throws SQLException { + return delegate.getConnection(); + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + return delegate.getConnection(username, password); + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return delegate.getLogWriter(); + } + + @Override + public void setLogWriter(PrintWriter printWriter) throws SQLException { + delegate.setLogWriter(printWriter); + } + + @Override + public void setLoginTimeout(int timeout) throws SQLException { + delegate.setLoginTimeout(timeout); + } + + @Override + public int getLoginTimeout() throws SQLException { + return delegate.getLoginTimeout(); + } + + @Override + public <T> T unwrap(Class<T> iface) throws SQLException { + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class<?> iface) throws SQLException { + return delegate.isWrapperFor(iface); + } + + // This exists for compatibility with Java 7 / JDBC 4.1, but doesn't exist + // in Java 6. In Java 7 it would @Override, but not in 6. + // @Override + public Logger getParentLogger() throws SQLFeatureNotSupportedException { + throw new SQLFeatureNotSupportedException(); + } + + private static class ConfiguringConnectionFactory implements ConnectionFactory { + + private final DataSource underlyingDataSource; + + ConfiguringConnectionFactory(DataSource underlyingDataSource) { + this.underlyingDataSource = underlyingDataSource; + } + + @Override + public Connection createConnection() throws SQLException { + Connection connection = underlyingDataSource.getConnection(); + connection.setTransactionIsolation(Connection.TRANSACTION_READ_UNCOMMITTED); + connection.setHoldability(ResultSet.CLOSE_CURSORS_AT_COMMIT); + return connection; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/GenericJDBCDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/GenericJDBCDataModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/GenericJDBCDataModel.java new file mode 100644 index 0000000..5dd0be9 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/model/jdbc/GenericJDBCDataModel.java @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model.jdbc; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +import com.google.common.io.Closeables; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; + +/** + * <p> + * A generic {@link org.apache.mahout.cf.taste.model.DataModel} designed for use with other JDBC data sources; + * one just specifies all necessary SQL queries to the constructor here. Optionally, the queries can be + * specified from a {@link Properties} object, {@link File}, or {@link InputStream}. This class is most + * appropriate when other existing implementations of {@link AbstractJDBCDataModel} are not suitable. If you + * are using this class to support a major database, consider contributing a specialized implementation of + * {@link AbstractJDBCDataModel} to the project for this database. + * </p> + */ +public final class GenericJDBCDataModel extends AbstractJDBCDataModel { + + public static final String DATA_SOURCE_KEY = "dataSource"; + public static final String GET_PREFERENCE_SQL_KEY = "getPreferenceSQL"; + public static final String GET_PREFERENCE_TIME_SQL_KEY = "getPreferenceTimeSQL"; + public static final String GET_USER_SQL_KEY = "getUserSQL"; + public static final String GET_ALL_USERS_SQL_KEY = "getAllUsersSQL"; + public static final String GET_NUM_USERS_SQL_KEY = "getNumUsersSQL"; + public static final String GET_NUM_ITEMS_SQL_KEY = "getNumItemsSQL"; + public static final String SET_PREFERENCE_SQL_KEY = "setPreferenceSQL"; + public static final String REMOVE_PREFERENCE_SQL_KEY = "removePreferenceSQL"; + public static final String GET_USERS_SQL_KEY = "getUsersSQL"; + public static final String GET_ITEMS_SQL_KEY = "getItemsSQL"; + public static final String GET_PREFS_FOR_ITEM_SQL_KEY = "getPrefsForItemSQL"; + public static final String GET_NUM_PREFERENCE_FOR_ITEM_KEY = "getNumPreferenceForItemSQL"; + public static final String GET_NUM_PREFERENCE_FOR_ITEMS_KEY = "getNumPreferenceForItemsSQL"; + public static final String GET_MAX_PREFERENCE_KEY = "getMaxPreferenceSQL"; + public static final String GET_MIN_PREFERENCE_KEY = "getMinPreferenceSQL"; + + /** + * <p> + * Specifies all SQL queries in a {@link Properties} object. See the {@code *_KEY} constants in this + * class (e.g. {@link #GET_USER_SQL_KEY}) for a list of all keys which must map to a value in this object. + * </p> + * + * @param props + * {@link Properties} object containing values + * @throws TasteException + * if anything goes wrong during initialization + */ + public GenericJDBCDataModel(Properties props) throws TasteException { + super(AbstractJDBCComponent.lookupDataSource(props.getProperty(DATA_SOURCE_KEY)), + props.getProperty(GET_PREFERENCE_SQL_KEY), + props.getProperty(GET_PREFERENCE_TIME_SQL_KEY), + props.getProperty(GET_USER_SQL_KEY), + props.getProperty(GET_ALL_USERS_SQL_KEY), + props.getProperty(GET_NUM_ITEMS_SQL_KEY), + props.getProperty(GET_NUM_USERS_SQL_KEY), + props.getProperty(SET_PREFERENCE_SQL_KEY), + props.getProperty(REMOVE_PREFERENCE_SQL_KEY), + props.getProperty(GET_USERS_SQL_KEY), + props.getProperty(GET_ITEMS_SQL_KEY), + props.getProperty(GET_PREFS_FOR_ITEM_SQL_KEY), + props.getProperty(GET_NUM_PREFERENCE_FOR_ITEM_KEY), + props.getProperty(GET_NUM_PREFERENCE_FOR_ITEMS_KEY), + props.getProperty(GET_MAX_PREFERENCE_KEY), + props.getProperty(GET_MIN_PREFERENCE_KEY)); + } + + /** + * <p> + * See {@link #GenericJDBCDataModel(Properties)}. This constructor reads values from a file + * instead, as if with {@link Properties#load(InputStream)}. So, the file should be in standard Java + * properties file format -- containing {@code key=value} pairs, one per line. + * </p> + * + * @param propertiesFile + * properties file + * @throws TasteException + * if anything goes wrong during initialization + */ + public GenericJDBCDataModel(File propertiesFile) throws TasteException { + this(getPropertiesFromFile(propertiesFile)); + } + + /** + * <p> + * See {@link #GenericJDBCDataModel(Properties)}. This constructor reads values from a resource available in + * the classpath, as if with {@link Class#getResourceAsStream(String)} and + * {@link Properties#load(InputStream)}. This is useful if your configuration file is, for example, packaged + * in a JAR file that is in the classpath. + * </p> + * + * @param resourcePath + * path to resource in classpath (e.g. "/com/foo/TasteSQLQueries.properties") + * @throws TasteException + * if anything goes wrong during initialization + */ + public GenericJDBCDataModel(String resourcePath) throws TasteException { + this(getPropertiesFromStream(GenericJDBCDataModel.class + .getResourceAsStream(resourcePath))); + } + + private static Properties getPropertiesFromFile(File file) throws TasteException { + try { + return getPropertiesFromStream(new FileInputStream(file)); + } catch (FileNotFoundException fnfe) { + throw new TasteException(fnfe); + } + } + + private static Properties getPropertiesFromStream(InputStream is) throws TasteException { + try { + try { + Properties props = new Properties(); + props.load(is); + return props; + } finally { + Closeables.close(is, true); + } + } catch (IOException ioe) { + throw new TasteException(ioe); + } + } + +}
