Skip to content

Commit

Permalink
Wire up requestGaussianSplatting API with persistent task backing
Browse files Browse the repository at this point in the history
  • Loading branch information
hobinjk-ptc committed Feb 19, 2024
1 parent f425a00 commit 0dcd8f3
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 1 deletion.
20 changes: 19 additions & 1 deletion controllers/object.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const path = require('path');
const formidable = require('formidable');
const utilities = require('../libraries/utilities');
const {fileExists, unlinkIfExists, mkdirIfNotExists} = utilities;
const {startSplatTask} = require('./object/SplatTask.js');

// Variables populated from server.js with setup()
var objects = {};
Expand Down Expand Up @@ -363,6 +364,22 @@ const generateXml = async function(objectID, body, callback) {
callback(200, 'ok');
};

/**
* @param {string} objectId
* @return {{done: boolean, gaussianSplatRequestId: string|undefined}} result
*/
async function requestGaussianSplatting(objectId) {
const object = utilities.getObject(objects, objectId);
if (!object) {
throw new Error('Object not found');
}

let splatTask = await startSplatTask(object);
// Starting splat task can modify object
await utilities.writeObjectToFile(objects, objectId, globalVariables.saveToDisk);
return splatTask.getStatus();
}

/**
* Enable sharing of Spatial Tools from this server to objects on other servers
* @todo: see github issue #23 - function is currently unimplemented
Expand Down Expand Up @@ -426,5 +443,6 @@ module.exports = {
generateXml: generateXml,
setFrameSharingEnabled: setFrameSharingEnabled,
getObject: getObject,
setup: setup
setup: setup,
requestGaussianSplatting,
};
161 changes: 161 additions & 0 deletions controllers/object/SplatTask.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
const fetch = require('node-fetch');
const FormData = require('form-data');
const path = require('path');
const WebSocket = require('ws');

const {objectsPath} = require('../../config.js');
const {identityFolderName} = require('../../constants.js');
const fsProm = require('../../persistence/fsProm.js');
const utilities = require('../../libraries/utilities.js');
const {fileExists} = utilities;

const SPLAT_HOST = 'change me:3000';

/**
* A class for starting and monitoring the progress of an area target to splat
* conversion problem, persisting the resulting file if successful
*/
class SplatTask {
/**
* @param {ObjectModel} object
*/
constructor(object) {
this.object = object;
this.gaussianSplatRequestId = null;
if (this.object.gaussianSplatRequestId) {
this.gaussianSplatRequestId = this.object.gaussianSplatRequestId;
}
this.done = false;

this.onOpen = this.onOpen.bind(this);
this.onMessage = this.onMessage.bind(this);
}
/**
* Starts the splatTask, running until we at least get a request id
* @return {string|undefined} gaussianSplatRequestId if started, undefined if already complete or started
*/
async start() {
const objectName = this.object.name;
const splatPath = path.join(objectsPath, objectName, identityFolderName, 'target', 'target.splat');
if (await fileExists(splatPath)) {
this.done = true;
return;
}

if (this.gaussianSplatRequestId) {
let success = await this.download();
if (success) {
return;
}
// Continue, download failed because the request is still in progress
} else {
const tdtPath = path.join(objectsPath, objectName, identityFolderName, 'target', 'target.3dt');
if (!await fileExists(tdtPath)) {
throw new Error('No 3dt file to upload');
}
const targetTdtBuf = await fsProm.readFile(tdtPath);
console.log('tdt?', tdtPath, !!targetTdtBuf);

const form = new FormData();
form.append('3dt', targetTdtBuf, {filename: 'target.3dt', name: '3dt', contentType: 'application/octet-stream'});
const res = await fetch(`http://${SPLAT_HOST}/upload`, {
method: 'POST',
headers: {
...form.getHeaders(),
},
body: form,
});

const gaussianSplatRequestId = await res.text();
this.object.gaussianSplatRequestId = gaussianSplatRequestId;
this.gaussianSplatRequestId = gaussianSplatRequestId;
}

this.openSocket();

return this.gaussianSplatRequestId;
}

openSocket() {
this.ws = new WebSocket('ws://' + SPLAT_HOST);
this.ws.addEventListener('open', this.onOpen);
this.ws.addEventListener('message', this.onMessage);
}

onOpen() {
this.ws.send(this.gaussianSplatRequestId);
}

onMessage(event) {
let message;
try {
message = JSON.parse(event.data);
} catch (e) {
console.error('SplatTask: json parse error', event.data);
return;
}

if (message.checkpointComplete) {
this.stop();
this.download();
}
}

getStatus() {
return {
done: this.done,
gaussianSplatRequestId: this.gaussianSplatRequestId,
};
}

stop() {
if (!this.ws) {
return;
}

try {
this.ws.removeEventListener('open', this.onOpen);
this.ws.removeEventListener('message', this.onMessage);
this.ws.close();
this.ws = null;
} catch (_e) {
// Don't care about potential errors
}
}

async download() {
const splatPath = path.join(objectsPath, this.object.name, identityFolderName, 'target', 'target.splat');
try {
let res = await fetch(`http://${SPLAT_HOST}/downloads/${this.gaussianSplatRequestId}`);
if (!res.ok) {
throw new Error(`Unexpected response: ${res.statusText}`);
}
let body = await res.arrayBuffer();
await fsProm.writeFile(splatPath, new Uint8Array(body));
this.done = true;
return true;
} catch (err) {
console.warn(`error downloading target.splat for ${this.object.objectId}`, err);
return false;
}
}
}

