From 45b2b81c6a015a39958e8827df2aa992c1cab8cc Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Fri, 2 Jun 2017 19:26:59 -0400 Subject: [PATCH] add `merge` and `structdiff` for named tuples --- base/namedtuple.jl | 41 +++++++++++++++++++++++++++++++++++++++++ test/namedtuple.jl | 10 ++++++++++ 2 files changed, 51 insertions(+) diff --git a/base/namedtuple.jl b/base/namedtuple.jl index 5637fccdd957b0..207d26a76db203 100644 --- a/base/namedtuple.jl +++ b/base/namedtuple.jl @@ -89,3 +89,44 @@ end namedtuple($NT, $(args...)) end end + +# a version of `in` for the older world these generated functions run in +function sym_in(x, itr) + for y in itr + y === x && return true + end + return false +end + +@generated function merge(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} + names = Symbol[an...] + for n in bn + if !sym_in(n, an) + push!(names, n) + end + end + vals = map(names) do n + if sym_in(n, bn) + :(getfield(b, $(Expr(:quote, n)))) + else + :(getfield(a, $(Expr(:quote, n)))) + end + end + names = (names...,) + :(namedtuple(NamedTuple{$names}, $(vals...))) +end + +@generated function structdiff(a::NamedTuple{an}, + b::Union{NamedTuple{bn},Type{NamedTuple{bn}}}) where {an,bn} + names = Symbol[] + for n in an + if !sym_in(n, bn) + push!(names, n) + end + end + vals = map(names) do n + :(getfield(a, $(Expr(:quote, n)))) + end + names = (names...,) + :(namedtuple(NamedTuple{$names}, $(vals...))) +end diff --git a/test/namedtuple.jl b/test/namedtuple.jl index 316da4e64bcf5d..fbd5f287c272e2 100644 --- a/test/namedtuple.jl +++ b/test/namedtuple.jl @@ -45,3 +45,13 @@ @test map(+, (x=1, y=2), (x=10, y=20)) == (x=11, y=22) @test map(string, (x=1, y=2)) == (x="1", y="2") @test map(round, (x=1//3, y=Int), (x=3, y=2//3)) == (x=0.333, y=1) + +@test merge((a=1, b=2), (a=10,)) == (a=10, b=2) +@test merge((a=1, b=2), (a=10, z=20)) == (a=10, b=2, z=20) +@test merge((a=1, b=2), (z=20,)) == (a=1, b=2, z=20) + +@test Base.structdiff((a=1, b=2), (b=3,)) == (a=1,) +@test Base.structdiff((a=1, b=2, z=20), (b=3,)) == (a=1, z=20) +@test Base.structdiff((a=1, b=2, z=20), (b=3, q=20, z=1)) == (a=1,) +@test Base.structdiff((a=1, b=2, z=20), (b=3, q=20, z=1, a=0)) == NamedTuple() +@test Base.structdiff((a=1, b=2, z=20), NamedTuple{(:b,)}) == (a=1, z=20)