Skip to content

Commit

Permalink
Refactored IEnumerable<> calls to use 'yield return' for streaming fo…
Browse files Browse the repository at this point in the history
…r large files
  • Loading branch information
Andrew Mattie committed Oct 25, 2013
1 parent 2353279 commit ef445cd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 43 deletions.
4 changes: 1 addition & 3 deletions src/LinqToExcel/Extensions/CommonExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ public static object Cast(this object @object, Type castType)

public static IEnumerable<TResult> Cast<TResult>(this IEnumerable<object> list, Func<object, TResult> caster)
{
var results = new List<TResult>();
foreach (var item in list)
results.Add(caster(item));
return results;
yield return caster(item);
}

public static IEnumerable<TResult> Cast<TResult>(this IEnumerable<object> list)
Expand Down
86 changes: 46 additions & 40 deletions src/LinqToExcel/Query/ExcelQueryExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,34 +164,34 @@ protected IEnumerable<object> GetDataResults(SqlParts sql, QueryModel queryModel
{
IEnumerable<object> results;
OleDbDataReader data = null;
using (var conn = new OleDbConnection(_connectionString))
using (var command = conn.CreateCommand())
var conn = new OleDbConnection(_connectionString);
var command = conn.CreateCommand();

conn.Open();
command.CommandText = sql.ToString();
command.Parameters.AddRange(sql.Parameters.ToArray());
try { data = command.ExecuteReader(); }
catch (OleDbException e)
{
conn.Open();
command.CommandText = sql.ToString();
command.Parameters.AddRange(sql.Parameters.ToArray());
try { data = command.ExecuteReader(); }
catch (OleDbException e)
{
if (e.Message.Contains(_args.WorksheetName))
throw new DataException(
string.Format("'{0}' is not a valid worksheet name. Valid worksheet names are: '{1}'",
_args.WorksheetName, string.Join("', '", ExcelUtilities.GetWorksheetNames(_args.FileName).ToArray())));
if (!CheckIfInvalidColumnNameUsed(sql))
throw e;
}

var columns = ExcelUtilities.GetColumnNames(data);
LogColumnMappingWarnings(columns);
if (columns.Count() == 1 && columns.First() == "Expr1000")
results = GetScalarResults(data);
else if (queryModel.MainFromClause.ItemType == typeof(Row))
results = GetRowResults(data, columns);
else if (queryModel.MainFromClause.ItemType == typeof(RowNoHeader))
results = GetRowNoHeaderResults(data);
else
results = GetTypeResults(data, columns, queryModel);
if (e.Message.Contains(_args.WorksheetName))
throw new DataException(
string.Format("'{0}' is not a valid worksheet name. Valid worksheet names are: '{1}'",
_args.WorksheetName, string.Join("', '", ExcelUtilities.GetWorksheetNames(_args.FileName).ToArray())));
if (!CheckIfInvalidColumnNameUsed(sql))
throw e;
}

var columns = ExcelUtilities.GetColumnNames(data);
LogColumnMappingWarnings(columns);
if (columns.Count() == 1 && columns.First() == "Expr1000")
results = GetScalarResults(data, conn, command);
else if (queryModel.MainFromClause.ItemType == typeof(Row))
results = GetRowResults(data, columns, conn, command);
else if (queryModel.MainFromClause.ItemType == typeof(RowNoHeader))
results = GetRowNoHeaderResults(data, conn, command);
else
results = GetTypeResults(data, columns, queryModel, conn, command);

return results;
}

Expand Down Expand Up @@ -229,9 +229,8 @@ private bool CheckIfInvalidColumnNameUsed(SqlParts sql)
return false;
}

private IEnumerable<object> GetRowResults(IDataReader data, IEnumerable<string> columns)
private IEnumerable<object> GetRowResults(IDataReader data, IEnumerable<string> columns, OleDbConnection conn, OleDbCommand command)
{
var results = new List<object>();
var columnIndexMapping = new Dictionary<string, int>();
for (var i = 0; i < columns.Count(); i++)
columnIndexMapping[columns.ElementAt(i)] = i;
Expand All @@ -241,27 +240,29 @@ private IEnumerable<object> GetRowResults(IDataReader data, IEnumerable<string>
IList<Cell> cells = new List<Cell>();
for (var i = 0; i < columns.Count(); i++)
cells.Add(new Cell(data[i]));
results.CallMethod("Add", new Row(cells, columnIndexMapping));
yield return new Row(cells, columnIndexMapping);
}
return results.AsEnumerable();

conn.Dispose();
command.Dispose();
}

private IEnumerable<object> GetRowNoHeaderResults(OleDbDataReader data)
private IEnumerable<object> GetRowNoHeaderResults(OleDbDataReader data, OleDbConnection conn, OleDbCommand command)
{
var results = new List<object>();
while (data.Read())
{
IList<Cell> cells = new List<Cell>();
for (var i = 0; i < data.FieldCount; i++)
cells.Add(new Cell(data[i]));
results.CallMethod("Add", new RowNoHeader(cells));
yield return new RowNoHeader(cells);
}
return results.AsEnumerable();

conn.Dispose();
command.Dispose();
}

private IEnumerable<object> GetTypeResults(IDataReader data, IEnumerable<string> columns, QueryModel queryModel)
private IEnumerable<object> GetTypeResults(IDataReader data, IEnumerable<string> columns, QueryModel queryModel, OleDbConnection conn, OleDbCommand command)
{
var results = new List<object>();
var fromType = queryModel.MainFromClause.ItemType;
var props = fromType.GetProperties();
if (_args.StrictMapping.Value != StrictMappingType.None)
Expand All @@ -278,9 +279,11 @@ private IEnumerable<object> GetTypeResults(IDataReader data, IEnumerable<string>
if (columns.Contains(columnName))
result.SetProperty(prop.Name, GetColumnValue(data, columnName, prop.Name).Cast(prop.PropertyType));
}
results.Add(result);
yield return result;
}
return results.AsEnumerable();

conn.Dispose();
command.Dispose();
}

private void ConfirmStrictMapping(IEnumerable<string> columns, PropertyInfo[] properties, StrictMappingType strictMappingType)
Expand Down Expand Up @@ -323,10 +326,13 @@ private object GetColumnValue(IDataRecord data, string columnName, string proper
data[columnName];
}

private IEnumerable<object> GetScalarResults(IDataReader data)
private IEnumerable<object> GetScalarResults(IDataReader data, OleDbConnection conn, OleDbCommand command)
{
data.Read();
return new List<object> { data[0] };
yield return data[0];

conn.Dispose();
command.Dispose();
}

private void LogSqlStatement(SqlParts sqlParts)
Expand Down

0 comments on commit ef445cd

Please sign in to comment.