Skip to content

Commit

Permalink
Fixes issue #298, reset rho(t) for each additional pass
Browse files Browse the repository at this point in the history
  • Loading branch information
cscorley committed Apr 25, 2015
1 parent 4863040 commit 0ce2504
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
21 changes: 13 additions & 8 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,6 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
if gamma_threshold is None:
gamma_threshold = self.gamma_threshold

# rho is the "speed" of updating; TODO try other fncs
rho = lambda: pow(offset + self.num_updates / self.chunksize, -decay)

try:
lencorpus = len(corpus)
except:
Expand Down Expand Up @@ -552,7 +549,13 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
logger.warning("too few updates, training might not converge; consider "
"increasing the number of passes or iterations to improve accuracy")

for pass_ in xrange(passes):
# rho is the "speed" of updating; TODO try other fncs
# pass_ * num_updates handles increasing the starting t for each pass,
# while allowing it to "reset" on the first pass of each update
rho = lambda: pow(offset + (pass_ * self.num_updates) / self.chunksize, -decay)

for pass_ in xrange(1, passes + 1):
logger.info("RHO INFO AT PASS " + str(pass_) +": " + str(rho()))
if self.dispatcher:
logger.info('initializing %s workers' % self.numworkers)
self.dispatcher.reset(self.state)
Expand Down Expand Up @@ -590,7 +593,7 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
# distributed mode: wait for all workers to finish
logger.info("reached the end of input; now waiting for all remaining jobs to finish")
other = self.dispatcher.getstate()
self.do_mstep(rho(), other)
self.do_mstep(rho(), other, pass_ != 1)
del other # free up some mem

if self.dispatcher:
Expand All @@ -609,13 +612,13 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
# distributed mode: wait for all workers to finish
logger.info("reached the end of input; now waiting for all remaining jobs to finish")
other = self.dispatcher.getstate()
self.do_mstep(rho(), other)
self.do_mstep(rho(), other, pass_ != 1)
del other
dirty = False
#endfor entire corpus update


def do_mstep(self, rho, other):
def do_mstep(self, rho, other, extra_pass=False):
"""
M step: use linear interpolation between the existing topics and
collected sufficient statistics in `other` to update the topics.
Expand All @@ -630,7 +633,9 @@ def do_mstep(self, rho, other):
self.sync_state()
self.print_topics(15) # print out some debug info at the end of each EM iteration
logger.info("topic diff=%f, rho=%f" % (numpy.mean(numpy.abs(diff)), rho))
self.num_updates += other.numdocs
if not extra_pass:
# only update if this isn't an additional pass.
self.num_updates += other.numdocs


def bound(self, corpus, gamma=None, subsample_ratio=1.0):
Expand Down
35 changes: 31 additions & 4 deletions gensim/test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def testTransform(self):
for i in range(5): # restart at most 5 times
# create the transformation model
model = ldamodel.LdaModel(id2word=dictionary, num_topics=2, passes=100)
model.update(corpus)
model.update(self.corpus)

# transform one document
doc = list(corpus)[0]
Expand All @@ -247,9 +247,36 @@ def testTransform(self):
def testTopTopics(self):
# create the transformation model
model = ldamodel.LdaModel(id2word=dictionary, num_topics=2, passes=100)
model.update(corpus)
model.update(self.corpus)

model.top_topics(self.corpus)

def testPasses(self):
# long message includes the original error message with a custom one
self.longMessage = True

# construct what we expect when passes aren't involved
test_rhots = list()
model = ldamodel.LdaModel(id2word=dictionary, chunksize=1, num_topics=2)
final_rhot = lambda: pow(model.offset + (1 * model.num_updates) / model.chunksize, -model.decay)

model.top_topics(corpus)
# generate 5 updates to test rhot on
for x in range(5):
model.update(self.corpus)
test_rhots.append(final_rhot())

for passes in [1, 5, 10, 50, 100]:
model = ldamodel.LdaModel(id2word=dictionary, chunksize=1, num_topics=2, passes=passes)
self.assertEqual(final_rhot(), 1.0)
# make sure the rhot matches the test after each update
for test_rhot in test_rhots:
model.update(self.corpus)

msg = ", ".join(map(str, [passes, model.num_updates, model.state.numdocs]))
self.assertAlmostEqual(final_rhot(), test_rhot, msg=msg)

self.assertEqual(model.state.numdocs, len(corpus) * len(test_rhots))
self.assertEqual(model.num_updates, len(corpus) * len(test_rhots))

def testTopicSeeding(self):
passed = False
Expand All @@ -268,7 +295,7 @@ def testTopicSeeding(self):
eta[topic, system] *= 10

model = ldamodel.LdaModel(id2word=dictionary, num_topics=2, passes=200, eta=eta)
model.update(corpus)
model.update(self.corpus)

topics = [dict((word, p) for p, word in model.show_topic(j)) for j in range(2)]

Expand Down

0 comments on commit 0ce2504

Please sign in to comment.