Dear all, I’m using the sklearn library to generate new features of a dataset using a Restricted Boltzmann Machine (RBM, sklearn.neural_network.BernoulliRBM). I use the following environment:
python 3.5.0 numpy==1.11.1 scikit-learn==0.18 I have already tried a large number of iterations (n_iter=6000) and a low learning rate (0.0001) for all training data (373 samples). However, The new features that are generated by the RBM are all highly correlated. Can anyone explain why this happens? Below is a MWE: import numpy as np import csv from sklearn.neural_network import BernoulliRBM # train data train_data = np.array( [[0.0326086956522,0.0,0.0,0.0200400801603,0.0674157303371,0.000805152979066,0.00200803212851,0.243243243243,0.0123456790123,0.55,0.0233428760185,0.0,0.0,0.0,0.444444444,0.0,0.0,0.157556270138,0.0188679245283,0.0983652512615], [0.0108695652174,0.2,0.0,0.00200400801603,0.0112359550562,0.0,0.0,0.027027027027,0.0123456790123,1.0,0.00154151068047,0.0,0.0,1.0,1.0,0.0,0.0,0.0289389067571,0.0,0.0], [0.0869565217391,0.0,0.152542372881,0.0260521042084,0.0749063670412,0.00322061191626,0.0180722891566,0.108108108108,0.0987654320988,0.4,0.022241796961,0.2,0.0909090909091,0.0,0.40625,0.0,0.0,0.053054662388,0.0188679245283,0.129097937384], [0.0326086956522,0.2,0.0847457627119,0.0140280561122,0.0149812734082,0.000268384326355,0.0120481927711,0.027027027027,0.0246913580247,0.25,0.00352345298392,1.0,0.0,0.75,0.555555556,0.0,0.0,0.0192926045047,0.0188679245283,0.0983652512615], [0.0978260869565,0.0,0.0,0.0100200400802,0.0711610486891,0.00214707461084,0.00803212851406,0.027027027027,0.111111111111,0.265625,0.0262056815679,1.0,0.0,0.0,0.518518519,0.0,0.0,0.0568060021635,0.0566037735849,0.213107498008], [0.0760869565217,0.8,0.0,0.0180360721443,0.0936329588015,0.0,0.0120481927711,0.0810810810811,0.0864197530864,0.3333333335,0.0561550319313,0.0,0.0,0.863636364,0.342857143,0.5,0.333333333333,0.168121267841,0.169811320755,0.463705037033], [0.0978260869565,1.0,0.0,0.0100200400802,0.063670411985,0.00697799248524,0.0,0.135135135135,0.0740740740741,0.4166666665,0.0156353226162,0.0,0.0,0.949367089,0.333333333,0.25,0.266666666667,0.0316184351626,0.0566037735849,0.163932249402], [0.0326086956522,0.2,0.0,0.0380761523046,0.0374531835206,0.000805152979066,0.0281124497992,0.135135135135,0.037037037037,1.0,0.00836820083682,0.0,0.0,0.923076923,0.583333333,0.0,0.0,0.0562700964881,0.0188679245283,0.0491752486057], [0.0108695652174,0.0,0.0,0.0200400801603,0.00374531835206,0.0,0.0160642570281,0.0540540540541,0.0123456790123,1.0,0.000220215811495,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0188679245283,0.147540499867], [0.217391304348,0.0,0.0,0.0140280561122,0.295880149813,0.0365002683843,0.0100401606426,0.135135135135,0.123456790123,0.4487534625,0.183880202599,1.0,0.0909090909091,0.0,0.19375,0.0,0.0,0.191961414822,0.188679245283,0.287703974741], [0.0652173913043,0.0,0.0,0.0160320641283,0.0224719101124,0.00402576489533,0.0140562248996,0.027027027027,0.0740740740741,1.0,0.00132129486897,0.0,0.0,0.0,0.444444444,0.0,0.0,0.0,0.0188679245283,0.147540499867], [0.0326086956522,0.6,0.0,0.0100200400802,0.0411985018727,0.000268384326355,0.00200803212851,0.108108108108,0.0123456790123,0.25,0.00902884827131,1.0,0.0909090909091,0.971428571,0.75,0.25,0.133333333333,0.0594855305401,0.0566037735849,0.147540499867], [0.119565217391,0.2,0.0,0.0140280561122,0.0973782771536,0.0,0.0100401606426,0.0540540540541,0.135802469136,0.29,0.0398590618806,1.0,0.0,0.529411765,0.409090909,0.0,0.0,0.0723472668927,0.0188679245283,0.107306205553], [0.0326086956522,0.2,0.0,0.0100200400802,0.0262172284644,0.000268384326355,0.00200803212851,0.108108108108,0.037037037037,0.25,0.00638625853336,1.0,0.0,0.818181818,0.666666667,0.0,0.0,0.0401929260499,0.0188679245283,0.0983652512615], [0.173913043478,0.4,0.0,0.0300601202405,0.243445692884,0.020397208803,0.0,0.405405405405,0.16049382716,0.46,0.106364236952,1.0,0.0,0.725490196,0.311111111,0.0,0.0,0.136254019315,0.169811320755,0.230532031043], [0.163043478261,0.4,0.0,0.0180360721443,0.153558052434,0.0,0.0,0.243243243243,0.185185185185,0.3392857145,0.044924025545,1.0,0.0909090909091,0.725490196,0.225,0.25,0.133333333333,0.0594855305401,0.0377358490566,0.226223848446], [0.152173913043,0.6,0.0508474576271,0.0220440881764,0.10861423221,0.0228126677402,0.00602409638554,0.216216216216,0.135802469136,0.2884615385,0.0237833076415,1.0,0.0909090909091,0.759259259,0.321428571,0.0,0.0,0.0316949931128,0.0754716981132,0.189692820679], [0.29347826087,0.4,0.0,0.0160320641283,0.378277153558,0.0421363392378,0.0100401606426,0.0810810810811,0.185185185185,0.4123931625,0.283197533583,0.888888889,0.0909090909091,0.294117647,0.183760684,0.25,0.466666666667,0.220078599537,0.0754716981132,0.163932249402], [0.0326086956522,0.0,0.0,0.00400801603206,0.0112359550562,0.000805152979066,0.00401606425703,0.0,0.037037037037,0.75,0.000880863245981,0.0,0.0,0.0,0.666666667,0.0,0.0,0.0,0.0188679245283,0.147540499867], [0.597826086957,0.4,0.135593220339,0.0400801603206,0.397003745318,0.352388620505,0.0160642570281,0.324324324324,0.111111111111,0.4782763535,0.249504514424,1.0,0.181818181818,0.406593407,0.195454545,0.0,0.0,0.0922537270084,0.188679245283,0.273613857004]] ) # define the RBM model random_state = 200 model = BernoulliRBM(n_components=10,n_iter=10,random_state=random_state) # building RBM and creating RBM features # Each column means one feature, each row means one line of the train data. RBM_feature_data = model.fit_transform(train_data) print(RBM_feature_data) Thank you! Masanari Kondo
_______________________________________________ scikit-learn mailing list scikit-learn@python.org https://mail.python.org/mailman/listinfo/scikit-learn