const splatTasks = {};
module.exports.splatTasks = splatTasks;

/**
* @param {ObjectModel} object
*/
module.exports.startSplatTask = async function startSplatTask(object) {
const objectId = object.objectId;
const oldTask = splatTasks[object.objectId];
if (oldTask) {
return oldTask;
}

splatTasks[objectId] = new SplatTask(object);
// Kick off the splatting to the point where we get a request id
await splatTasks[objectId].start();
return splatTasks[objectId];
};
1 change: 1 addition & 0 deletions models/ObjectModel.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function ObjectModel(ip, version, protocol, objectId) {
this.isAnchor = false;
this.type = 'object'; // or: 'world' or 'human' or 'avatar' etc...
this.timestamp = null; // timestamp optionally stores when the object was first created
this.gaussianSplatRequestId = null; // Optional id for in-progress request to GS server
}

/**
Expand Down
16 changes: 16 additions & 0 deletions routers/object.js
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,22 @@ router.post('/:objectName/frame/:frameName/pinned/', function (req, res) {
});
});

router.post('/:objectName/requestGaussianSplatting/', async function (req, res) {
if (!utilities.isValidId(req.params.objectName)) {
res.status(400).send('Invalid object name. Must be alphanumeric.');
return;
}
// splat status (commonly referred to as "splattus") is
// {done: boolean, progress: number}
try {
const splatStatus = await objectController.requestGaussianSplatting(req.params.objectName);
res.json(splatStatus);
} catch (e) {
console.error(e);
res.sendStatus(500);
}
});

const setupDeveloperRoutes = function() {
// normal nodes
router.post('/:objectName/frame/:frameName/node/:nodeName/size/', function (req, res) {
Expand Down
4 changes: 4 additions & 0 deletions server.js
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ const linkController = require('./controllers/link.js');
const logicNodeController = require('./controllers/logicNode.js');
const nodeController = require('./controllers/node.js');
const objectController = require('./controllers/object.js');
const {splatTasks} = require('./controllers/object/SplatTask.js');
const spatialController = require('./controllers/spatial');

const signallingController = require('./controllers/signalling.js');
Expand Down Expand Up @@ -1155,6 +1156,9 @@ async function exit() {
clearInterval(socketUpdaterInterval);
staleObjectCleaner.clearCleanupIntervals();
humanPoseFuser.stop();
for (const splatTask of Object.values(splatTasks)) {
splatTask.stop();
}
console.info('Server exited successfully');
if (process.env.NODE_ENV !== 'test') {
process.exit(0);
Expand Down

0 comments on commit 0dcd8f3

Please sign in to comment.