Skip to content

Commit

Permalink
Issue #298: LDA's rho behavior
Browse files Browse the repository at this point in the history
- Fixes update_alpha problem when calling rho(), pass_ wasn't in scope
- Change rho:
    - To not rely on num_updates/chunksize: when updates from a
    corpus size < chunksize, rho remains 1.0 for most updates.
    - *add* the pass count with the number of updates
    - Change the num_updates to actually be number of updates, rather than
    number of documents seen. Matches paper now, this relates to the chunksize
    problem.
  • Loading branch information
cscorley committed Apr 27, 2015
1 parent d446fb8 commit 1ec7acc
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ def update_alpha(self, gammat, rho):

dalpha = -(gradf - b) / q

if all(rho() * dalpha + self.alpha > 0):
self.alpha += rho() * dalpha
if all(rho * dalpha + self.alpha > 0):
self.alpha += rho * dalpha
else:
logger.warning("updated alpha not positive")
logger.info("optimized alpha %s" % list(self.alpha))
Expand Down Expand Up @@ -541,13 +541,12 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
"increasing the number of passes or iterations to improve accuracy")

# rho is the "speed" of updating; TODO try other fncs
# pass_ * num_updates handles increasing the starting t for each pass,
# pass_ + num_updates handles increasing the starting t for each pass,
# while allowing it to "reset" on the first pass of each update
def rho():
return pow(offset + ((pass_ * self.num_updates) / self.chunksize),
-decay)
return pow(offset + ((pass_ + self.num_updates)), -decay)

for pass_ in xrange(1, passes + 1):
for pass_ in xrange(passes):
if self.dispatcher:
logger.info('initializing %s workers' % self.numworkers)
self.dispatcher.reset(self.state)
Expand All @@ -574,7 +573,7 @@ def rho():
gammat = self.do_estep(chunk, other)

if self.optimize_alpha:
self.update_alpha(gammat, rho)
self.update_alpha(gammat, rho())

dirty = True
del chunk
Expand All @@ -585,7 +584,7 @@ def rho():
# distributed mode: wait for all workers to finish
logger.info("reached the end of input; now waiting for all remaining jobs to finish")
other = self.dispatcher.getstate()
self.do_mstep(rho(), other, pass_ != 1)
self.do_mstep(rho(), other, pass_ > 0)
del other # frees up memory

if self.dispatcher:
Expand All @@ -604,7 +603,7 @@ def rho():
# distributed mode: wait for all workers to finish
logger.info("reached the end of input; now waiting for all remaining jobs to finish")
other = self.dispatcher.getstate()
self.do_mstep(rho(), other, pass_ != 1)
self.do_mstep(rho(), other, pass_ > 0)
del other
dirty = False
# endfor entire corpus update
Expand All @@ -624,12 +623,12 @@ def do_mstep(self, rho, other, extra_pass=False):
self.sync_state()

# print out some debug info at the end of each EM iteration
self.print_topics(15)
self.print_topics(5)
logger.info("topic diff=%f, rho=%f" % (numpy.mean(numpy.abs(diff)), rho))

if not extra_pass:
# only update if this isn't an additional pass
self.num_updates += other.numdocs
self.num_updates += 1

def bound(self, corpus, gamma=None, subsample_ratio=1.0):
"""
Expand Down

0 comments on commit 1ec7acc

Please sign in to comment.