Skip to content

Commit

Permalink
Fixes issue #298, reset rho(t) for each additional pass
Browse files Browse the repository at this point in the history
- Fixes update_alpha problem when calling rho(), pass_ wasn't in scope
- Change rho to *add* the pass count with the number of updates
- On update, set chunksize to be the corpus length if len(corpus) < chunksize
and no chunksize was specified
  • Loading branch information
cscorley committed Apr 27, 2015
1 parent 4863040 commit edc3ce5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
37 changes: 23 additions & 14 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ def update_alpha(self, gammat, rho):

dalpha = -(gradf - b) / q

if all(rho() * dalpha + self.alpha > 0):
self.alpha += rho() * dalpha
if all(rho * dalpha + self.alpha > 0):
self.alpha += rho * dalpha
else:
logger.warning("updated alpha not positive")
logger.info("optimized alpha %s" % list(self.alpha))
Expand Down Expand Up @@ -500,8 +500,6 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
"""
# use parameters given in constructor, unless user explicitly overrode them
if chunksize is None:
chunksize = self.chunksize
if decay is None:
decay = self.decay
if offset is None:
Expand All @@ -517,9 +515,6 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
if gamma_threshold is None:
gamma_threshold = self.gamma_threshold

# rho is the "speed" of updating; TODO try other fncs
rho = lambda: pow(offset + self.num_updates / self.chunksize, -decay)

try:
lencorpus = len(corpus)
except:
Expand All @@ -529,6 +524,9 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
logger.warning("LdaModel.update() called with an empty corpus")
return

if chunksize is None:
chunksize = min(lencorpus, self.chunksize)

self.state.numdocs += lencorpus

if update_every:
Expand All @@ -552,6 +550,12 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
logger.warning("too few updates, training might not converge; consider "
"increasing the number of passes or iterations to improve accuracy")

# rho is the "speed" of updating; TODO try other fncs
# pass_ + num_updates handles increasing the starting t for each pass,
# while allowing it to "reset" on the first pass of each update
def rho():
return pow(offset + pass_ + (self.num_updates / chunksize), -decay)

for pass_ in xrange(passes):
if self.dispatcher:
logger.info('initializing %s workers' % self.numworkers)
Expand Down Expand Up @@ -579,7 +583,7 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
gammat = self.do_estep(chunk, other)

if self.optimize_alpha:
self.update_alpha(gammat, rho)
self.update_alpha(gammat, rho())

dirty = True
del chunk
Expand All @@ -590,8 +594,8 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
# distributed mode: wait for all workers to finish
logger.info("reached the end of input; now waiting for all remaining jobs to finish")
other = self.dispatcher.getstate()
self.do_mstep(rho(), other)
del other # free up some mem
self.do_mstep(rho(), other, pass_ > 0)
del other # frees up memory

if self.dispatcher:
logger.info('initializing workers')
Expand All @@ -609,13 +613,13 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
# distributed mode: wait for all workers to finish
logger.info("reached the end of input; now waiting for all remaining jobs to finish")
other = self.dispatcher.getstate()
self.do_mstep(rho(), other)
self.do_mstep(rho(), other, pass_ > 0)
del other
dirty = False
#endfor entire corpus update


def do_mstep(self, rho, other):
def do_mstep(self, rho, other, extra_pass=False):
"""
M step: use linear interpolation between the existing topics and
collected sufficient statistics in `other` to update the topics.
Expand All @@ -628,9 +632,14 @@ def do_mstep(self, rho, other):
self.state.blend(rho, other)
diff -= self.state.get_Elogbeta()
self.sync_state()
self.print_topics(15) # print out some debug info at the end of each EM iteration

# print out some debug info at the end of each EM iteration
self.print_topics(5)
logger.info("topic diff=%f, rho=%f" % (numpy.mean(numpy.abs(diff)), rho))
self.num_updates += other.numdocs

if not extra_pass:
# only update if this isn't an additional pass
self.num_updates += other.numdocs


def bound(self, corpus, gamma=None, subsample_ratio=1.0):
Expand Down
35 changes: 31 additions & 4 deletions gensim/test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def testTransform(self):
for i in range(5): # restart at most 5 times
# create the transformation model
model = ldamodel.LdaModel(id2word=dictionary, num_topics=2, passes=100)
model.update(corpus)
model.update(self.corpus)

# transform one document
doc = list(corpus)[0]
Expand All @@ -247,9 +247,36 @@ def testTransform(self):
def testTopTopics(self):
# create the transformation model
model = ldamodel.LdaModel(id2word=dictionary, num_topics=2, passes=100)
model.update(corpus)
model.update(self.corpus)

model.top_topics(self.corpus)

def testPasses(self):
# long message includes the original error message with a custom one
self.longMessage = True

# construct what we expect when passes aren't involved
test_rhots = list()
model = ldamodel.LdaModel(id2word=dictionary, chunksize=1, num_topics=2)
final_rhot = lambda: pow(model.offset + (1 * model.num_updates) / model.chunksize, -model.decay)

model.top_topics(corpus)
# generate 5 updates to test rhot on
for x in range(5):
model.update(self.corpus)
test_rhots.append(final_rhot())

for passes in [1, 5, 10, 50, 100]:
model = ldamodel.LdaModel(id2word=dictionary, chunksize=1, num_topics=2, passes=passes)
self.assertEqual(final_rhot(), 1.0)
# make sure the rhot matches the test after each update
for test_rhot in test_rhots:
model.update(self.corpus)

msg = ", ".join(map(str, [passes, model.num_updates, model.state.numdocs]))
self.assertAlmostEqual(final_rhot(), test_rhot, msg=msg)

self.assertEqual(model.state.numdocs, len(corpus) * len(test_rhots))
self.assertEqual(model.num_updates, len(corpus) * len(test_rhots))

def testTopicSeeding(self):
passed = False
Expand All @@ -268,7 +295,7 @@ def testTopicSeeding(self):
eta[topic, system] *= 10

model = ldamodel.LdaModel(id2word=dictionary, num_topics=2, passes=200, eta=eta)
model.update(corpus)
model.update(self.corpus)

topics = [dict((word, p) for p, word in model.show_topic(j)) for j in range(2)]

Expand Down

0 comments on commit edc3ce5

Please sign in to comment.