Skip to content

Commit

Permalink
Simplify session management
Browse files Browse the repository at this point in the history
Change-Id: I4f49354787c3dd848d4daacd04a049b27c858a82
  • Loading branch information
jerryshao committed Nov 24, 2016
1 parent b8a0839 commit 24c8eff
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 33 deletions.
19 changes: 10 additions & 9 deletions rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public class RSCClient implements LivyClient {

private ContextInfo contextInfo;
private volatile boolean isAlive;
private volatile String replState;

private SessionStateListener stateListener;

RSCClient(RSCConf conf, Promise<ContextInfo> ctx) throws IOException {
this.conf = conf;
Expand Down Expand Up @@ -94,6 +95,10 @@ public void onFailure(Throwable error) {
isAlive = true;
}

public void registerStateListener(SessionStateListener stateListener) {
this.stateListener = stateListener;
}

private synchronized void connectToContext(final ContextInfo info) throws Exception {
this.contextInfo = info;

Expand Down Expand Up @@ -287,13 +292,6 @@ public Future<ReplJobResults> getReplJobResults() throws Exception {
return deferredCall(new BaseProtocol.GetReplJobResults(), ReplJobResults.class);
}

/**
* @return Return the repl state. If this's not connected to a repl session, it will return null.
*/
public String getReplState() {
return replState;
}

private class ClientProtocol extends BaseProtocol {

<T> JobHandleImpl<T> submit(Job<T> job) {
Expand Down Expand Up @@ -389,7 +387,10 @@ private void handle(ChannelHandlerContext ctx, JobStarted msg) {

private void handle(ChannelHandlerContext ctx, ReplState msg) {
LOG.trace("Received repl state for {}", msg.state);
replState = msg.state;

if (stateListener != null) {
stateListener.onStateUpdated(msg.state);
}
}
}
}
26 changes: 26 additions & 0 deletions rsc/src/main/java/com/cloudera/livy/rsc/SessionStateListener.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.cloudera.livy.rsc;

public interface SessionStateListener {

/**
* Action when state is updated.
*/
void onStateUpdated(String state);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.launcher.SparkLauncher

import com.cloudera.livy._
import com.cloudera.livy.client.common.HttpMessages._
import com.cloudera.livy.rsc.{PingJob, RSCClient, RSCConf}
import com.cloudera.livy.rsc.{PingJob, RSCClient, RSCConf, SessionStateListener}
import com.cloudera.livy.rsc.driver.Statement
import com.cloudera.livy.server.recovery.SessionStore
import com.cloudera.livy.sessions._
Expand Down Expand Up @@ -333,11 +333,11 @@ class InteractiveSession(
sessionStore: SessionStore,
mockApp: Option[SparkApp]) // For unit test.
extends Session(id, owner, livyConf)
with SparkAppListener {
with SparkAppListener with SessionStateListener {

import InteractiveSession._

private var serverSideState: SessionState = initialState
@volatile private var _state: SessionState = initialState

private val operations = mutable.Map[Long, String]()
private val operationCounter = new AtomicLong(0)
Expand Down Expand Up @@ -376,6 +376,9 @@ class InteractiveSession(
}
uriFuture onFailure { case e => warn("Fail to get rsc uri", e) }

// Register this class to RSCClient as a session state listener
client.get.registerStateListener(this)

// Send a dummy job that will return once the client is ready to be used, and set the
// state to "idle" at that point.
client.get.submit(new PingJob()).addListener(new JobHandle.Listener[Void]() {
Expand All @@ -387,14 +390,14 @@ class InteractiveSession(
override def onJobFailed(job: JobHandle[Void], cause: Throwable): Unit = errorOut()

override def onJobSucceeded(job: JobHandle[Void], result: Void): Unit = {
transition(SessionState.Running())
transition(SessionState.Idle())
}

private def errorOut(): Unit = {
// Other code might call stop() to close the RPC channel. When RPC channel is closing,
// this callback might be triggered. Check and don't call stop() to avoid nested called
// if the session is already shutting down.
if (serverSideState != SessionState.ShuttingDown()) {
if (_state != SessionState.ShuttingDown()) {
transition(SessionState.Error())
stop()
}
Expand All @@ -407,17 +410,7 @@ class InteractiveSession(
override def recoveryMetadata: RecoveryMetadata =
InteractiveRecoveryMetadata(id, appId, appTag, kind, owner, proxyUser, rscDriverUri)

override def state: SessionState = {
if (serverSideState.isInstanceOf[SessionState.Running]) {
// If session is in running state, return the repl state from RSCClient.
client
.flatMap(s => Option(s.getReplState))
.map(SessionState(_))
.getOrElse(SessionState.Busy()) // If repl state is unknown, assume repl is busy.
} else {
serverSideState
}
}
override def state: SessionState = _state

override def stopSession(): Unit = {
try {
Expand Down Expand Up @@ -511,24 +504,24 @@ class InteractiveSession(
// If the session crashed because of the error, the session should instead go to dead state.
// Since these 2 transitions are triggered by different threads, there's a race condition.
// Make sure we won't transit from dead to error state.
val areSameStates = serverSideState.getClass() == newState.getClass()
val transitFromInactiveToActive = !serverSideState.isActive && newState.isActive
val areSameStates = _state.getClass() == newState.getClass()
val transitFromInactiveToActive = !_state.isActive && newState.isActive
if (!areSameStates && !transitFromInactiveToActive) {
debug(s"$this session state change from ${serverSideState} to $newState")
serverSideState = newState
debug(s"$this session state change from ${_state} to $newState")
_state = newState
}
}

private def ensureActive(): Unit = synchronized {
require(serverSideState.isActive, "Session isn't active.")
require(_state.isActive, "Session isn't active.")
require(client.isDefined, "Session is active but client hasn't been created.")
}

private def ensureRunning(): Unit = synchronized {
serverSideState match {
case SessionState.Running() =>
_state match {
case SessionState.Running() | SessionState.Idle() | SessionState.Busy() => Unit
case _ =>
throw new IllegalStateException("Session is in state %s" format serverSideState)
throw new IllegalStateException(s"Session is in state ${_state}")
}
}

Expand Down Expand Up @@ -560,4 +553,6 @@ class InteractiveSession(
}

override def infoChanged(appInfo: AppInfo): Unit = { this.appInfo = appInfo }

override def onStateUpdated(state: String): Unit = { transition(SessionState(state)) }
}

0 comments on commit 24c8eff

Please sign in to comment.