Function Approximators: Neural Networks and Soft Trees
Nonlinear mixed-effects models often require flexible functional forms to capture relationships that cannot be specified a priori. NoLimits supports two classes of learnable function approximators – neural networks and soft decision trees – that can be embedded directly into any model block. Their parameters are estimated jointly with all other model parameters during fitting.
The supported parameter constructors are:
NNParameters(...)– wraps a Lux.jl neural network architecture.SoftTreeParameters(...)– constructs a differentiable soft decision tree.
Both are declared in @fixedEffects and exposed as callable model functions through the function_name keyword argument.
Where They Can Be Used
Model functions created from NNParameters and SoftTreeParameters are available throughout the model specification. Specifically, they can appear in:
@randomEffects– parameterizing the distributions of random effects@preDifferentialEquation– computing time-constant derived quantities@DifferentialEquation– within the right-hand side of ODE systems@initialDE– setting initial conditions@formulas– constructing the observation model
Pattern 1: Population-Level Approximators with Separate Random Effects
In this pattern, the approximator parameters are shared across all individuals (population-level fixed effects), while between-subject variability is captured by separate, additive random effects. This is the simplest way to introduce flexible nonlinearity without dramatically increasing the dimensionality of the random-effects space.
using NoLimits
using Distributions
using Lux
chain = Lux.Chain(Lux.Dense(2, 4, tanh), Lux.Dense(4, 1))
model = @Model begin
@fixedEffects begin
sigma = RealNumber(0.3, scale=:log)
z_nn = NNParameters(chain; function_name=:NN1, calculate_se=false)
z_st = SoftTreeParameters(2, 2; function_name=:ST1, calculate_se=false)
end
@covariates begin
t = Covariate()
x = ConstantCovariateVector([:Age, :BMI]; constant_on=:ID)
end
@randomEffects begin
eta = RandomEffect(Normal(0.0, 1.0); column=:ID)
end
@formulas begin
mu = NN1([x.Age, x.BMI], z_nn)[1] + ST1([x.Age, x.BMI], z_st)[1] + tanh(eta) + eta^2
y ~ Gamma(abs(mu) + 1e-6, sigma)
end
endPattern 2: Full-Parameter Individualization via Random Effects
When the functional form itself is expected to vary across individuals, the entire parameter vector of an approximator can be treated as a random effect. Each individual receives a personalized set of network or tree weights drawn from a multivariate distribution centered on the population-level parameters. This enables fully individualized nonlinear mappings at the cost of a high-dimensional random-effects distribution.
using NoLimits
using Distributions
using Lux
using LinearAlgebra
chain_A1 = Lux.Chain(Lux.Dense(1, 4, tanh), Lux.Dense(4, 1))
chain_A2 = Lux.Chain(Lux.Dense(1, 4, tanh), Lux.Dense(4, 1))
model = @Model begin
@helpers begin
softplus(u) = u > 20 ? u : log1p(exp(u))
end
@covariates begin
t = Covariate()
d = ConstantCovariate(; constant_on=:ID)
end
@fixedEffects begin
sigma = RealNumber(0.3, scale=:log)
zA1 = NNParameters(chain_A1; function_name=:NNA1, calculate_se=false)
zA2 = NNParameters(chain_A2; function_name=:NNA2, calculate_se=false)
gC1 = SoftTreeParameters(1, 2; function_name=:STC1, calculate_se=false)
gC2 = SoftTreeParameters(1, 2; function_name=:STC2, calculate_se=false)
end
@randomEffects begin
etaA1 = RandomEffect(MvNormal(zA1, Diagonal(ones(length(zA1)))); column=:ID)
etaA2 = RandomEffect(MvNormal(zA2, Diagonal(ones(length(zA2)))); column=:ID)
etaC1 = RandomEffect(MvNormal(gC1, Diagonal(ones(length(gC1)))); column=:ID)
etaC2 = RandomEffect(MvNormal(gC2, Diagonal(ones(length(gC2)))); column=:ID)
end
@DifferentialEquation begin
a_A(t) = softplus(depot)
x_C(t) = softplus(center)
fA1(t) = softplus(NNA1([t / 24], etaA1)[1])
fA2(t) = softplus(NNA2([a_A(t)], etaA2)[1])
fC1(t) = -softplus(STC1([x_C(t)], etaC1)[1])
fC2(t) = softplus(STC2([t / 24], etaC2)[1])
D(depot) ~ -d * fA1(t) - fA2(t)
D(center) ~ d * fA1(t) + fA2(t) + fC1(t) + d * fC2(t)
end
@initialDE begin
depot = d
center = 0.0
end
@formulas begin
y ~ LogNormal(center(t), sigma)
end
endPattern 3: Hybrid Models Combining Both Strategies
A single model can combine population-level and fully individualized approximators. For instance, one network may capture a shared population-level transformation while another is individualized through random effects. This provides a principled way to decompose variation into components that are common across individuals and components that are subject-specific.
using NoLimits
using Distributions
using Lux
using LinearAlgebra
chain = Lux.Chain(Lux.Dense(1, 4, tanh), Lux.Dense(4, 1))
model = @Model begin
@covariates begin
t = Covariate()
c = ConstantCovariate(; constant_on=:ID)
end
@fixedEffects begin
sigma = RealNumber(0.3, scale=:log)
z_fix = NNParameters(chain; function_name=:NNfix, calculate_se=false)
g_mix = SoftTreeParameters(1, 2; function_name=:STmix, calculate_se=false)
end
@randomEffects begin
eta_g = RandomEffect(MvNormal(g_mix, Diagonal(ones(length(g_mix)))); column=:ID)
end
@DifferentialEquation begin
D(x1) ~ -abs(NNfix([t / 24], z_fix)[1]) * x1 + abs(STmix([t / 24], eta_g)[1])
end
@initialDE begin
x1 = c
end
@formulas begin
y ~ Exponential(log1p(x1(t)^2) + sigma)
end
endPractical Notes
- The
function_namekeyword controls the callable name used to invoke the approximator in model expressions. Each approximator must have a unique function name. - Learned parameter blocks are typically declared with
calculate_se=false, since standard error computation for high-dimensional parameter vectors is often neither feasible nor informative. - The same
@ModelDSL is used for fixed-effects-only and mixed-effects workflows; only the presence and structure of@randomEffectsdetermines whether individualization occurs.