View in NBViewer Open in binder Download notebook View source


6.3. Benchmarking flags for FFTW plans

We test planning the FFTW with different flags.

Here are the packages we need.

using FFTW
using Plots
using LinearAlgebra: mul!
using Test
using BenchmarkTools
using Random
using BenchmarkPlots
using StatsPlots

@info "Threads: $(FFTW.nthreads())"
[ Info: Threads: 8

Preparing the benchmarks

We prepare a suite of benchmarks, using BenchmarkTools.BenchmarkGroup().

The best plan usually depends on the size of the array, which in this case is N×NN \times N. In particular, it depens on the factorization of NN. So, we benchmark different values of NN.

We choose the following values:

Ns = [2^6, 3^4, 2 * 3 * 5 * 7, 2^8, 2^2 * 3^2 * 5^2, 2^10, 2^11, 2^2 * 3 * 5^2 * 7]
8-element Vector{Int64}:
   64
   81
  210
  256
  900
 1024
 2048
 2100

We must also prepare some other variables.

List of flags for planning the FFTs:

flags = ["ESTIMATE", "MEASURE", "PATIENT", "EXHAUSTIVE"]
ext_flags = ["NO PLAN", flags...]
nothing

Physical space:

L = 2π
κ₀ = 2π/L
nothing

Excited modes for testing field:

rng = Xoshiro(123)
num_modes = 24
max_mode = 16

kxs = rand(rng, 1:max_mode, num_modes)
kys = rand(rng, 1:max_mode, num_modes)
ars = 10*randn(rng, num_modes)
ais = 10*randn(rng, num_modes)
nothing

The suite of benchmarks:

suite = BenchmarkGroup()
for flag in ext_flags
    suite[flag] = BenchmarkGroup()
end

plan_stats = Dict{String, Dict{Int, Float64}}(flag => Dict() for flag in flags)
Dict{String, Dict{Int64, Float64}} with 4 entries:
  "ESTIMATE" => Dict()
  "EXHAUSTIVE" => Dict()
  "MEASURE" => Dict()
  "PATIENT" => Dict()

Now we are ready to prepare the suite of benchmarks. Keep in mind that preparing the suite includes planning the transforms, and plans with a PATIENT or EXAUSTIVE flag take some time for large NN.

for N in Ns
    x = y = (L/N):(L/N):L
    vort = sum(
        [
            2κ₀^2 * (kx^2 + ky^2) * (
                ar * cos.(κ₀ * (kx * one.(y) * x' + ky * y * one.(x)'))
                - ai * sin.(κ₀ * (kx * one.(y) * x' + ky * y * one.(x)'))
            )
            for (kx, ky, ar, ai) in zip(kxs, kys, ars, ais)
        ]
    )

    vort_hat = rfft(vort)

    flag = "NO PLAN"
    @info "N = $N; flag: $flag"
    suite[flag][N] = @benchmarkable rfft(w) setup = (w = copy($vort));

    for flag in flags
        @info "N = $N; flag: $flag"
        planed, pstats... = @timed plan_rfft(vort, flags = eval(Meta.parse("FFTW.$flag")))
        plan_stats[flag][N] = pstats.time

        suite[flag][N] = @benchmarkable mul!(w_hat, p, w) setup = (
            w_hat = copy($vort_hat);
            p = $planed;
            w = copy($vort)
        );
    end
end

suite
5-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "ESTIMATE" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  81 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  1024 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  210 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  900 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2100 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2048 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  256 => Benchmark(evals=1, seconds=5.0, samples=10000)
  "EXHAUSTIVE" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  81 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  1024 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  210 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  900 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2100 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2048 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  256 => Benchmark(evals=1, seconds=5.0, samples=10000)
  "NO PLAN" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  81 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  1024 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  210 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  900 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2100 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2048 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  256 => Benchmark(evals=1, seconds=5.0, samples=10000)
  "MEASURE" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  81 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  1024 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  210 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  900 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2100 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2048 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  256 => Benchmark(evals=1, seconds=5.0, samples=10000)
  "PATIENT" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  81 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  1024 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  210 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  900 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2100 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  2048 => Benchmark(evals=1, seconds=5.0, samples=10000)
	  256 => Benchmark(evals=1, seconds=5.0, samples=10000)

Looking at plan_stats, we can see the time spent in planning the transforms.

for N in Ns
    @info "N = $N"
    for flag in flags
        @info "$flag: \t$(BenchmarkTools.prettytime(plan_stats[flag][N] * 1.0e+9))"
    end
end
[ Info: N = 64
[ Info: ESTIMATE: 	33.391 ms
[ Info: MEASURE: 	124.770 ms
[ Info: PATIENT: 	414.759 ms
[ Info: EXHAUSTIVE: 	12.395 s
[ Info: N = 81
[ Info: ESTIMATE: 	143.500 μs
[ Info: MEASURE: 	37.177 ms
[ Info: PATIENT: 	190.513 ms
[ Info: EXHAUSTIVE: 	2.837 s
[ Info: N = 210
[ Info: ESTIMATE: 	191.000 μs
[ Info: MEASURE: 	114.820 ms
[ Info: PATIENT: 	209.885 ms
[ Info: EXHAUSTIVE: 	4.047 s
[ Info: N = 256
[ Info: ESTIMATE: 	131.250 μs
[ Info: MEASURE: 	257.434 ms
[ Info: PATIENT: 	1.802 s
[ Info: EXHAUSTIVE: 	40.778 s
[ Info: N = 900
[ Info: ESTIMATE: 	208.584 μs
[ Info: MEASURE: 	912.623 ms
[ Info: PATIENT: 	15.810 s
[ Info: EXHAUSTIVE: 	155.129 s
[ Info: N = 1024
[ Info: ESTIMATE: 	265.750 μs
[ Info: MEASURE: 	560.984 ms
[ Info: PATIENT: 	9.662 s
[ Info: EXHAUSTIVE: 	109.921 s
[ Info: N = 2048
[ Info: ESTIMATE: 	204.125 μs
[ Info: MEASURE: 	778.754 ms
[ Info: PATIENT: 	42.324 s
[ Info: EXHAUSTIVE: 	209.617 s
[ Info: N = 2100
[ Info: ESTIMATE: 	546.166 μs
[ Info: MEASURE: 	1.265 s
[ Info: PATIENT: 	72.559 s
[ Info: EXHAUSTIVE: 	339.066 s

Running the benchmark

This should take some time as well.

results, stats... = @timed run(suite)
nothing

Here are the stats of the run:

stats
(time = 151.039047292, bytes = 476946126208, gctime = 17.151591493, gcstats = Base.GC_Diff(476946126208, 432074, 0, 8819757, 1324, 0, 17151591493, 2844, 185))

Let's take a look at the results. The time shown is the minimum time of the trial runs for each benchmark.

results
5-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "ESTIMATE" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Trial(5.625 μs)
	  81 => Trial(23.958 μs)
	  1024 => Trial(3.346 ms)
	  210 => Trial(363.125 μs)
	  900 => Trial(3.543 ms)
	  2100 => Trial(25.550 ms)
	  2048 => Trial(20.249 ms)
	  256 => Trial(151.084 μs)
  "EXHAUSTIVE" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Trial(5.625 μs)
	  81 => Trial(22.958 μs)
	  1024 => Trial(2.758 ms)
	  210 => Trial(344.250 μs)
	  900 => Trial(2.588 ms)
	  2100 => Trial(17.807 ms)
	  2048 => Trial(15.177 ms)
	  256 => Trial(133.167 μs)
  "NO PLAN" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Trial(24.500 μs)
	  81 => Trial(112.375 μs)
	  1024 => Trial(2.840 ms)
	  210 => Trial(492.625 μs)
	  900 => Trial(2.694 ms)
	  2100 => Trial(17.967 ms)
	  2048 => Trial(15.296 ms)
	  256 => Trial(197.958 μs)
  "MEASURE" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Trial(5.625 μs)
	  81 => Trial(23.000 μs)
	  1024 => Trial(2.940 ms)
	  210 => Trial(342.291 μs)
	  900 => Trial(2.645 ms)
	  2100 => Trial(20.090 ms)
	  2048 => Trial(16.487 ms)
	  256 => Trial(133.958 μs)
  "PATIENT" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  64 => Trial(5.625 μs)
	  81 => Trial(22.125 μs)
	  1024 => Trial(2.754 ms)
	  210 => Trial(336.459 μs)
	  900 => Trial(2.581 ms)
	  2100 => Trial(17.853 ms)
	  2048 => Trial(15.179 ms)
	  256 => Trial(132.958 μs)

Analysis of the benchmark

We start plotting the minimum and median times of the benchmark trials.

First with the minimum time for low values of N.

plt = plot(
    title="Minimum times for different plans with vector fields of different sizes",
    xlabel = "N",
    ylabel = "time (ns)",
    xticks = Ns[1:4],
    rotation = 90,
    titlefont=10,
    legend=:topleft
)

for flag in ext_flags
    plot!(plt, Ns[1:4], N -> minimum(values(results[flag][N]).times),
    linestyle = :dash,
    markershape = :rect,
    label="$flag"
    )
end

plt

Now, the minimum time for all values of N:

plt = plot(
    title="Minimum times for different plans with vector fields of different sizes",
    xlabel = "N",
    ylabel = "time (ns)",
    xticks = Ns,
    rotation = 90,
    titlefont=10,
    legend=:topleft
)

for flag in ext_flags
    plot!(plt, Ns, N -> minimum(values(results[flag][N]).times),
    linestyle = :dash,
    markershape = :rect,
    label="$flag"
    )
end

plt

Next with the median time, starting with low values of N:

plt = plot(
    title="Median times for different plans with vector fields of different sizes",
    xlabel = "N",
    ylabel = "time (ms)",
    xticks = Ns[1:4],
    rotation = 90,
    titlefont=10,
    legend=:topleft
)

for flag in ext_flags
    plot!(plt, Ns[1:4], N -> median(values(results[flag][N]).times),
    linestyle = :solid,
    markershape = :circle,
    label="$flag"
    )
end
plt

Now with all values of N:

plt = plot(
    title="Median times for different plans with vector fields of different sizes",
    xlabel = "N",
    ylabel = "time (ms)",
    xticks = Ns,
    rotation = 90,
    titlefont=10,
    legend=:topleft
)

for flag in ext_flags
    plot!(plt, Ns, N -> median(values(results[flag][N]).times)./ 1.0e+6,
    linestyle = :solid,
    markershape = :circle,
    label="$flag"
    )
end

plt

Next we have a look at the set of trials, with violin plots.

plts = []

for N in Ns
    push!(
        plts,
        violin(
            [results[flag][N].times for flag in flags],
            title = "Trials with N = $N",
            titlefont = 12,
            xticks = (1:length(flags), string.(flags)),
            yaxis = "time (ns)",
            legend = nothing
        )
    )
end

if isodd(length(Ns))
    push!(plts, plot(border = :none))
end

plt = plot(plts..., layout = (div(length(plts), 2), 2), size = (800, 1200))

We may look at the plot recipe build for the results of running a benchmark suite. But need to work on sorting the values of N.

Let us take a final closer look at the median values.

for N in Ns
    @info "N = $N"
    for flag in flags
        @info "median time for flag $flag: $(BenchmarkTools.prettytime(median(values(results[flag][N]).times)))"
    end
end
[ Info: N = 64
[ Info: median time for flag ESTIMATE: 5.709 μs
[ Info: median time for flag MEASURE: 5.709 μs
[ Info: median time for flag PATIENT: 5.750 μs
[ Info: median time for flag EXHAUSTIVE: 5.750 μs
[ Info: N = 81
[ Info: median time for flag ESTIMATE: 24.292 μs
[ Info: median time for flag MEASURE: 23.500 μs
[ Info: median time for flag PATIENT: 23.750 μs
[ Info: median time for flag EXHAUSTIVE: 23.958 μs
[ Info: N = 210
[ Info: median time for flag ESTIMATE: 384.875 μs
[ Info: median time for flag MEASURE: 360.667 μs
[ Info: median time for flag PATIENT: 343.459 μs
[ Info: median time for flag EXHAUSTIVE: 354.250 μs
[ Info: N = 256
[ Info: median time for flag ESTIMATE: 151.625 μs
[ Info: median time for flag MEASURE: 134.334 μs
[ Info: median time for flag PATIENT: 133.542 μs
[ Info: median time for flag EXHAUSTIVE: 133.750 μs
[ Info: N = 900
[ Info: median time for flag ESTIMATE: 3.623 ms
[ Info: median time for flag MEASURE: 2.713 ms
[ Info: median time for flag PATIENT: 2.634 ms
[ Info: median time for flag EXHAUSTIVE: 2.643 ms
[ Info: N = 1024
[ Info: median time for flag ESTIMATE: 3.417 ms
[ Info: median time for flag MEASURE: 3.008 ms
[ Info: median time for flag PATIENT: 2.835 ms
[ Info: median time for flag EXHAUSTIVE: 2.836 ms
[ Info: N = 2048
[ Info: median time for flag ESTIMATE: 20.574 ms
[ Info: median time for flag MEASURE: 16.540 ms
[ Info: median time for flag PATIENT: 15.498 ms
[ Info: median time for flag EXHAUSTIVE: 15.226 ms
[ Info: N = 2100
[ Info: median time for flag ESTIMATE: 25.783 ms
[ Info: median time for flag MEASURE: 20.312 ms
[ Info: median time for flag PATIENT: 18.203 ms
[ Info: median time for flag EXHAUSTIVE: 17.921 ms

Notice that N with powers of 2 are a little better. The timings for the non-powers-of-2 are skewed up.

We can have a better look at that by fitting the expected order of complexity. The discrete Fourier transform is of order Klog(K)K\log(K), but this is a two-dimensional problem, so K=N2K = N^2.

Let us fit a single plan, say PATIENT, with the median times.

flag = "PATIENT"
"PATIENT"

We only fit the data for the powers of two.

twos = filter(N -> isinteger(log2(N)), Ns)
4-element Vector{Int64}:
   64
  256
 1024
 2048

Here is the Vandermonde matrix for the fit.

mat = [ones(length(twos)) [N^2 * log(N^2) for N in twos]]
4×2 Matrix{Float64}:
 1.0  34069.6
 1.0      7.26817e5
 1.0      1.45363e7
 1.0      6.39599e7

We get

a, b = mat \ [median(values(results[flag][N].times)) for N in twos]
2-element Vector{Float64}:
 -215516.36056311236
       0.2439469359125431

Now we plot the resulting fit agains the data

plt = plot(
    title="Median times for plan $flag with vector fields of different sizes",
    xlabel = "N (stretched out as N²)",
    ylabel = "time (ms)",
    xticks = (Ns.^2, string.(Ns)),
    rotation = 90,
    titlefont=10,
    legend=:topleft
)

plot!(plt, Ns.^2, map(N -> median(values(results[flag][isqrt(N)]).times)./ 1.0e+6, Ns.^2),
linestyle = :solid,
markershape = :circle,
label="$flag"
)

plot!(plt, Ns.^2, K -> (a + b * K * log(K) )./ 1.0e+6, linestyle = :dash, markershape = :square, label="N²log(N²) fit")


plt

Now we take a closer look at the results for low N. Because the values are so discrepant for low and high Ns (they scale with N2N^2), instead of zooming into the result for low N, we do another fit, for these values. One option is to do a weighted least square fit, but that didn't turn out to be the best, so we we just restrict the fit to the low N.

We look at the powers of two that have low N:

lowtwos = filter(N -> isinteger(log2(N)), Ns[1:4])
2-element Vector{Int64}:
  64
 256

And fit them with the Vandermonde matrix

lowmat = [ones(length(lowtwos)) [N^2 * log(N^2) for N in lowtwos]]
2×2 Matrix{Float64}:
 1.0  34069.6
 1.0      7.26817e5

We get

c, d = lowmat \ [median(values(results[flag][N].times)) for N in lowtwos]
2-element Vector{Float64}:
 -534.8524590163925
    0.18447114004194656

Now we can visualize the result.

plt = plot(
    title="Median times for plan $flag with vector fields of different sizes",
    xlabel = "N (stretched out as N²)",
    ylabel = "time (ms)",
    xticks = (Ns[1:4].^2, string.(Ns[1:4])),
    rotation = 90,
    titlefont=10,
    legend=:topleft
)

plot!(plt, Ns[1:4].^2, map(N -> median(values(results[flag][isqrt(N)]).times)./ 1.0e+6, Ns[1:4].^2),
linestyle = :solid,
markershape = :circle,
label="$flag"
)

plot!(plt, Ns[1:4].^2, K -> (c + d * K * log(K) )./ 1.0e+6, linestyle = :dash, markershape = :square, label="N²log(N²) fit")


plt

Hmm, notice there are just two powers of two in this low range of N, so we could have just draw a straight line joint these two, but we leave it here in this generic way in case we change the values in the vector Ns of values of N.

Conclusions

Well, no plan is consistently worse, with all the others being closer together. In general, "PATIENT" and "EXHAUSTIVE" perform better, but planning them is costly, especially "EXHAUSTIVE", which doesn't seem to be worth it.

Hence, for a quick plan, "MEASURE" seems to be a good choice. Otherwise, "PATIENT" seems like the best option.

Moreover, powers of single primes, and in particular powers of 2, tend also to be faster, relatively speaking.

Hence, prefer N with powers of 2, plan with PATIENT, and apply the plan with mul!.



Last modified: March 24, 2022. Built with Franklin.jl, using the Book Template.